From 5af8ba1878d1e0bdcc2862464c06c6882a13e8fc Mon Sep 17 00:00:00 2001 From: brige Date: Mon, 23 Sep 2024 15:52:39 +0800 Subject: [PATCH] modify test push --- Human.py | 83 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/Human.py b/Human.py index c1f6e8d..d27a65b 100644 --- a/Human.py +++ b/Human.py @@ -2,9 +2,11 @@ import logging import multiprocessing as mp +import platform, subprocess import queue import time + import numpy as np import audio @@ -162,7 +164,7 @@ def face_detect(images): while 1: predictions = [] try: - for i in tqdm(range(0, len(images), batch_size)): + for i in range(0, len(images), batch_size): predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) except RuntimeError: if batch_size == 1: @@ -240,6 +242,44 @@ def datagen(frames, mels): yield img_batch, mel_batch, frame_batch, coords_batch +def datagen_signal(frame, mel, face_det_results): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + # for i, m in enumerate(mels): + idx = 0 + frame_to_save = frame.copy() + face, coords = face_det_results[idx].copy() + + face = cv2.resize(face, (img_size, img_size)) + m = mel + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, img_size // 2:] = 0 + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + return img_batch, mel_batch, frame_batch, coords_batch + + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + img_masked = img_batch.copy() + img_masked[:, img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + return img_batch, mel_batch, frame_batch, coords_batch + + class Human: def __init__(self): @@ -299,7 +339,46 @@ class Human: face_images_length = len(face_list_cycle) logging.info(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}') - gen = datagen(face_list_cycle, self.mel_chunks_queue_) + + model = load_model(r'.\checkpoints\wav2lip.pth') + print("Model loaded") + + frame_h, frame_w = face_list_cycle[0].shape[:-1] + out = cv2.VideoWriter('temp/resul_tttt.avi', + cv2.VideoWriter_fourcc(*'DIVX'), 25, (frame_w, frame_h)) + + face_det_results = face_detect(face_list_cycle) + + j = 0 + while not self.mel_chunks_queue_.empty(): + print("self.mel_chunks_queue_ len:", self.mel_chunks_queue_.qsize()) + m = self.mel_chunks_queue_.get() + img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results) + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + + with torch.no_grad(): + pred = model(mel_batch, img_batch) + + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + for p, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) + + f[y1:y2, x1:x2] = p + # name = "%04d" % j + # cv2.imwrite(f'temp/images/{j}.jpg', p) + # j = j + 1 + out.write(f) + + out.release() + command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format('./audio/audio1.wav', 'temp/resul_tttt.avi', + 'temp/resul_tttt.mp4') + subprocess.call(command, shell=platform.system() != 'Windows') + + + # gen = datagen(face_list_cycle, self.mel_chunks_queue_) def get_fps(self): return self._fps