#encoding = utf8 import queue import time from queue import Queue from threading import Thread, Event import logging import cv2 import numpy as np import torch from tqdm import tqdm import face_detection import utils from models import Wav2Lip logger = logging.getLogger(__name__) device = 'cuda' if torch.cuda.is_available() else 'cpu' def read_images(img_list): frames = [] print('reading images...') for img_path in tqdm(img_list): print(f'read image path:{img_path}') frame = cv2.imread(img_path) frames.append(frame) return frames def _load(checkpoint_path): if device == 'cuda': checkpoint = torch.load(checkpoint_path) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) return checkpoint def load_model(path): model = Wav2Lip() print("Load checkpoint from: {}".format(path)) logging.info(f'Load checkpoint from {path}') checkpoint = _load(path) s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): new_s[k.replace('module.', '')] = v model.load_state_dict(new_s) model = model.to(device) return model.eval() def get_smoothened_boxes(boxes, T): for i in range(len(boxes)): if i + T > len(boxes): window = boxes[len(boxes) - T:] else: window = boxes[i : i + T] boxes[i] = np.mean(window, axis=0) return boxes def face_detect(images): detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device) batch_size = 16 while 1: predictions = [] try: for i in range(0, len(images), batch_size): predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) except RuntimeError: if batch_size == 1: raise RuntimeError( 'Image too big to run face detection on GPU. Please use the --resize_factor argument') batch_size //= 2 print('Recovering from OOM error; New batch size: {}'.format(batch_size)) continue break results = [] pady1, pady2, padx1, padx2 = [0, 10, 0, 0] for rect, image in zip(predictions, images): if rect is None: cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') y1 = max(0, rect[1] - pady1) y2 = min(image.shape[0], rect[3] + pady2) x1 = max(0, rect[0] - padx1) x2 = min(image.shape[1], rect[2] + padx2) results.append([x1, y1, x2, y2]) boxes = np.array(results) boxes = get_smoothened_boxes(boxes, T=5) results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] del detector return results img_size = 96 wav2lip_batch_size = 128 def datagen_signal(frame, mel, face_det_results): img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] # for i, m in enumerate(mels): idx = 0 frame_to_save = frame.copy() face, coords = face_det_results[idx].copy() face = cv2.resize(face, (img_size, img_size)) m = mel img_batch.append(face) mel_batch.append(m) frame_batch.append(frame_to_save) coords_batch.append(coords) if len(img_batch) >= wav2lip_batch_size: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size // 2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) return img_batch, mel_batch, frame_batch, coords_batch if len(img_batch) > 0: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size//2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) return img_batch, mel_batch, frame_batch, coords_batch class Infer: def __init__(self, human): self._human = human self._queue = Queue() self._exit_event = Event() self._run_thread = Thread(target=self.__on_run) self._exit_event.set() self._run_thread.start() def __on_run(self): face_images_path = r'./face/' face_images_path = utils.read_files_path(face_images_path) face_list_cycle = read_images(face_images_path) face_images_length = len(face_list_cycle) logging.info(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}') model = load_model(r'.\checkpoints\wav2lip.pth') print("Model loaded") # frame_h, frame_w = face_list_cycle[0].shape[:-1] face_det_results = face_detect(face_list_cycle) j = 0 count = 0 while self._exit_event.is_set(): try: m = self._queue.get(block=True, timeout=1) except queue.Empty: continue img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) time.sleep(0.01) with torch.no_grad(): pred = model(mel_batch, img_batch) pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. for p, f, c in zip(pred, frames, coords): y1, y2, x1, x2 = c p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) f[y1:y2, x1:x2] = p # name = "%04d" % j cv2.imwrite(f'temp/images/{j}.jpg', p) j = j + 1 # count = count + 1 p = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) self._human.push_render_image(p) # out.write(f) # print('infer count:', count) def push(self, chunk): self._queue.put(chunk)