add test code

This commit is contained in:
brige 2024-09-22 16:41:19 +08:00
parent 4e1e923c0b
commit 17d9437425
2 changed files with 154 additions and 16 deletions

166
Human.py
View File

@ -7,12 +7,15 @@ import time
import numpy as np import numpy as np
import audio
import face_detection
import utils import utils
from models import Wav2Lip from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal from tts.Chunk2Mal import Chunk2Mal
import torch import torch
import cv2 import cv2
from tqdm import tqdm from tqdm import tqdm
from queue import Queue
from tts.EdgeTTS import EdgeTTS from tts.EdgeTTS import EdgeTTS
from tts.TTSBase import TTSBase from tts.TTSBase import TTSBase
@ -140,9 +143,107 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
logging.info('finish inference') logging.info('finish inference')
def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i : i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def face_detect(images):
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
batch_size = 16
while 1:
predictions = []
try:
for i in tqdm(range(0, len(images), batch_size)):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError(
'Image too big to run face detection on GPU. Please use the --resize_factor argument')
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break
results = []
pady1, pady2, padx1, padx2 = [0, 10, 0, 0]
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])
boxes = np.array(results)
if not False: boxes = get_smoothened_boxes(boxes, T=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
del detector
return results
img_size = 96
wav2lip_batch_size = 128
def datagen(frames, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
# for i, m in enumerate(mels):
for i in range(mels.qsize()):
idx = 0 if True else i%len(frames)
frame_to_save = frames[__mirror_index(1, i)].copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
m = mels.get()
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
class Human: class Human:
def __init__(self): def __init__(self):
self._fps = 50 # 20 ms per frame self._fps = 25 # 20 ms per frame
self._batch_size = 16 self._batch_size = 16
self._sample_rate = 16000 self._sample_rate = 16000
self._stride_left_size = 10 self._stride_left_size = 10
@ -151,17 +252,54 @@ class Human:
self._output_queue = mp.Queue() self._output_queue = mp.Queue()
self._res_frame_queue = mp.Queue(self._batch_size * 2) self._res_frame_queue = mp.Queue(self._batch_size * 2)
self._chunk_2_mal = Chunk2Mal(self) # self._chunk_2_mal = Chunk2Mal(self)
self._tts = TTSBase(self) # self._tts = TTSBase(self)
self.mel_chunks_queue_ = Queue()
self.test()
# face_images_path = r'./face/'
# self._face_image_paths = utils.read_files_path(face_images_path)
# print(self._face_image_paths)
# self.render_event = mp.Event()
# mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths,
# self._feat_queue, self._output_queue, self._res_frame_queue,
# )).start()
# self.render_event.set()
def test(self):
wav = audio.load_wav(r'./audio/audio1.wav', 16000)
mel = audio.melspectrogram(wav)
if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError(
'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
mel_step_size = 16
print('fps:', self._fps)
mel_idx_multiplier = 80. / self._fps
print('mel_idx_multiplier:', mel_idx_multiplier)
i = 0
while 1:
start_idx = int(i * mel_idx_multiplier)
if start_idx + mel_step_size > len(mel[0]):
# mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
self.mel_chunks_queue_.put(mel[:, len(mel[0]) - mel_step_size:])
break
# mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
self.mel_chunks_queue_.put(mel[:, start_idx: start_idx + mel_step_size])
i += 1
batch_size = 128
print('batch_size:', batch_size, ' mel_chunks len:', self.mel_chunks_queue_.qsize())
face_images_path = r'./face/' face_images_path = r'./face/'
self._face_image_paths = utils.read_files_path(face_images_path) face_images_path = utils.read_files_path(face_images_path)
print(self._face_image_paths) face_list_cycle = read_images(face_images_path)
self.render_event = mp.Event() face_images_length = len(face_list_cycle)
mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths, logging.info(f'face images length: {face_images_length}')
self._feat_queue, self._output_queue, self._res_frame_queue, print(f'face images length: {face_images_length}')
)).start() gen = datagen(face_list_cycle, self.mel_chunks_queue_)
self.render_event.set()
def get_fps(self): def get_fps(self):
return self._fps return self._fps
@ -179,10 +317,10 @@ class Human:
return self._stride_right_size return self._stride_right_size
def on_destroy(self): def on_destroy(self):
self.render_event.clear() # self.render_event.clear()
self._chunk_2_mal.stop() # self._chunk_2_mal.stop()
if self._tts is not None: # if self._tts is not None:
self._tts.stop() # self._tts.stop()
logging.info('human destroy') logging.info('human destroy')
def read(self, txt): def read(self, txt):

View File

@ -17,7 +17,7 @@ class Chunk2Mal:
self._chunks = [] self._chunks = []
# 320 samples per chunk (20ms * 16000 / 1000)audio_chunk # 320 samples per chunk (20ms * 16000 / 1000)audio_chunk
self._chunk_len = self._human.get_audio_sample_rate // self._human.get_fps() self._chunk_len = self._human.get_audio_sample_rate() // self._human.get_fps()
self._exit_event = Event() self._exit_event = Event()
self._thread = Thread(target=self._on_run) self._thread = Thread(target=self._on_run)
@ -82,7 +82,7 @@ class Chunk2Mal:
chunk = self._audio_chunk_queue.get(block=True, timeout=1) chunk = self._audio_chunk_queue.get(block=True, timeout=1)
type = 1 type = 1
except queue.Empty: except queue.Empty:
chunk = np.zeros(self._human.get_chunk(), dtype=np.float32) chunk = np.zeros(self._chunk_len, dtype=np.float32)
type = 0 type = 0
return chunk, type return chunk, type