#encoding = utf8 import queue import time from queue import Queue from threading import Thread, Event import logging import cv2 import numpy as np import torch from tqdm import tqdm import face_detection from models import Wav2Lip from utils import mirror_index logger = logging.getLogger(__name__) device = 'cuda' if torch.cuda.is_available() else 'cpu' def read_images(img_list): frames = [] print('reading images...') for img_path in tqdm(img_list): print(f'read image path:{img_path}') frame = cv2.imread(img_path) frames.append(frame) return frames def _load(checkpoint_path): if device == 'cuda': checkpoint = torch.load(checkpoint_path) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) return checkpoint def load_model(path): model = Wav2Lip() print("Load checkpoint from: {}".format(path)) logging.info(f'Load checkpoint from {path}') checkpoint = _load(path) s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): new_s[k.replace('module.', '')] = v model.load_state_dict(new_s) model = model.to(device) return model.eval() def get_smoothened_boxes(boxes, T): for i in range(len(boxes)): if i + T > len(boxes): window = boxes[len(boxes) - T:] else: window = boxes[i : i + T] boxes[i] = np.mean(window, axis=0) return boxes def face_detect(images): detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device) batch_size = 16 while 1: predictions = [] try: for i in range(0, len(images), batch_size): predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) except RuntimeError: if batch_size == 1: raise RuntimeError( 'Image too big to run face detection on GPU. Please use the --resize_factor argument') batch_size //= 2 print('Recovering from OOM error; New batch size: {}'.format(batch_size)) continue break results = [] pady1, pady2, padx1, padx2 = [0, 10, 0, 0] for rect, image in zip(predictions, images): if rect is None: cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') y1 = max(0, rect[1] - pady1) y2 = min(image.shape[0], rect[3] + pady2) x1 = max(0, rect[0] - padx1) x2 = min(image.shape[1], rect[2] + padx2) results.append([x1, y1, x2, y2]) boxes = np.array(results) boxes = get_smoothened_boxes(boxes, T=5) results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] del detector return results img_size = 96 wav2lip_batch_size = 128 def datagen_signal(frame, mel, face_det_results): img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] # for i, m in enumerate(mels): idx = 0 frame_to_save = frame.copy() face, coords = face_det_results[idx].copy() face = cv2.resize(face, (img_size, img_size)) m = mel img_batch.append(face) mel_batch.append(m) frame_batch.append(frame_to_save) coords_batch.append(coords) if len(img_batch) >= wav2lip_batch_size: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size // 2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) return img_batch, mel_batch, frame_batch, coords_batch if len(img_batch) > 0: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size//2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) return img_batch, mel_batch, frame_batch, coords_batch class Infer: def __init__(self, human): self._human = human self._feat_queue = Queue() self._audio_out_queue = Queue() self._exit_event = Event() self._run_thread = Thread(target=self.__on_run) self._exit_event.set() self._run_thread.start() def __on_run(self): model = load_model(r'.\checkpoints\wav2lip.pth') print("Model loaded") face_list_cycle = self._human.get_face_list_cycle() # self.__do_run1(face_list_cycle, model) self.__do_run2(face_list_cycle, model) # frame_h, frame_w = face_list_cycle[0].shape[:-1] def __do_run1(self, face_list_cycle, model): face_det_results = face_detect(face_list_cycle) j = 0 count = 0 while self._exit_event.is_set(): try: m = self._feat_queue.get(block=True, timeout=1) except queue.Empty: continue img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) time.sleep(0.01) with torch.no_grad(): pred = model(mel_batch, img_batch) pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. for p, f, c in zip(pred, frames, coords): y1, y2, x1, x2 = c p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) f[y1:y2, x1:x2] = p # name = "%04d" % j cv2.imwrite(f'temp/images/{j}.jpg', p) j = j + 1 # count = count + 1 p = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) self._human.push_render_image(p) # out.write(f) # print('infer count:', count) def __do_run2(self, face_list_cycle, model): length = len(face_list_cycle) index = 0 count = 0 count_time = 0 print('start inference') while True: if self._exit_event.is_set(): start_time = time.perf_counter() try: mel_batch = self._feat_queue.get(block=True, timeout=1) except queue.Empty: continue is_all_silence = True audio_frames = [] for _ in range(self._human.get_batch_size() * 2): frame, type_ = self._audio_out_queue.get() audio_frames.append((frame, type_)) if type_ == 0: is_all_silence = False if is_all_silence: for i in range(self._human.get_batch_size()): # res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]) index = index + 1 else: print('infer=======') t = time.perf_counter() img_batch = [] for i in range(self._human.get_batch_size()): idx = mirror_index(length, index + i) face = face_list_cycle[idx] img_batch.append(face) img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, face.shape[0] // 2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) with torch.no_grad(): pred = model(mel_batch, img_batch) pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. count_time += (time.perf_counter() - t) count += self._human.batch_size() # _totalframe += 1 if count >= 100: print(f"------actual avg infer fps:{count / count_time:.4f}") count = 0 count_time = 0 for i, res_frame in enumerate(pred): # self.__pushmedia(res_frame,loop,audio_track,video_track) # res_frame_queue.put( # (res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) self._human.push_res_frame(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]) index = index + 1 # print('total batch time:',time.perf_counter()-start_time) else: time.sleep(1) print('musereal inference processor stop') def push(self, mel_chunks): self._feat_queue.put(mel_chunks) def push_out_queue(self, frame, type_): self._audio_out_queue.put((frame, type_))