#encoding = utf8 import os import glob import queue import multiprocessing as mp import time from queue import Queue from threading import Thread, Event import logging import cv2 import numpy as np import torch from tqdm import tqdm import face_detection import utils from models import Wav2Lip from utils import mirror_index logger = logging.getLogger(__name__) device = 'cuda' if torch.cuda.is_available() else 'cpu' def read_images(img_list): frames = [] print('reading images...') for img_path in tqdm(img_list): print(f'read image path:{img_path}') frame = cv2.imread(img_path) frames.append(frame) return frames def _load(checkpoint_path): if device == 'cuda': checkpoint = torch.load(checkpoint_path) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) return checkpoint def load_model(path): model = Wav2Lip() print("Load checkpoint from: {}".format(path)) logging.info(f'Load checkpoint from {path}') checkpoint = _load(path) s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): new_s[k.replace('module.', '')] = v model.load_state_dict(new_s) model = model.to(device) return model.eval() def get_smoothened_boxes(boxes, T): for i in range(len(boxes)): if i + T > len(boxes): window = boxes[len(boxes) - T:] else: window = boxes[i : i + T] boxes[i] = np.mean(window, axis=0) return boxes def face_detect(images): detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device) batch_size = 16 while 1: predictions = [] try: for i in range(0, len(images), batch_size): predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) except RuntimeError: if batch_size == 1: raise RuntimeError( 'Image too big to run face detection on GPU. Please use the --resize_factor argument') batch_size //= 2 print('Recovering from OOM error; New batch size: {}'.format(batch_size)) continue break results = [] pady1, pady2, padx1, padx2 = [0, 10, 0, 0] for rect, image in zip(predictions, images): if rect is None: cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') y1 = max(0, rect[1] - pady1) y2 = min(image.shape[0], rect[3] + pady2) x1 = max(0, rect[0] - padx1) x2 = min(image.shape[1], rect[2] + padx2) results.append([x1, y1, x2, y2]) boxes = np.array(results) boxes = get_smoothened_boxes(boxes, T=5) results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] del detector return results img_size = 96 wav2lip_batch_size = 128 def datagen(frames, mels): img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] face_det_results = face_detect(frames) # BGR2RGB for CNN face detection i = 0 for m in enumerate(mels): # for i in range(mels.qsize()): idx = 0 if True else i%len(frames) frame_to_save = frames[mirror_index(1, i)].copy() face, coords = face_det_results[idx].copy() face = cv2.resize(face, (img_size, img_size)) img_batch.append(face) mel_batch.append(m) frame_batch.append(frame_to_save) coords_batch.append(coords) if len(img_batch) >= wav2lip_batch_size: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size//2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) yield img_batch, mel_batch, frame_batch, coords_batch img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] i = i + 1 if len(img_batch) > 0: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size//2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) yield img_batch, mel_batch, frame_batch, coords_batch def datagen_signal(frame, mel, face_det_results): img_batch, mel_batch, frame_batch, coord_batch = [], [], [], [] # for i, m in enumerate(mels): idx = 0 frame_to_save = frame.copy() face, coord = face_det_results[idx].copy() face = cv2.resize(face, (img_size, img_size)) for i, m in enumerate(mel): img_batch.append(face) mel_batch.append(m) frame_batch.append(frame_to_save) coord_batch.append(coord) if len(img_batch) >= wav2lip_batch_size: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size // 2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) return img_batch, mel_batch, frame_batch, coord_batch if len(img_batch) > 0: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size // 2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) return img_batch, mel_batch, frame_batch, coord_batch def inference(render_event, batch_size, face_imgs_path, audio_feat_queue, audio_out_queue, res_frame_queue): model = load_model(r'.\checkpoints\wav2lip.pth') # face_list_cycle = read_images(face_imgs_path) input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]')) input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) face_list_cycle = read_images(input_face_list) # input_latent_list_cycle = torch.load(latents_out_path) length = len(face_list_cycle) index = 0 count = 0 counttime = 0 print('start inference') while True: if render_event.is_set(): starttime = time.perf_counter() mel_batch = [] try: mel_batch = audio_feat_queue.get(block=True, timeout=1) except queue.Empty: continue is_all_silence = True audio_frames = [] for _ in range(batch_size * 2): frame, type_ = audio_out_queue.get() audio_frames.append((frame, type_)) if type_ == 0: is_all_silence = False if is_all_silence: for i in range(batch_size): res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) index = index + 1 else: print('infer=======') t = time.perf_counter() img_batch = [] for i in range(batch_size): idx = mirror_index(length, index + i) face = face_list_cycle[idx] img_batch.append(face) img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, face.shape[0] // 2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) with torch.no_grad(): pred = model(mel_batch, img_batch) pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. counttime += (time.perf_counter() - t) count += batch_size # _totalframe += 1 if count >= 100: print(f"------actual avg infer fps:{count / counttime:.4f}") count = 0 counttime = 0 for i, res_frame in enumerate(pred): # self.__pushmedia(res_frame,loop,audio_track,video_track) res_frame_queue.put((res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) index = index + 1 # print('total batch time:',time.perf_counter()-starttime) else: time.sleep(1) print('musereal inference processor stop') class Infer: def __init__(self, human): self._human = human self._feat_queue = Queue() self._audio_out_queue = Queue() self.batch_size = human.get_batch_size() # self.asr = human.chunk_2_mal # self.res_frame_queue = human.res_render_queue # face_images_path = r'./face/' # self.face_images_path = utils.read_files_path(face_images_path) self.avatar_id = 'wav2lip_avatar1' self.avatar_path = f"./data/{self.avatar_id}" self.full_imgs_path = f"{self.avatar_path}/full_imgs" self.face_images_path = f"{self.avatar_path}/face_imgs" self.coords_path = f"{self.avatar_path}/coords.pkl" # self.render_event = mp.Event() # mp.Process(target=inference, args=(self.render_event, self.batch_size, self.face_images_path, # self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue, # )).start() # self.render_event.set() self._exit_event = Event() self._run_thread = Thread(target=self.__on_run) self._exit_event.set() self._run_thread.start() def __on_run(self): model = load_model(r'.\checkpoints\wav2lip.pth') print("Model loaded") face_list_cycle = self._human.get_face_list_cycle() # input_face_list = glob.glob(os.path.join(self.face_images_path, '*.[jpJP][pnPN]*[gG]')) # input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) # face_list_cycle = read_images(input_face_list) # self.__do_run1(face_list_cycle, model) self.__do_run2(face_list_cycle, model) # frame_h, frame_w = face_list_cycle[0].shape[:-1] def __do_run1(self, face_list_cycle, model): face_det_results = face_detect(face_list_cycle) j = 0 count = 0 while self._exit_event.is_set(): try: m = self._feat_queue.get(block=True, timeout=1) except queue.Empty: continue img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) time.sleep(0.01) with torch.no_grad(): pred = model(mel_batch, img_batch) pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. for p, f, c in zip(pred, frames, coords): y1, y2, x1, x2 = c p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) f[y1:y2, x1:x2] = p # name = "%04d" % j cv2.imwrite(f'temp/images/{j}.jpg', p) j = j + 1 # count = count + 1 p = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) self._human.push_render_image(p) # out.write(f) # print('infer count:', count) def __do_run2(self, face_list_cycle, model): length = len(face_list_cycle) index = 0 count = 0 count_time = 0 print('start inference') # # face_images_path = r'./face/' # face_images_path = utils.read_files_path(face_images_path) # face_list_cycle1 = read_images(face_images_path) # face_det_results = face_detect(face_list_cycle1) while True: if self._exit_event.is_set(): start_time = time.perf_counter() batch_size = self._human.get_batch_size() try: mel_batch = self._feat_queue.get(block=True, timeout=1) except queue.Empty: continue is_all_silence = True audio_frames = [] for _ in range(batch_size * 2): frame, type_ = self._audio_out_queue.get() audio_frames.append((frame, type_)) if type_ == 0: is_all_silence = False if is_all_silence: for i in range(batch_size): # res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]) index = index + 1 else: print('infer=======') t = time.perf_counter() img_batch = [] for i in range(batch_size): idx = mirror_index(length, index + i) face = face_list_cycle[idx] img_batch.append(face) # img_batch_1, mel_batch_1, frames, coords = datagen_signal(face_list_cycle1, # mel_batch, face_det_results) img_batch = np.asarray(img_batch) mel_batch = np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, face.shape[0] // 2:] = 0 # img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) with torch.no_grad(): pred = model(mel_batch, img_batch) # pred = model(mel_batch, img_batch) * 255.0 pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. count_time += (time.perf_counter() - t) count += batch_size # _totalframe += 1 if count >= 100: print(f"------actual avg infer fps:{count / count_time:.4f}") count = 0 count_time = 0 for i, res_frame in enumerate(pred): # self.__pushmedia(res_frame,loop,audio_track,video_track) # res_frame_queue.put( # (res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) self._human.push_res_frame(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]) index = index + 1 # print('total batch time:',time.perf_counter()-start_time) else: time.sleep(1) print('musereal inference processor stop') def push(self, mel_chunks): self._feat_queue.put(mel_chunks) def push_out_queue(self, frame, type_): self._audio_out_queue.put((frame, type_)) def get_out_put(self): return self._audio_out_queue.get()