Compare commits

...

3 Commits

Author SHA1 Message Date
b0a600c7b7 merge overide 2024-09-09 08:30:15 +08:00
c9f8ff6541 添加chunk处理 2024-09-09 08:23:04 +08:00
cf9fc3545d 添加chunk处理 2024-09-06 08:30:11 +08:00
4 changed files with 246 additions and 26 deletions

177
Human.py
View File

@ -1,23 +1,171 @@
#encoding = utf8
import logging
import multiprocessing as mp
import queue
from queue import Queue
import time
import numpy as np
from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal
import torch
import cv2
from tqdm import tqdm
logger = logging.getLogger(__name__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def __mirror_index(size, index):
# size = len(self.coord_list_cycle)
turn = index // size
res = index % size
if turn % 2 == 0:
return res
else:
return size - res - 1
# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
# .\face\img00016.jpg --audio .\audio\audio1.wav
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
model = load_model(r'.\checkpoints\wav2lip.pth')
face_list_cycle = read_images(face_images_path)
face_images_length = len(face_list_cycle)
logger.info(f'face images length: {face_images_length}')
length = len(face_list_cycle)
index = 0
count = 0
count_time = 0
logger.info('start inference')
while render_event.is_set():
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
audio_frames = []
is_all_silence = True
for _ in range(batch_size * 2):
frame, type = audio_feat_queue.get()
audio_frames.append((frame, type))
if type == 0:
is_all_silence = False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
index = index + 1
else:
t = time.perf_counter()
image_batch = []
for i in range(batch_size):
idx = __mirror_index(length, index + i)
face = face_list_cycle[idx]
image_batch.append(face)
image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch)
image_masked = image_batch.copy()
image_masked[:, face.shape[0]//2:] = 0
image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, image_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
count_time += (time.perf_counter() - t)
count += batch_size
if count >= 100:
logger.info(f"------actual avg infer fps:{count/count_time:.4f}")
count = 0
count_time = 0
for i, res_frame in enumerate(pred):
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
index = index + 1
logger.info('finish inference')
class Human:
def __init__(self):
self._tts = None
self._fps = 25 # 20 ms per frame
self._fps = 50 # 20 ms per frame
self._batch_size = 16
self._sample_rate = 16000
self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000)
self._chunk_2_mal = Chunk2Mal()
self._chunk_2_mal = Chunk2Mal(self)
self._stride_left_size = 10
self._stride_right_size = 10
self._feat_queue = mp.Queue(2)
self._output_queue = mp.Queue()
self._res_frame_queue = mp.Queue(self._batch_size * 2)
self.face_images_path = r'.\face'
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event, self._batch_size, self.face_images_path,
self._feat_queue, self._output_queue, self._res_frame_queue,
)).start()
def get_fps(self):
return self._fps
def get_batch_size(self):
return self._batch_size
def get_chunk(self):
return self._chunk
def get_stride_left_size(self):
return self._stride_left_size
def get_stride_right_size(self):
return self._stride_right_size
def on_destroy(self):
self.render_event.set()
self._chunk_2_mal.stop()
if self._tts is not None:
self._tts.stop()
logger.info('human destroy')
@ -28,6 +176,7 @@ class Human:
self._tts = tts
self._tts.start()
self._chunk_2_mal.start()
def read(self, txt):
if self._tts is None:
@ -38,13 +187,17 @@ class Human:
def push_audio_chunk(self, chunk):
self._chunk_2_mal.push_chunk(chunk)
self._audio_chunk_queue.put(chunk)
def pull_audio_chunk(self):
try:
chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
type = 1
except queue.Empty:
chunk = np.zeros(self._chunk, dtype=np.float32)
type = 0
return chunk, type
def push_feat_queue(self, mel_chunks):
print("21")
self._feat_queue.put(mel_chunks)
print("22")
# def pull_audio_chunk(self):
# try:
# chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
# type = 1
# except queue.Empty:
# chunk = np.zeros(self._chunk, dtype=np.float32)
# type = 0
# return chunk, type

View File

@ -1,11 +1,89 @@
#encoding = utf8
import logging
import queue
from queue import Queue
from threading import Thread, Event
import numpy as np
import audio
class Chunk2Mal:
def __init__(self):
def __init__(self, human):
self._audio_chunk_queue = Queue()
self._human = human
self._thread = None
self._exit_event = None
self._chunks = []
def _on_run(self):
logging.info('chunk2mal run')
while not self._exit_event.is_set():
try:
chunk, type = self.pull_chunk()
self._chunks.append(chunk)
print("1")
except queue.Empty:
continue
if len(self._chunks) <= self._human.get_stride_left_size() + self._human.get_stride_right_size():
continue
inputs = np.concatenate(self._chunks) # [N * chunk]
mel = audio.melspectrogram(inputs)
left = max(0, self._human.get_stride_left_size() * 80 / 50)
right = min(len(mel[0]), len(mel[0]) - self._human.get_stride_right_size() * 80 / 50)
mel_idx_multiplier = 80. * 2 / self._human.get_fps()
mel_step_size = 16
i = 0
mel_chunks = []
while i < (len(self._chunks) - self._human.get_stride_left_size()
- self._human.get_stride_right_size()) / 2:
print("14")
start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx)
if start_idx + mel_step_size > len(mel[0]):
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
else:
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
i += 1
print("13")
self._human.push_feat_queue(mel_chunks)
print("15")
# discard the old part to save memory
self._chunks = self._chunks[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):]
print("12")
logging.info('chunk2mal exit')
def start(self):
if self._exit_event is not None:
return
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._thread.start()
logging.info('chunk2mal start')
def stop(self):
if self._exit_event is None:
return
self._exit_event.set()
if self._thread.is_alive():
self._thread.join()
logging.info('chunk2mal stop')
def push_chunk(self, chunk):
self._audio_chunk_queue.put(chunk)
def pull_chunk(self):
try:
chunk = self._audio_chunk_queue.get(block=True, timeout=1)
type = 1
except queue.Empty:
chunk = np.zeros(self._human.get_chunk(), dtype=np.float32)
type = 0
return chunk, type

View File

@ -13,7 +13,7 @@ class TTSBase:
self._queue = Queue()
self._exit_event = None
self._io_stream = BytesIO()
self._fps = 50
self._fps = human.get_fps()
self._sample_rate = 16000
self._chunk = self._sample_rate // self._fps

13
ui.py
View File

@ -120,18 +120,7 @@ def config_logging(file_name: str, console_level: int=logging.INFO, file_level:
if __name__ == "__main__":
# logging.basicConfig(filename='./logs/info.log', level=logging.INFO)
config_logging('./logs/info.log', logging.INFO, logging.INFO)
# logger = logging.getLogger('manager')
# # 输出到控制台, 级别为DEBUG
# console = logging.StreamHandler()
# console.setLevel(logging.DEBUG)
# logger.addHandler(console)
#
# # 输出到文件, 级别为INFO, 文件按大小切分
# filelog = logging.handlers.RotatingFileHandler(filename='./logs/info.log', level=logging.INFO,
# maxBytes=1024 * 1024, backupCount=5)
# filelog.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
# logger.setLevel(logging.INFO)
# logger.addHandler(filelog)
logger.info('------------start------------')
app = App()
app.mainloop()