human/Human.py

222 lines
6.7 KiB
Python
Raw Normal View History

2024-09-02 00:13:34 +00:00
#encoding = utf8
2024-09-04 16:51:14 +00:00
import logging
2024-09-09 00:30:15 +00:00
2024-09-04 16:51:14 +00:00
import multiprocessing as mp
2024-09-09 00:23:04 +00:00
import queue
import time
2024-09-04 16:51:14 +00:00
2024-09-09 00:23:04 +00:00
import numpy as np
2024-09-12 00:15:09 +00:00
import utils
2024-09-09 00:23:04 +00:00
from models import Wav2Lip
2024-09-04 16:51:14 +00:00
from tts.Chunk2Mal import Chunk2Mal
2024-09-09 00:23:04 +00:00
import torch
import cv2
from tqdm import tqdm
2024-09-04 16:51:14 +00:00
2024-09-09 00:23:04 +00:00
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
2024-09-12 00:15:09 +00:00
logging.info(f'Load checkpoint from {path}')
2024-09-09 00:23:04 +00:00
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
2024-09-12 00:15:09 +00:00
print(f'read image path:{img_path}')
2024-09-09 00:23:04 +00:00
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def __mirror_index(size, index):
# size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
# .\face\img00016.jpg --audio .\audio\audio1.wav
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
2024-09-12 00:15:09 +00:00
logging.info(f'Using {device} for inference.')
print(f'Using {device} for inference.')
print(f'face_images_path: {face_images_path}')
2024-09-09 00:23:04 +00:00
model = load_model(r'.\checkpoints\wav2lip.pth')
face_list_cycle = read_images(face_images_path)
face_images_length = len(face_list_cycle)
2024-09-12 00:15:09 +00:00
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
2024-09-09 00:23:04 +00:00
length = len(face_list_cycle)
index = 0
count = 0
count_time = 0
2024-09-12 00:15:09 +00:00
logging.info('start inference')
print(f'start inference: {render_event.is_set()}')
2024-09-09 00:23:04 +00:00
while render_event.is_set():
2024-09-12 00:15:09 +00:00
print('start inference')
2024-09-09 00:23:04 +00:00
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
audio_frames = []
is_all_silence = True
for _ in range(batch_size * 2):
frame, type = audio_feat_queue.get()
audio_frames.append((frame, type))
if type == 0:
is_all_silence = False
2024-09-12 00:15:09 +00:00
print(f'is_all_silence {is_all_silence}')
2024-09-09 00:23:04 +00:00
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
index = index + 1
else:
t = time.perf_counter()
image_batch = []
for i in range(batch_size):
idx = __mirror_index(length, index + i)
face = face_list_cycle[idx]
image_batch.append(face)
image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch)
image_masked = image_batch.copy()
image_masked[:, face.shape[0]//2:] = 0
image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, image_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
count_time += (time.perf_counter() - t)
count += batch_size
if count >= 100:
2024-09-12 00:15:09 +00:00
logging.info(f"------actual avg infer fps:{count/count_time:.4f}")
2024-09-09 00:23:04 +00:00
count = 0
count_time = 0
for i, res_frame in enumerate(pred):
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
index = index + 1
2024-09-12 00:15:09 +00:00
logging.info('finish inference')
2024-09-09 00:23:04 +00:00
2024-09-02 00:13:34 +00:00
class Human:
def __init__(self):
self._tts = None
2024-09-04 16:51:14 +00:00
self._fps = 50 # 20 ms per frame
self._batch_size = 16
self._sample_rate = 16000
self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000)
self._chunk_2_mal = Chunk2Mal(self)
self._stride_left_size = 10
self._stride_right_size = 10
self._feat_queue = mp.Queue(2)
2024-09-09 00:23:04 +00:00
self._output_queue = mp.Queue()
self._res_frame_queue = mp.Queue(self._batch_size * 2)
2024-09-12 00:15:09 +00:00
face_images_path = r'./face/'
self._face_image_paths = utils.read_files_path(face_images_path)
print(self._face_image_paths)
2024-09-09 00:23:04 +00:00
self.render_event = mp.Event()
2024-09-12 00:15:09 +00:00
mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths,
2024-09-09 00:23:04 +00:00
self._feat_queue, self._output_queue, self._res_frame_queue,
)).start()
2024-09-04 16:51:14 +00:00
def get_fps(self):
return self._fps
def get_batch_size(self):
return self._batch_size
def get_chunk(self):
return self._chunk
def get_stride_left_size(self):
return self._stride_left_size
def get_stride_right_size(self):
return self._stride_right_size
def on_destroy(self):
2024-09-09 00:23:04 +00:00
self.render_event.set()
2024-09-04 16:51:14 +00:00
self._chunk_2_mal.stop()
if self._tts is not None:
self._tts.stop()
2024-09-12 00:15:09 +00:00
logging.info('human destroy')
2024-09-02 00:13:34 +00:00
def set_tts(self, tts):
2024-09-04 16:51:14 +00:00
if self._tts == tts:
return
2024-09-02 00:13:34 +00:00
self._tts = tts
2024-09-04 16:51:14 +00:00
self._tts.start()
self._chunk_2_mal.start()
def read(self, txt):
if self._tts is None:
2024-09-12 00:15:09 +00:00
logging.warning('tts is none')
2024-09-04 16:51:14 +00:00
return
self._tts.push_txt(txt)
2024-09-02 00:13:34 +00:00
def push_audio_chunk(self, chunk):
2024-09-04 16:51:14 +00:00
self._chunk_2_mal.push_chunk(chunk)
def push_feat_queue(self, mel_chunks):
2024-09-09 00:23:04 +00:00
print("21")
2024-09-04 16:51:14 +00:00
self._feat_queue.put(mel_chunks)
2024-09-09 00:23:04 +00:00
print("22")
2024-09-09 00:30:15 +00:00
2024-09-12 00:15:09 +00:00
def render(self):
try:
img, aud = self._res_frame_queue.get(block=True, timeout=.3)
except queue.Empty:
return None
return img
2024-09-09 00:30:15 +00:00
# def pull_audio_chunk(self):
# try:
# chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
# type = 1
# except queue.Empty:
# chunk = np.zeros(self._chunk, dtype=np.float32)
# type = 0
# return chunk, type