From e606fb6ef53aec893b8c64bba05edfe4df8ae56c Mon Sep 17 00:00:00 2001 From: jiegeaiai Date: Fri, 27 Sep 2024 01:34:52 +0800 Subject: [PATCH] render image to ui --- .gitignore | 1 - Human.py | 29 ++++++- edge_tts_test.py | 2 +- infer.py | 200 +++++++++++++++++++++++++++++++++++++++++++++++ tts/Chunk2Mal.py | 18 +++-- tts/TTSBase.py | 52 +++++++----- ui.py | 5 +- 7 files changed, 273 insertions(+), 34 deletions(-) create mode 100644 infer.py diff --git a/.gitignore b/.gitignore index 56f0280..7803c36 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ *.pkl -*.jpg *.mp4 *.pth *.pyc diff --git a/Human.py b/Human.py index 41f7d17..80709fc 100644 --- a/Human.py +++ b/Human.py @@ -1,4 +1,5 @@ #encoding = utf8 +import io import logging import multiprocessing as mp @@ -14,6 +15,7 @@ import pyaudio import audio import face_detection import utils +from infer import Infer from models import Wav2Lip from tts.Chunk2Mal import Chunk2Mal import torch @@ -160,7 +162,6 @@ def get_smoothened_boxes(boxes, T): def face_detect(images): detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device) - batch_size = 16 while 1: @@ -281,6 +282,16 @@ def datagen_signal(frame, mel, face_det_results): return img_batch, mel_batch, frame_batch, coords_batch +# 从字节流加载音频数据 +def load_audio_from_bytes(byte_data): + # 使用 BytesIO 创建一个字节流 + with io.BytesIO(byte_data) as b: + wav = audio.load_wav(b, 16000) # 根据实际库的参数进行调整 + return wav + +# 假设你有音频文件的字节数据 + + class Human: def __init__(self): self._fps = 25 # 20 ms per frame @@ -294,12 +305,15 @@ class Human: self._chunk_2_mal = Chunk2Mal(self) self._tts = TTSBase(self) + self._infer = Infer(self) self.mel_chunks_queue_ = Queue() self.audio_chunks_queue_ = Queue() self._test_image_queue = Queue() self._thread = None + thread = threading.Thread(target=self.test) + thread.start() # self.test() # self.play_pcm() @@ -326,7 +340,13 @@ class Human: # p.terminate() def test(self): - wav = audio.load_wav(r'./audio/audio1.wav', 16000) + wav = audio.load_wav(r'./audio/test.wav', 16000) + # with open(r'./audio/test.wav', 'rb') as f: + # byte_data = f.read() + # + # byte_data = byte_data[16:] + # inputs = np.concatenate(byte_data) # [N * chunk] + # wav = load_audio_from_bytes(inputs) mel = audio.melspectrogram(wav) if np.isnan(mel.reshape(-1)).sum() > 0: raise ValueError( @@ -432,7 +452,7 @@ class Human: self._chunk_2_mal.push_chunk(audio_chunk) def push_mel_chunks_queue(self, mel_chunk): - self.mel_chunks_queue_.put(mel_chunk) + self._infer.push(mel_chunk) # self.audio_chunks_queue_.put(audio_chunk) def push_feat_queue(self, mel_chunks): @@ -443,6 +463,9 @@ class Human: print("push_audio_frames") self._output_queue.put((chunk, type_)) + def push_render_image(self, image): + self._test_image_queue.put(image) + def render(self): try: # img, aud = self._res_frame_queue.get(block=True, timeout=.3) diff --git a/edge_tts_test.py b/edge_tts_test.py index 9b72790..682b318 100644 --- a/edge_tts_test.py +++ b/edge_tts_test.py @@ -22,7 +22,7 @@ async def play_tts(text, voice): # 设置 PyAudio audio = pyaudio.PyAudio() - stream = audio.open(format=pyaudio.paInt16, channels=1, rate=24000, output=True) + stream = audio.open(format=pyaudio.paInt16, channels=1, rate=16000, output=True) # async for chunk in communicate.stream(): # 使用 stream 方法 # if chunk['type'] == 'audio': # 确保 chunk 是字节流 diff --git a/infer.py b/infer.py new file mode 100644 index 0000000..b18f972 --- /dev/null +++ b/infer.py @@ -0,0 +1,200 @@ +#encoding = utf8 +import queue +from queue import Queue +from threading import Thread, Event +import logging + +import cv2 +import numpy as np +import torch +from tqdm import tqdm + +import face_detection +import utils +from models import Wav2Lip + +logger = logging.getLogger(__name__) + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + + +def read_images(img_list): + frames = [] + print('reading images...') + for img_path in tqdm(img_list): + print(f'read image path:{img_path}') + frame = cv2.imread(img_path) + frames.append(frame) + return frames + + +def _load(checkpoint_path): + if device == 'cuda': + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + + +def load_model(path): + model = Wav2Lip() + print("Load checkpoint from: {}".format(path)) + logging.info(f'Load checkpoint from {path}') + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + model = model.to(device) + return model.eval() + + +def get_smoothened_boxes(boxes, T): + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i : i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + + +def face_detect(images): + detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, + flip_input=False, device=device) + batch_size = 16 + + while 1: + predictions = [] + try: + for i in range(0, len(images), batch_size): + predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError( + 'Image too big to run face detection on GPU. Please use the --resize_factor argument') + batch_size //= 2 + print('Recovering from OOM error; New batch size: {}'.format(batch_size)) + continue + break + + results = [] + pady1, pady2, padx1, padx2 = [0, 10, 0, 0] + for rect, image in zip(predictions, images): + if rect is None: + cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. + raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + + boxes = np.array(results) + boxes = get_smoothened_boxes(boxes, T=5) + results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + del detector + return results + + +img_size = 96 +wav2lip_batch_size = 128 + + +def datagen_signal(frame, mel, face_det_results): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + # for i, m in enumerate(mels): + idx = 0 + frame_to_save = frame.copy() + face, coords = face_det_results[idx].copy() + + face = cv2.resize(face, (img_size, img_size)) + m = mel + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, img_size // 2:] = 0 + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + return img_batch, mel_batch, frame_batch, coords_batch + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + img_masked = img_batch.copy() + img_masked[:, img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + return img_batch, mel_batch, frame_batch, coords_batch + + +class Infer: + def __init__(self, human): + self._human = human + self._queue = Queue() + + self._exit_event = Event() + self._run_thread = Thread(target=self.__on_run) + self._exit_event.set() + self._run_thread.start() + + def __on_run(self): + face_images_path = r'./face/' + face_images_path = utils.read_files_path(face_images_path) + face_list_cycle = read_images(face_images_path) + face_images_length = len(face_list_cycle) + logging.info(f'face images length: {face_images_length}') + print(f'face images length: {face_images_length}') + + model = load_model(r'.\checkpoints\wav2lip.pth') + print("Model loaded") + + # frame_h, frame_w = face_list_cycle[0].shape[:-1] + face_det_results = face_detect(face_list_cycle) + + j = 0 + + while self._exit_event.is_set(): + try: + m = self._queue.get(block=True, timeout=1) + except queue.Empty: + continue + + img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results) + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + + with torch.no_grad(): + pred = model(mel_batch, img_batch) + + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + for p, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) + + f[y1:y2, x1:x2] = p + # name = "%04d" % j + # cv2.imwrite(f'temp/images/{j}.jpg', p) + # j = j + 1 + p = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) + self._human.push_render_image(p) + # out.write(f) + + def push(self, chunk): + self._queue.put(chunk) \ No newline at end of file diff --git a/tts/Chunk2Mal.py b/tts/Chunk2Mal.py index 7388b3b..5dff798 100644 --- a/tts/Chunk2Mal.py +++ b/tts/Chunk2Mal.py @@ -2,6 +2,7 @@ import logging import queue +import time from queue import Queue from threading import Thread, Event @@ -28,25 +29,28 @@ class Chunk2Mal: def _on_run(self): logging.info('chunk2mal run') while self._exit_event.is_set(): + if self._audio_chunk_queue.empty(): + time.sleep(0.5) + continue try: - chunk, type_ = self.pull_chunk() + chunk = self._audio_chunk_queue.get(block=True, timeout=1) self._chunks.append(chunk) - self._human.push_audio_frames(chunk, type_) + self._human.push_audio_frames(chunk, 0) + if len(self._chunks) < 10: + continue except queue.Empty: # print('Chunk2Mal queue.Empty') continue - if type_ == 0: - continue logging.info('np.concatenate') - mel = audio.melspectrogram(chunk) + inputs = np.concatenate(self._chunks) # [N * chunk] + mel = audio.melspectrogram(inputs) if np.isnan(mel.reshape(-1)).sum() > 0: raise ValueError( 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') mel_step_size = 16 - print('fps:', self._human.get_fps()) mel_idx_multiplier = 80. / self._human.get_fps() print('mel_idx_multiplier:', mel_idx_multiplier) @@ -55,10 +59,8 @@ class Chunk2Mal: while 1: start_idx = int(i * mel_idx_multiplier) if start_idx + mel_step_size > len(mel[0]): - # mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) self._human.push_mel_chunks_queue(mel[:, len(mel[0]) - mel_step_size:]) break - # mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) self._human.push_mel_chunks_queue(mel[:, start_idx: start_idx + mel_step_size]) i += 1 diff --git a/tts/TTSBase.py b/tts/TTSBase.py index 49f3b01..9c93c1a 100644 --- a/tts/TTSBase.py +++ b/tts/TTSBase.py @@ -7,6 +7,7 @@ import edge_tts import numpy as np import pyaudio import soundfile +import sounddevice import resampy import queue from io import BytesIO @@ -23,18 +24,16 @@ class TTSBase: self._human = human self._thread = None self._queue = Queue() - self._exit_event = None self._io_stream = BytesIO() - self._sample_rate = 16000 - self._chunk_len = self._sample_rate // self._human.get_fps() + self._chunk_len = self._human.get_audio_sample_rate() // self._human.get_fps() self._exit_event = Event() self._thread = Thread(target=self._on_run) self._exit_event.set() self._thread.start() - self._pcm_player = pyaudio.PyAudio() - self._pcm_stream = self._pcm_player.open(format=pyaudio.paInt16, - channels=1, rate=16000, output=True) + # self._pcm_player = pyaudio.PyAudio() + # self._pcm_stream = self._pcm_player.open(format=pyaudio.paInt16, + # channels=1, rate=24000, output=True) logging.info('tts start') def _on_run(self): @@ -56,16 +55,24 @@ class TTSBase: self._io_stream.seek(0) stream = self.__create_bytes_stream(self._io_stream) stream_len = stream.shape[0] + # try: + # sounddevice.play(stream, samplerate=self._human.get_audio_sample_rate()) + # sounddevice.wait() # 等待音频播放完毕 + # except Exception as e: + # logger.error(f"播放音频出错: {e}") playrec index = 0 while stream_len >= self._chunk_len: audio_chunk = stream[index:index + self._chunk_len] + # sounddevice.play(audio_chunk, samplerate=self._human.get_audio_sample_rate()) # self._pcm_stream.write(audio_chunk) - # self._pcm_stream.write(AudioSegment.from_mp3(audio_chunk)) + # self._pcm_stream.write(audio_chunk.tobytes()) # self._human.push_audio_chunk(audio_chunk) # self._human.push_mel_chunks_queue(audio_chunk) self._human.push_audio_chunk(audio_chunk) stream_len -= self._chunk_len index += self._chunk_len + self._io_stream.seek(0) + self._io_stream.truncate() def __create_bytes_stream(self, io_stream): stream, sample_rate = soundfile.read(io_stream) @@ -76,29 +83,34 @@ class TTSBase: logger.warning(f'tts audio has {stream.shape[1]} channels, only use the first') stream = stream[:, 1] - if sample_rate != self._sample_rate and stream.shape[0] > 0: - logger.warning(f'tts audio sample rate is {sample_rate}, resample to {self._sample_rate}') - stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate) + if sample_rate != self._human.get_audio_sample_rate() and stream.shape[0] > 0: + logger.warning(f'tts audio sample rate is {sample_rate}, resample to {self._human.get_audio_sample_rate() }') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._human.get_audio_sample_rate() ) return stream async def __on_request(self, voice, txt): communicate = edge_tts.Communicate(txt, voice) first = True - # total_data = b'' - # CHUNK_SIZE = self._chunk_len + total_data = b'' + CHUNK_SIZE = self._chunk_len async for chunk in communicate.stream(): if chunk["type"] == "audio" and chunk["data"]: - self._io_stream.write(chunk['data']) - # total_data += chunk["data"] - # if len(total_data) >= CHUNK_SIZE: + data = chunk['data'] + self._io_stream.write(data) + elif chunk["type"] == "WordBoundary": + pass + ''' + total_data += chunk["data"] + if len(total_data) >= CHUNK_SIZE: # print(f"Time elapsed: {time.time() - start_time:.2f} seconds") # Print time - # audio_data = AudioSegment.from_mp3(BytesIO(total_data[:CHUNK_SIZE])) #.raw_data - # audio_data = audio_data.set_frame_rate(self._human.get_audio_sample_rate()) + audio_data = AudioSegment.from_mp3(BytesIO(total_data[:CHUNK_SIZE])) #.raw_data + audio_data = audio_data.set_frame_rate(self._human.get_audio_sample_rate()) # self._human.push_audio_chunk(audio_data) - # self._pcm_stream.write(audio_data.raw_data) + self._pcm_stream.write(audio_data.raw_data) # play_audio(total_data[:CHUNK_SIZE], stream) # Play first CHUNK_SIZE bytes - # total_data = total_data[CHUNK_SIZE:] # Remove played data + total_data = total_data[CHUNK_SIZE:] # Remove played data + ''' # if first: # first = False @@ -106,10 +118,12 @@ class TTSBase: # if chuck['type'] == 'audio': # # self._io_stream.write(chuck['data']) # self._io_stream.write(AudioSegment.from_mp3(BytesIO(total_data[:CHUNK_SIZE])).raw_data) + # if len(total_data) > 0: # self._pcm_stream.write(AudioSegment.from_mp3(BytesIO(total_data)).raw_data) # audio_data = AudioSegment.from_mp3(BytesIO(total_data)) # .raw_data # audio_data = audio_data.set_frame_rate(self._human.get_audio_sample_rate()) + # self._pcm_stream.write(audio_data.raw_data) # self._human.push_audio_chunk(audio_data) # self._io_stream.write(AudioSegment.from_mp3(BytesIO(total_data)).raw_data) diff --git a/ui.py b/ui.py index a4a80cc..9dd6895 100644 --- a/ui.py +++ b/ui.py @@ -63,10 +63,11 @@ class App(customtkinter.CTk): self._human.on_destroy() def play_audio(self): + # return if self._is_play_audio: return self._is_play_audio = True - file = os.path.curdir + '/audio/audio1.wav' + file = os.path.curdir + '/audio/test.wav' print(file) winsound.PlaySound(file, winsound.SND_ASYNC or winsound.SND_FILENAME) # playsound(file) @@ -104,7 +105,7 @@ class App(customtkinter.CTk): height = self.winfo_height() * 0.5 self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk) self._canvas.update() - self.after(34, self._render) + self.after(33, self._render) def request_tts(self): content = self.entry.get()