添加chunk处理

This commit is contained in:
jiegeaiai 2024-09-09 08:23:04 +08:00
parent cf9fc3545d
commit c9f8ff6541
3 changed files with 140 additions and 14 deletions

132
Human.py
View File

@ -1,11 +1,131 @@
#encoding = utf8
import logging
import multiprocessing as mp
import queue
import time
import numpy as np
from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal
import torch
import cv2
from tqdm import tqdm
logger = logging.getLogger(__name__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def __mirror_index(size, index):
# size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
# .\face\img00016.jpg --audio .\audio\audio1.wav
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
model = load_model(r'.\checkpoints\wav2lip.pth')
face_list_cycle = read_images(face_images_path)
face_images_length = len(face_list_cycle)
logger.info(f'face images length: {face_images_length}')
length = len(face_list_cycle)
index = 0
count = 0
count_time = 0
logger.info('start inference')
while render_event.is_set():
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
audio_frames = []
is_all_silence = True
for _ in range(batch_size * 2):
frame, type = audio_feat_queue.get()
audio_frames.append((frame, type))
if type == 0:
is_all_silence = False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
index = index + 1
else:
t = time.perf_counter()
image_batch = []
for i in range(batch_size):
idx = __mirror_index(length, index + i)
face = face_list_cycle[idx]
image_batch.append(face)
image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch)
image_masked = image_batch.copy()
image_masked[:, face.shape[0]//2:] = 0
image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, image_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
count_time += (time.perf_counter() - t)
count += batch_size
if count >= 100:
logger.info(f"------actual avg infer fps:{count/count_time:.4f}")
count = 0
count_time = 0
for i, res_frame in enumerate(pred):
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
index = index + 1
logger.info('finish inference')
class Human:
def __init__(self):
@ -18,6 +138,14 @@ class Human:
self._stride_left_size = 10
self._stride_right_size = 10
self._feat_queue = mp.Queue(2)
self._output_queue = mp.Queue()
self._res_frame_queue = mp.Queue(self._batch_size * 2)
self.face_images_path = r'.\face'
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event, self._batch_size, self.face_images_path,
self._feat_queue, self._output_queue, self._res_frame_queue,
)).start()
def get_fps(self):
return self._fps
@ -35,6 +163,8 @@ class Human:
return self._stride_right_size
def on_destroy(self):
self.render_event.set()
self._chunk_2_mal.stop()
if self._tts is not None:
@ -60,4 +190,6 @@ class Human:
self._chunk_2_mal.push_chunk(chunk)
def push_feat_queue(self, mel_chunks):
print("21")
self._feat_queue.put(mel_chunks)
print("22")

View File

@ -22,6 +22,7 @@ class Chunk2Mal:
try:
chunk, type = self.pull_chunk()
self._chunks.append(chunk)
print("1")
except queue.Empty:
continue
@ -38,6 +39,7 @@ class Chunk2Mal:
mel_chunks = []
while i < (len(self._chunks) - self._human.get_stride_left_size()
- self._human.get_stride_right_size()) / 2:
print("14")
start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx)
if start_idx + mel_step_size > len(mel[0]):
@ -45,10 +47,13 @@ class Chunk2Mal:
else:
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
i += 1
print("13")
self._human.push_feat_queue(mel_chunks)
print("15")
# discard the old part to save memory
self._chunks = self._chunks[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):]
print("12")
logging.info('chunk2mal exit')
@ -65,7 +70,8 @@ class Chunk2Mal:
return
self._exit_event.set()
self._thread.join()
if self._thread.is_alive():
self._thread.join()
logging.info('chunk2mal stop')
def push_chunk(self, chunk):
@ -73,7 +79,7 @@ class Chunk2Mal:
def pull_chunk(self):
try:
chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
chunk = self._audio_chunk_queue.get(block=True, timeout=1)
type = 1
except queue.Empty:
chunk = np.zeros(self._human.get_chunk(), dtype=np.float32)

12
ui.py
View File

@ -120,18 +120,6 @@ def config_logging(file_name: str, console_level: int=logging.INFO, file_level:
if __name__ == "__main__":
# logging.basicConfig(filename='./logs/info.log', level=logging.INFO)
config_logging('./logs/info.log', logging.INFO, logging.INFO)
# logger = logging.getLogger('manager')
# # 输出到控制台, 级别为DEBUG
# console = logging.StreamHandler()
# console.setLevel(logging.DEBUG)
# logger.addHandler(console)
#
# # 输出到文件, 级别为INFO, 文件按大小切分
# filelog = logging.handlers.RotatingFileHandler(filename='./logs/info.log', level=logging.INFO,
# maxBytes=1024 * 1024, backupCount=5)
# filelog.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
# logger.setLevel(logging.INFO)
# logger.addHandler(filelog)
logger.info('------------start------------')
app = App()
app.mainloop()