modify test push
This commit is contained in:
parent
17d9437425
commit
5af8ba1878
83
Human.py
83
Human.py
@ -2,9 +2,11 @@
|
||||
import logging
|
||||
|
||||
import multiprocessing as mp
|
||||
import platform, subprocess
|
||||
import queue
|
||||
import time
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
import audio
|
||||
@ -162,7 +164,7 @@ def face_detect(images):
|
||||
while 1:
|
||||
predictions = []
|
||||
try:
|
||||
for i in tqdm(range(0, len(images), batch_size)):
|
||||
for i in range(0, len(images), batch_size):
|
||||
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
|
||||
except RuntimeError:
|
||||
if batch_size == 1:
|
||||
@ -240,6 +242,44 @@ def datagen(frames, mels):
|
||||
yield img_batch, mel_batch, frame_batch, coords_batch
|
||||
|
||||
|
||||
def datagen_signal(frame, mel, face_det_results):
|
||||
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
|
||||
|
||||
# for i, m in enumerate(mels):
|
||||
idx = 0
|
||||
frame_to_save = frame.copy()
|
||||
face, coords = face_det_results[idx].copy()
|
||||
|
||||
face = cv2.resize(face, (img_size, img_size))
|
||||
m = mel
|
||||
|
||||
img_batch.append(face)
|
||||
mel_batch.append(m)
|
||||
frame_batch.append(frame_to_save)
|
||||
coords_batch.append(coords)
|
||||
|
||||
if len(img_batch) >= wav2lip_batch_size:
|
||||
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
||||
|
||||
img_masked = img_batch.copy()
|
||||
img_masked[:, img_size // 2:] = 0
|
||||
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
||||
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
||||
|
||||
return img_batch, mel_batch, frame_batch, coords_batch
|
||||
|
||||
|
||||
if len(img_batch) > 0:
|
||||
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
||||
img_masked = img_batch.copy()
|
||||
img_masked[:, img_size//2:] = 0
|
||||
|
||||
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
||||
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
||||
|
||||
return img_batch, mel_batch, frame_batch, coords_batch
|
||||
|
||||
|
||||
|
||||
class Human:
|
||||
def __init__(self):
|
||||
@ -299,7 +339,46 @@ class Human:
|
||||
face_images_length = len(face_list_cycle)
|
||||
logging.info(f'face images length: {face_images_length}')
|
||||
print(f'face images length: {face_images_length}')
|
||||
gen = datagen(face_list_cycle, self.mel_chunks_queue_)
|
||||
|
||||
model = load_model(r'.\checkpoints\wav2lip.pth')
|
||||
print("Model loaded")
|
||||
|
||||
frame_h, frame_w = face_list_cycle[0].shape[:-1]
|
||||
out = cv2.VideoWriter('temp/resul_tttt.avi',
|
||||
cv2.VideoWriter_fourcc(*'DIVX'), 25, (frame_w, frame_h))
|
||||
|
||||
face_det_results = face_detect(face_list_cycle)
|
||||
|
||||
j = 0
|
||||
while not self.mel_chunks_queue_.empty():
|
||||
print("self.mel_chunks_queue_ len:", self.mel_chunks_queue_.qsize())
|
||||
m = self.mel_chunks_queue_.get()
|
||||
img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results)
|
||||
|
||||
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
||||
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = model(mel_batch, img_batch)
|
||||
|
||||
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
||||
for p, f, c in zip(pred, frames, coords):
|
||||
y1, y2, x1, x2 = c
|
||||
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
|
||||
|
||||
f[y1:y2, x1:x2] = p
|
||||
# name = "%04d" % j
|
||||
# cv2.imwrite(f'temp/images/{j}.jpg', p)
|
||||
# j = j + 1
|
||||
out.write(f)
|
||||
|
||||
out.release()
|
||||
command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format('./audio/audio1.wav', 'temp/resul_tttt.avi',
|
||||
'temp/resul_tttt.mp4')
|
||||
subprocess.call(command, shell=platform.system() != 'Windows')
|
||||
|
||||
|
||||
# gen = datagen(face_list_cycle, self.mel_chunks_queue_)
|
||||
|
||||
def get_fps(self):
|
||||
return self._fps
|
||||
|
Loading…
Reference in New Issue
Block a user