modify test push

This commit is contained in:
brige 2024-09-23 15:52:39 +08:00
parent 17d9437425
commit 5af8ba1878

View File

@ -2,9 +2,11 @@
import logging
import multiprocessing as mp
import platform, subprocess
import queue
import time
import numpy as np
import audio
@ -162,7 +164,7 @@ def face_detect(images):
while 1:
predictions = []
try:
for i in tqdm(range(0, len(images), batch_size)):
for i in range(0, len(images), batch_size):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
@ -240,6 +242,44 @@ def datagen(frames, mels):
yield img_batch, mel_batch, frame_batch, coords_batch
def datagen_signal(frame, mel, face_det_results):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
# for i, m in enumerate(mels):
idx = 0
frame_to_save = frame.copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
m = mel
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coords_batch
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coords_batch
class Human:
def __init__(self):
@ -299,7 +339,46 @@ class Human:
face_images_length = len(face_list_cycle)
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
gen = datagen(face_list_cycle, self.mel_chunks_queue_)
model = load_model(r'.\checkpoints\wav2lip.pth')
print("Model loaded")
frame_h, frame_w = face_list_cycle[0].shape[:-1]
out = cv2.VideoWriter('temp/resul_tttt.avi',
cv2.VideoWriter_fourcc(*'DIVX'), 25, (frame_w, frame_h))
face_det_results = face_detect(face_list_cycle)
j = 0
while not self.mel_chunks_queue_.empty():
print("self.mel_chunks_queue_ len:", self.mel_chunks_queue_.qsize())
m = self.mel_chunks_queue_.get()
img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results)
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
for p, f, c in zip(pred, frames, coords):
y1, y2, x1, x2 = c
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
f[y1:y2, x1:x2] = p
# name = "%04d" % j
# cv2.imwrite(f'temp/images/{j}.jpg', p)
# j = j + 1
out.write(f)
out.release()
command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format('./audio/audio1.wav', 'temp/resul_tttt.avi',
'temp/resul_tttt.mp4')
subprocess.call(command, shell=platform.system() != 'Windows')
# gen = datagen(face_list_cycle, self.mel_chunks_queue_)
def get_fps(self):
return self._fps