#encoding = utf8 import logging import multiprocessing as mp import queue import time import numpy as np import utils from models import Wav2Lip from tts.Chunk2Mal import Chunk2Mal import torch import cv2 from tqdm import tqdm device = 'cuda' if torch.cuda.is_available() else 'cpu' def _load(checkpoint_path): if device == 'cuda': checkpoint = torch.load(checkpoint_path) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) return checkpoint def load_model(path): model = Wav2Lip() print("Load checkpoint from: {}".format(path)) logging.info(f'Load checkpoint from {path}') checkpoint = _load(path) s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): new_s[k.replace('module.', '')] = v model.load_state_dict(new_s) model = model.to(device) return model.eval() def read_images(img_list): frames = [] print('reading images...') for img_path in tqdm(img_list): print(f'read image path:{img_path}') frame = cv2.imread(img_path) frames.append(frame) return frames def __mirror_index(size, index): # size = len(self.coord_list_cycle) turn = index // size res = index % size if turn % 2 == 0: return res else: return size - res - 1 # python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face # .\face\img00016.jpg --audio .\audio\audio1.wav def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue): logging.info(f'Using {device} for inference.') print(f'Using {device} for inference.') print(f'face_images_path: {face_images_path}') model = load_model(r'.\checkpoints\wav2lip.pth') face_list_cycle = read_images(face_images_path) face_images_length = len(face_list_cycle) logging.info(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}') length = len(face_list_cycle) index = 0 count = 0 count_time = 0 logging.info('start inference') print(f'start inference: {render_event.is_set()}') while render_event.is_set(): print('start inference') try: mel_batch = audio_feat_queue.get(block=True, timeout=1) except queue.Empty: continue audio_frames = [] is_all_silence = True for _ in range(batch_size * 2): frame, type = audio_out_queue.get() audio_frames.append((frame, type)) if type == 0: is_all_silence = False print(f'is_all_silence {is_all_silence}') if is_all_silence: for i in range(batch_size): res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2])) index = index + 1 else: t = time.perf_counter() image_batch = [] for i in range(batch_size): idx = __mirror_index(length, index + i) face = face_list_cycle[idx] image_batch.append(face) image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch) image_masked = image_batch.copy() image_masked[:, face.shape[0]//2:] = 0 image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) with torch.no_grad(): pred = model(mel_batch, image_batch) pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. count_time += (time.perf_counter() - t) count += batch_size if count >= 100: logging.info(f"------actual avg infer fps:{count/count_time:.4f}") count = 0 count_time = 0 for i, res_frame in enumerate(pred): res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2])) index = index + 1 logging.info('finish inference') class Human: def __init__(self): self._tts = None self._fps = 50 # 20 ms per frame self._batch_size = 16 self._sample_rate = 16000 self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000) self._chunk_2_mal = Chunk2Mal(self) self._stride_left_size = 10 self._stride_right_size = 10 self._feat_queue = mp.Queue(2) self._output_queue = mp.Queue() self._res_frame_queue = mp.Queue(self._batch_size * 2) face_images_path = r'./face/' self._face_image_paths = utils.read_files_path(face_images_path) print(self._face_image_paths) self.render_event = mp.Event() mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths, self._feat_queue, self._output_queue, self._res_frame_queue, )).start() self.render_event.set() def get_fps(self): return self._fps def get_batch_size(self): return self._batch_size def get_chunk(self): return self._chunk def get_stride_left_size(self): return self._stride_left_size def get_stride_right_size(self): return self._stride_right_size def on_destroy(self): self.render_event.clear() self._chunk_2_mal.stop() if self._tts is not None: self._tts.stop() logging.info('human destroy') def set_tts(self, tts): if self._tts == tts: return self._tts = tts self._tts.start() self._chunk_2_mal.start() def read(self, txt): if self._tts is None: logging.warning('tts is none') return self._tts.push_txt(txt) def push_audio_chunk(self, chunk): self._chunk_2_mal.push_chunk(chunk) def push_feat_queue(self, mel_chunks): print("21") self._feat_queue.put(mel_chunks) print("22") def push_audio_frames(self, chunk, type_): self._output_queue.put((chunk, type_)) def render(self): try: img, aud = self._res_frame_queue.get(block=True, timeout=.3) except queue.Empty: print('queue.Empty:') return None return img # def pull_audio_chunk(self): # try: # chunk = self._audio_chunk_queue.get(block=True, timeout=1.0) # type = 1 # except queue.Empty: # chunk = np.zeros(self._chunk, dtype=np.float32) # type = 0 # return chunk, type