From c9f8ff6541a58c303179afc0bd8c204f36fb1059 Mon Sep 17 00:00:00 2001 From: jiegeaiai Date: Mon, 9 Sep 2024 08:23:04 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0chunk=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Human.py | 132 +++++++++++++++++++++++++++++++++++++++++++++++ tts/Chunk2Mal.py | 10 +++- ui.py | 12 ----- 3 files changed, 140 insertions(+), 14 deletions(-) diff --git a/Human.py b/Human.py index ca78f9e..4f35eda 100644 --- a/Human.py +++ b/Human.py @@ -1,11 +1,131 @@ #encoding = utf8 import logging import multiprocessing as mp +import queue +import time +import numpy as np + +from models import Wav2Lip from tts.Chunk2Mal import Chunk2Mal +import torch +import cv2 +from tqdm import tqdm logger = logging.getLogger(__name__) +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Using {} for inference.'.format(device)) + + +def _load(checkpoint_path): + if device == 'cuda': + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + + +def load_model(path): + model = Wav2Lip() + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + model = model.to(device) + return model.eval() + + +def read_images(img_list): + frames = [] + print('reading images...') + for img_path in tqdm(img_list): + frame = cv2.imread(img_path) + frames.append(frame) + return frames + + +def __mirror_index(size, index): + # size = len(self.coord_list_cycle) + turn = index // size + res = index % size + if turn % 2 == 0: + return res + else: + return size - res - 1 + + +# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face +# .\face\img00016.jpg --audio .\audio\audio1.wav +def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue): + model = load_model(r'.\checkpoints\wav2lip.pth') + face_list_cycle = read_images(face_images_path) + face_images_length = len(face_list_cycle) + logger.info(f'face images length: {face_images_length}') + + length = len(face_list_cycle) + index = 0 + count = 0 + count_time = 0 + logger.info('start inference') + while render_event.is_set(): + try: + mel_batch = audio_feat_queue.get(block=True, timeout=1) + except queue.Empty: + continue + + audio_frames = [] + is_all_silence = True + for _ in range(batch_size * 2): + frame, type = audio_feat_queue.get() + audio_frames.append((frame, type)) + + if type == 0: + is_all_silence = False + + if is_all_silence: + for i in range(batch_size): + res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2])) + index = index + 1 + else: + t = time.perf_counter() + image_batch = [] + for i in range(batch_size): + idx = __mirror_index(length, index + i) + face = face_list_cycle[idx] + image_batch.append(face) + image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch) + + image_masked = image_batch.copy() + image_masked[:, face.shape[0]//2:] = 0 + + image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + + with torch.no_grad(): + pred = model(mel_batch, image_batch) + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + + count_time += (time.perf_counter() - t) + count += batch_size + if count >= 100: + logger.info(f"------actual avg infer fps:{count/count_time:.4f}") + count = 0 + count_time = 0 + + for i, res_frame in enumerate(pred): + res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2])) + index = index + 1 + + logger.info('finish inference') + class Human: def __init__(self): @@ -18,6 +138,14 @@ class Human: self._stride_left_size = 10 self._stride_right_size = 10 self._feat_queue = mp.Queue(2) + self._output_queue = mp.Queue() + self._res_frame_queue = mp.Queue(self._batch_size * 2) + + self.face_images_path = r'.\face' + self.render_event = mp.Event() + mp.Process(target=inference, args=(self.render_event, self._batch_size, self.face_images_path, + self._feat_queue, self._output_queue, self._res_frame_queue, + )).start() def get_fps(self): return self._fps @@ -35,6 +163,8 @@ class Human: return self._stride_right_size def on_destroy(self): + self.render_event.set() + self._chunk_2_mal.stop() if self._tts is not None: @@ -60,4 +190,6 @@ class Human: self._chunk_2_mal.push_chunk(chunk) def push_feat_queue(self, mel_chunks): + print("21") self._feat_queue.put(mel_chunks) + print("22") diff --git a/tts/Chunk2Mal.py b/tts/Chunk2Mal.py index 004d682..3d4e8a6 100644 --- a/tts/Chunk2Mal.py +++ b/tts/Chunk2Mal.py @@ -22,6 +22,7 @@ class Chunk2Mal: try: chunk, type = self.pull_chunk() self._chunks.append(chunk) + print("1") except queue.Empty: continue @@ -38,6 +39,7 @@ class Chunk2Mal: mel_chunks = [] while i < (len(self._chunks) - self._human.get_stride_left_size() - self._human.get_stride_right_size()) / 2: + print("14") start_idx = int(left + i * mel_idx_multiplier) # print(start_idx) if start_idx + mel_step_size > len(mel[0]): @@ -45,10 +47,13 @@ class Chunk2Mal: else: mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) i += 1 + print("13") self._human.push_feat_queue(mel_chunks) + print("15") # discard the old part to save memory self._chunks = self._chunks[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):] + print("12") logging.info('chunk2mal exit') @@ -65,7 +70,8 @@ class Chunk2Mal: return self._exit_event.set() - self._thread.join() + if self._thread.is_alive(): + self._thread.join() logging.info('chunk2mal stop') def push_chunk(self, chunk): @@ -73,7 +79,7 @@ class Chunk2Mal: def pull_chunk(self): try: - chunk = self._audio_chunk_queue.get(block=True, timeout=1.0) + chunk = self._audio_chunk_queue.get(block=True, timeout=1) type = 1 except queue.Empty: chunk = np.zeros(self._human.get_chunk(), dtype=np.float32) diff --git a/ui.py b/ui.py index b41b4c0..503d018 100644 --- a/ui.py +++ b/ui.py @@ -120,18 +120,6 @@ def config_logging(file_name: str, console_level: int=logging.INFO, file_level: if __name__ == "__main__": # logging.basicConfig(filename='./logs/info.log', level=logging.INFO) config_logging('./logs/info.log', logging.INFO, logging.INFO) - # logger = logging.getLogger('manager') - # # 输出到控制台, 级别为DEBUG - # console = logging.StreamHandler() - # console.setLevel(logging.DEBUG) - # logger.addHandler(console) - # - # # 输出到文件, 级别为INFO, 文件按大小切分 - # filelog = logging.handlers.RotatingFileHandler(filename='./logs/info.log', level=logging.INFO, - # maxBytes=1024 * 1024, backupCount=5) - # filelog.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) - # logger.setLevel(logging.INFO) - # logger.addHandler(filelog) logger.info('------------start------------') app = App() app.mainloop()