human/Human.py
2024-09-22 16:41:19 +08:00

360 lines
12 KiB
Python

#encoding = utf8
import logging
import multiprocessing as mp
import queue
import time
import numpy as np
import audio
import face_detection
import utils
from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal
import torch
import cv2
from tqdm import tqdm
from queue import Queue
from tts.EdgeTTS import EdgeTTS
from tts.TTSBase import TTSBase
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
logging.info(f'Load checkpoint from {path}')
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
print(f'read image path:{img_path}')
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def __mirror_index(size, index):
# size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
# .\face\img00016.jpg --audio .\audio\audio1.wav
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
logging.info(f'Using {device} for inference.')
print(f'Using {device} for inference.')
print(f'face_images_path: {face_images_path}')
model = load_model(r'.\checkpoints\wav2lip.pth')
face_list_cycle = read_images(face_images_path)
face_images_length = len(face_list_cycle)
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
length = len(face_list_cycle)
index = 0
count = 0
count_time = 0
logging.info('start inference')
print(f'start inference: {render_event.is_set()}')
while render_event.is_set():
mel_batch = []
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
audio_frames = []
is_all_silence = True
for _ in range(batch_size * 2):
frame, type = audio_out_queue.get()
audio_frames.append((frame, type))
if type == 0:
is_all_silence = False
print(f'is_all_silence {is_all_silence}')
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
index = index + 1
else:
t = time.perf_counter()
image_batch = []
for i in range(batch_size):
idx = __mirror_index(length, index + i)
face = face_list_cycle[idx]
image_batch.append(face)
image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch)
image_masked = image_batch.copy()
image_masked[:, face.shape[0]//2:] = 0
image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, image_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
count_time += (time.perf_counter() - t)
count += batch_size
if count >= 100:
logging.info(f"------actual avg infer fps:{count/count_time:.4f}")
count = 0
count_time = 0
for i, res_frame in enumerate(pred):
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
index = index + 1
logging.info('finish inference')
def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i : i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def face_detect(images):
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
batch_size = 16
while 1:
predictions = []
try:
for i in tqdm(range(0, len(images), batch_size)):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError(
'Image too big to run face detection on GPU. Please use the --resize_factor argument')
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break
results = []
pady1, pady2, padx1, padx2 = [0, 10, 0, 0]
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])
boxes = np.array(results)
if not False: boxes = get_smoothened_boxes(boxes, T=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
del detector
return results
img_size = 96
wav2lip_batch_size = 128
def datagen(frames, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
# for i, m in enumerate(mels):
for i in range(mels.qsize()):
idx = 0 if True else i%len(frames)
frame_to_save = frames[__mirror_index(1, i)].copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
m = mels.get()
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
yield img_batch, mel_batch, frame_batch, coords_batch
class Human:
def __init__(self):
self._fps = 25 # 20 ms per frame
self._batch_size = 16
self._sample_rate = 16000
self._stride_left_size = 10
self._stride_right_size = 10
self._feat_queue = mp.Queue(2)
self._output_queue = mp.Queue()
self._res_frame_queue = mp.Queue(self._batch_size * 2)
# self._chunk_2_mal = Chunk2Mal(self)
# self._tts = TTSBase(self)
self.mel_chunks_queue_ = Queue()
self.test()
# face_images_path = r'./face/'
# self._face_image_paths = utils.read_files_path(face_images_path)
# print(self._face_image_paths)
# self.render_event = mp.Event()
# mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths,
# self._feat_queue, self._output_queue, self._res_frame_queue,
# )).start()
# self.render_event.set()
def test(self):
wav = audio.load_wav(r'./audio/audio1.wav', 16000)
mel = audio.melspectrogram(wav)
if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError(
'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
mel_step_size = 16
print('fps:', self._fps)
mel_idx_multiplier = 80. / self._fps
print('mel_idx_multiplier:', mel_idx_multiplier)
i = 0
while 1:
start_idx = int(i * mel_idx_multiplier)
if start_idx + mel_step_size > len(mel[0]):
# mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
self.mel_chunks_queue_.put(mel[:, len(mel[0]) - mel_step_size:])
break
# mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
self.mel_chunks_queue_.put(mel[:, start_idx: start_idx + mel_step_size])
i += 1
batch_size = 128
print('batch_size:', batch_size, ' mel_chunks len:', self.mel_chunks_queue_.qsize())
face_images_path = r'./face/'
face_images_path = utils.read_files_path(face_images_path)
face_list_cycle = read_images(face_images_path)
face_images_length = len(face_list_cycle)
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
gen = datagen(face_list_cycle, self.mel_chunks_queue_)
def get_fps(self):
return self._fps
def get_batch_size(self):
return self._batch_size
def get_audio_sample_rate(self):
return self._sample_rate
def get_stride_left_size(self):
return self._stride_left_size
def get_stride_right_size(self):
return self._stride_right_size
def on_destroy(self):
# self.render_event.clear()
# self._chunk_2_mal.stop()
# if self._tts is not None:
# self._tts.stop()
logging.info('human destroy')
def read(self, txt):
if self._tts is None:
logging.warning('tts is none')
return
self._tts.push_txt(txt)
def push_audio_chunk(self, audio_chunk):
self._chunk_2_mal.push_chunk(audio_chunk)
def push_feat_queue(self, mel_chunks):
print("push_feat_queue")
self._feat_queue.put(mel_chunks)
def push_audio_frames(self, chunk, type_):
print("push_audio_frames")
self._output_queue.put((chunk, type_))
def render(self):
try:
img, aud = self._res_frame_queue.get(block=True, timeout=.3)
except queue.Empty:
# print('render queue.Empty:')
return None
return img
# def pull_audio_chunk(self):
# try:
# chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
# type = 1
# except queue.Empty:
# chunk = np.zeros(self._chunk, dtype=np.float32)
# type = 0
# return chunk, type