From 17d9437425cdda36238a9c696bd1d1201d0b3a1c Mon Sep 17 00:00:00 2001 From: brige Date: Sun, 22 Sep 2024 16:41:19 +0800 Subject: [PATCH] add test code --- Human.py | 166 +++++++++++++++++++++++++++++++++++++++++++---- tts/Chunk2Mal.py | 4 +- 2 files changed, 154 insertions(+), 16 deletions(-) diff --git a/Human.py b/Human.py index c4aa573..c1f6e8d 100644 --- a/Human.py +++ b/Human.py @@ -7,12 +7,15 @@ import time import numpy as np +import audio +import face_detection import utils from models import Wav2Lip from tts.Chunk2Mal import Chunk2Mal import torch import cv2 from tqdm import tqdm +from queue import Queue from tts.EdgeTTS import EdgeTTS from tts.TTSBase import TTSBase @@ -140,9 +143,107 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi logging.info('finish inference') +def get_smoothened_boxes(boxes, T): + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i : i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + + +def face_detect(images): + detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, + flip_input=False, device=device) + + batch_size = 16 + + while 1: + predictions = [] + try: + for i in tqdm(range(0, len(images), batch_size)): + predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError( + 'Image too big to run face detection on GPU. Please use the --resize_factor argument') + batch_size //= 2 + print('Recovering from OOM error; New batch size: {}'.format(batch_size)) + continue + break + + results = [] + pady1, pady2, padx1, padx2 = [0, 10, 0, 0] + for rect, image in zip(predictions, images): + if rect is None: + cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. + raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + + boxes = np.array(results) + if not False: boxes = get_smoothened_boxes(boxes, T=5) + results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + del detector + return results + + +img_size = 96 +wav2lip_batch_size = 128 + + +def datagen(frames, mels): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + face_det_results = face_detect(frames) # BGR2RGB for CNN face detection + + # for i, m in enumerate(mels): + for i in range(mels.qsize()): + idx = 0 if True else i%len(frames) + frame_to_save = frames[__mirror_index(1, i)].copy() + face, coords = face_det_results[idx].copy() + + face = cv2.resize(face, (img_size, img_size)) + m = mels.get() + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, img_size//2:] = 0 + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + img_masked = img_batch.copy() + img_masked[:, img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + + + class Human: def __init__(self): - self._fps = 50 # 20 ms per frame + self._fps = 25 # 20 ms per frame self._batch_size = 16 self._sample_rate = 16000 self._stride_left_size = 10 @@ -151,17 +252,54 @@ class Human: self._output_queue = mp.Queue() self._res_frame_queue = mp.Queue(self._batch_size * 2) - self._chunk_2_mal = Chunk2Mal(self) - self._tts = TTSBase(self) + # self._chunk_2_mal = Chunk2Mal(self) + # self._tts = TTSBase(self) + + self.mel_chunks_queue_ = Queue() + self.test() + + # face_images_path = r'./face/' + # self._face_image_paths = utils.read_files_path(face_images_path) + # print(self._face_image_paths) + # self.render_event = mp.Event() + # mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths, + # self._feat_queue, self._output_queue, self._res_frame_queue, + # )).start() + # self.render_event.set() + + def test(self): + wav = audio.load_wav(r'./audio/audio1.wav', 16000) + mel = audio.melspectrogram(wav) + if np.isnan(mel.reshape(-1)).sum() > 0: + raise ValueError( + 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') + + mel_step_size = 16 + + print('fps:', self._fps) + mel_idx_multiplier = 80. / self._fps + print('mel_idx_multiplier:', mel_idx_multiplier) + i = 0 + while 1: + start_idx = int(i * mel_idx_multiplier) + if start_idx + mel_step_size > len(mel[0]): + # mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) + self.mel_chunks_queue_.put(mel[:, len(mel[0]) - mel_step_size:]) + break + # mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) + self.mel_chunks_queue_.put(mel[:, start_idx: start_idx + mel_step_size]) + i += 1 + + batch_size = 128 + print('batch_size:', batch_size, ' mel_chunks len:', self.mel_chunks_queue_.qsize()) face_images_path = r'./face/' - self._face_image_paths = utils.read_files_path(face_images_path) - print(self._face_image_paths) - self.render_event = mp.Event() - mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths, - self._feat_queue, self._output_queue, self._res_frame_queue, - )).start() - self.render_event.set() + face_images_path = utils.read_files_path(face_images_path) + face_list_cycle = read_images(face_images_path) + face_images_length = len(face_list_cycle) + logging.info(f'face images length: {face_images_length}') + print(f'face images length: {face_images_length}') + gen = datagen(face_list_cycle, self.mel_chunks_queue_) def get_fps(self): return self._fps @@ -179,10 +317,10 @@ class Human: return self._stride_right_size def on_destroy(self): - self.render_event.clear() - self._chunk_2_mal.stop() - if self._tts is not None: - self._tts.stop() + # self.render_event.clear() + # self._chunk_2_mal.stop() + # if self._tts is not None: + # self._tts.stop() logging.info('human destroy') def read(self, txt): diff --git a/tts/Chunk2Mal.py b/tts/Chunk2Mal.py index 460ab94..783a2a7 100644 --- a/tts/Chunk2Mal.py +++ b/tts/Chunk2Mal.py @@ -17,7 +17,7 @@ class Chunk2Mal: self._chunks = [] # 320 samples per chunk (20ms * 16000 / 1000)audio_chunk - self._chunk_len = self._human.get_audio_sample_rate // self._human.get_fps() + self._chunk_len = self._human.get_audio_sample_rate() // self._human.get_fps() self._exit_event = Event() self._thread = Thread(target=self._on_run) @@ -82,7 +82,7 @@ class Chunk2Mal: chunk = self._audio_chunk_queue.get(block=True, timeout=1) type = 1 except queue.Empty: - chunk = np.zeros(self._human.get_chunk(), dtype=np.float32) + chunk = np.zeros(self._chunk_len, dtype=np.float32) type = 0 return chunk, type