human/Human.py
2024-09-09 08:23:04 +08:00

196 lines
5.8 KiB
Python

#encoding = utf8
import logging
import multiprocessing as mp
import queue
import time
import numpy as np
from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal
import torch
import cv2
from tqdm import tqdm
logger = logging.getLogger(__name__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def __mirror_index(size, index):
# size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
# .\face\img00016.jpg --audio .\audio\audio1.wav
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
model = load_model(r'.\checkpoints\wav2lip.pth')
face_list_cycle = read_images(face_images_path)
face_images_length = len(face_list_cycle)
logger.info(f'face images length: {face_images_length}')
length = len(face_list_cycle)
index = 0
count = 0
count_time = 0
logger.info('start inference')
while render_event.is_set():
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
audio_frames = []
is_all_silence = True
for _ in range(batch_size * 2):
frame, type = audio_feat_queue.get()
audio_frames.append((frame, type))
if type == 0:
is_all_silence = False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
index = index + 1
else:
t = time.perf_counter()
image_batch = []
for i in range(batch_size):
idx = __mirror_index(length, index + i)
face = face_list_cycle[idx]
image_batch.append(face)
image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch)
image_masked = image_batch.copy()
image_masked[:, face.shape[0]//2:] = 0
image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, image_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
count_time += (time.perf_counter() - t)
count += batch_size
if count >= 100:
logger.info(f"------actual avg infer fps:{count/count_time:.4f}")
count = 0
count_time = 0
for i, res_frame in enumerate(pred):
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
index = index + 1
logger.info('finish inference')
class Human:
def __init__(self):
self._tts = None
self._fps = 50 # 20 ms per frame
self._batch_size = 16
self._sample_rate = 16000
self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000)
self._chunk_2_mal = Chunk2Mal(self)
self._stride_left_size = 10
self._stride_right_size = 10
self._feat_queue = mp.Queue(2)
self._output_queue = mp.Queue()
self._res_frame_queue = mp.Queue(self._batch_size * 2)
self.face_images_path = r'.\face'
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event, self._batch_size, self.face_images_path,
self._feat_queue, self._output_queue, self._res_frame_queue,
)).start()
def get_fps(self):
return self._fps
def get_batch_size(self):
return self._batch_size
def get_chunk(self):
return self._chunk
def get_stride_left_size(self):
return self._stride_left_size
def get_stride_right_size(self):
return self._stride_right_size
def on_destroy(self):
self.render_event.set()
self._chunk_2_mal.stop()
if self._tts is not None:
self._tts.stop()
logger.info('human destroy')
def set_tts(self, tts):
if self._tts == tts:
return
self._tts = tts
self._tts.start()
self._chunk_2_mal.start()
def read(self, txt):
if self._tts is None:
logger.warning('tts is none')
return
self._tts.push_txt(txt)
def push_audio_chunk(self, chunk):
self._chunk_2_mal.push_chunk(chunk)
def push_feat_queue(self, mel_chunks):
print("21")
self._feat_queue.put(mel_chunks)
print("22")