human/infer.py
2024-10-04 14:37:50 +08:00

431 lines
16 KiB
Python

#encoding = utf8
import os
import glob
import queue
import multiprocessing as mp
import time
from queue import Queue
from threading import Thread, Event
import logging
import cv2
import numpy as np
import torch
from tqdm import tqdm
import face_detection
import utils
from models import Wav2Lip
from utils import mirror_index
logger = logging.getLogger(__name__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
print(f'read image path:{img_path}')
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
logging.info(f'Load checkpoint from {path}')
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i : i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def face_detect(images):
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
batch_size = 16
while 1:
predictions = []
try:
for i in range(0, len(images), batch_size):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError(
'Image too big to run face detection on GPU. Please use the --resize_factor argument')
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break
results = []
pady1, pady2, padx1, padx2 = [0, 10, 0, 0]
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])
boxes = np.array(results)
boxes = get_smoothened_boxes(boxes, T=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
del detector
return results
img_size = 96
wav2lip_batch_size = 128
def datagen(frames, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
i = 0
for m in enumerate(mels):
# for i in range(mels.qsize()):
idx = 0 if True else i%len(frames)
frame_to_save = frames[mirror_index(1, i)].copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
i = i + 1
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
def datagen_signal(frame, mel, face_det_results):
img_batch, mel_batch, frame_batch, coord_batch = [], [], [], []
# for i, m in enumerate(mels):
idx = 0
frame_to_save = frame.copy()
face, coord = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
for i, m in enumerate(mel):
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coord_batch.append(coord)
if len(img_batch) >= wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coord_batch
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coord_batch
def inference(render_event, batch_size, face_imgs_path, audio_feat_queue, audio_out_queue, res_frame_queue):
model = load_model(r'.\checkpoints\wav2lip.pth')
# face_list_cycle = read_images(face_imgs_path)
input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
face_list_cycle = read_images(input_face_list)
# input_latent_list_cycle = torch.load(latents_out_path)
length = len(face_list_cycle)
index = 0
count = 0
counttime = 0
print('start inference')
while True:
if render_event.is_set():
starttime = time.perf_counter()
mel_batch = []
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence = True
audio_frames = []
for _ in range(batch_size * 2):
frame, type_ = audio_out_queue.get()
audio_frames.append((frame, type_))
if type_ == 0:
is_all_silence = False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
index = index + 1
else:
print('infer=======')
t = time.perf_counter()
img_batch = []
for i in range(batch_size):
idx = mirror_index(length, index + i)
face = face_list_cycle[idx]
img_batch.append(face)
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, face.shape[0] // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
counttime += (time.perf_counter() - t)
count += batch_size
# _totalframe += 1
if count >= 100:
print(f"------actual avg infer fps:{count / counttime:.4f}")
count = 0
counttime = 0
for i, res_frame in enumerate(pred):
# self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
index = index + 1
# print('total batch time:',time.perf_counter()-starttime)
else:
time.sleep(1)
print('musereal inference processor stop')
class Infer:
def __init__(self, human):
self._human = human
# self._feat_queue = Queue()
# self._audio_out_queue = Queue()
self.batch_size = human.get_batch_size()
self.asr = human.chunk_2_mal
self.res_frame_queue = human.res_render_queue
# self._exit_event = Event()
# face_images_path = r'./face/'
# self.face_images_path = utils.read_files_path(face_images_path)
self.avatar_id = 'wav2lip_avatar1'
self.avatar_path = f"./data/{self.avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.face_images_path = f"{self.avatar_path}/face_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event, self.batch_size, self.face_images_path,
self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue,
)).start()
self.render_event.set()
# self._run_thread = Thread(target=self.__on_run)
# self._exit_event.set()
# self._run_thread.start()
def __on_run(self):
model = load_model(r'.\checkpoints\wav2lip.pth')
print("Model loaded")
face_list_cycle = self._human.get_face_list_cycle()
# self.__do_run1(face_list_cycle, model)
self.__do_run2(face_list_cycle, model)
# frame_h, frame_w = face_list_cycle[0].shape[:-1]
def __do_run1(self, face_list_cycle, model):
face_det_results = face_detect(face_list_cycle)
j = 0
count = 0
while self._exit_event.is_set():
try:
m = self._feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results)
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
time.sleep(0.01)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
for p, f, c in zip(pred, frames, coords):
y1, y2, x1, x2 = c
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
f[y1:y2, x1:x2] = p
# name = "%04d" % j
cv2.imwrite(f'temp/images/{j}.jpg', p)
j = j + 1
# count = count + 1
p = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
self._human.push_render_image(p)
# out.write(f)
# print('infer count:', count)
def __do_run2(self, face_list_cycle, model):
length = len(face_list_cycle)
index = 0
count = 0
count_time = 0
print('start inference')
#
# face_images_path = r'./face/'
# face_images_path = utils.read_files_path(face_images_path)
# face_list_cycle1 = read_images(face_images_path)
# face_det_results = face_detect(face_list_cycle1)
while True:
if self._exit_event.is_set():
start_time = time.perf_counter()
batch_size = self._human.get_batch_size()
try:
mel_batch = self._feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence = True
audio_frames = []
for _ in range(batch_size * 2):
frame, type_ = self._audio_out_queue.get()
audio_frames.append((frame, type_))
if type_ == 0:
is_all_silence = False
if is_all_silence:
for i in range(batch_size):
# res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])
index = index + 1
else:
print('infer=======')
t = time.perf_counter()
img_batch = []
for i in range(batch_size):
idx = mirror_index(length, index + i)
face = face_list_cycle[idx]
img_batch.append(face)
# img_batch_1, mel_batch_1, frames, coords = datagen_signal(face_list_cycle1,
# mel_batch, face_det_results)
img_batch = np.asarray(img_batch)
mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, face.shape[0] // 2:] = 0
#
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch,
[len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
# pred = model(mel_batch, img_batch) * 255.0
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
count_time += (time.perf_counter() - t)
count += self._human.batch_size()
# _totalframe += 1
if count >= 100:
print(f"------actual avg infer fps:{count / count_time:.4f}")
count = 0
count_time = 0
for i, res_frame in enumerate(pred):
# self.__pushmedia(res_frame,loop,audio_track,video_track)
# res_frame_queue.put(
# (res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
self._human.push_res_frame(res_frame, mirror_index(length, index),
audio_frames[i * 2:i * 2 + 2])
index = index + 1
# print('total batch time:',time.perf_counter()-start_time)
else:
time.sleep(1)
print('musereal inference processor stop')
def push(self, mel_chunks):
self._feat_queue.put(mel_chunks)
def push_out_queue(self, frame, type_):
self._audio_out_queue.put((frame, type_))
def get_out_put(self):
return self._audio_out_queue.get()