modify test push

This commit is contained in:
brige 2024-09-23 15:52:39 +08:00
parent 17d9437425
commit 5af8ba1878

View File

@ -2,9 +2,11 @@
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import platform, subprocess
import queue import queue
import time import time
import numpy as np import numpy as np
import audio import audio
@ -162,7 +164,7 @@ def face_detect(images):
while 1: while 1:
predictions = [] predictions = []
try: try:
for i in tqdm(range(0, len(images), batch_size)): for i in range(0, len(images), batch_size):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError: except RuntimeError:
if batch_size == 1: if batch_size == 1:
@ -240,6 +242,44 @@ def datagen(frames, mels):
yield img_batch, mel_batch, frame_batch, coords_batch yield img_batch, mel_batch, frame_batch, coords_batch
def datagen_signal(frame, mel, face_det_results):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
# for i, m in enumerate(mels):
idx = 0
frame_to_save = frame.copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
m = mel
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coords_batch
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coords_batch
class Human: class Human:
def __init__(self): def __init__(self):
@ -299,7 +339,46 @@ class Human:
face_images_length = len(face_list_cycle) face_images_length = len(face_list_cycle)
logging.info(f'face images length: {face_images_length}') logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}')
gen = datagen(face_list_cycle, self.mel_chunks_queue_)
model = load_model(r'.\checkpoints\wav2lip.pth')
print("Model loaded")
frame_h, frame_w = face_list_cycle[0].shape[:-1]
out = cv2.VideoWriter('temp/resul_tttt.avi',
cv2.VideoWriter_fourcc(*'DIVX'), 25, (frame_w, frame_h))
face_det_results = face_detect(face_list_cycle)
j = 0
while not self.mel_chunks_queue_.empty():
print("self.mel_chunks_queue_ len:", self.mel_chunks_queue_.qsize())
m = self.mel_chunks_queue_.get()
img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results)
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
for p, f, c in zip(pred, frames, coords):
y1, y2, x1, x2 = c
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
f[y1:y2, x1:x2] = p
# name = "%04d" % j
# cv2.imwrite(f'temp/images/{j}.jpg', p)
# j = j + 1
out.write(f)
out.release()
command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format('./audio/audio1.wav', 'temp/resul_tttt.avi',
'temp/resul_tttt.mp4')
subprocess.call(command, shell=platform.system() != 'Windows')
# gen = datagen(face_list_cycle, self.mel_chunks_queue_)
def get_fps(self): def get_fps(self):
return self._fps return self._fps