Compare commits
3 Commits
b0753a5220
...
b0a600c7b7
Author | SHA1 | Date | |
---|---|---|---|
b0a600c7b7 | |||
c9f8ff6541 | |||
cf9fc3545d |
177
Human.py
177
Human.py
@ -1,23 +1,171 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import multiprocessing as mp
|
||||||
import queue
|
import queue
|
||||||
from queue import Queue
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from models import Wav2Lip
|
||||||
from tts.Chunk2Mal import Chunk2Mal
|
from tts.Chunk2Mal import Chunk2Mal
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print('Using {} for inference.'.format(device))
|
||||||
|
|
||||||
|
|
||||||
|
def _load(checkpoint_path):
|
||||||
|
if device == 'cuda':
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(checkpoint_path,
|
||||||
|
map_location=lambda storage, loc: storage)
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(path):
|
||||||
|
model = Wav2Lip()
|
||||||
|
print("Load checkpoint from: {}".format(path))
|
||||||
|
checkpoint = _load(path)
|
||||||
|
s = checkpoint["state_dict"]
|
||||||
|
new_s = {}
|
||||||
|
for k, v in s.items():
|
||||||
|
new_s[k.replace('module.', '')] = v
|
||||||
|
model.load_state_dict(new_s)
|
||||||
|
model = model.to(device)
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
def read_images(img_list):
|
||||||
|
frames = []
|
||||||
|
print('reading images...')
|
||||||
|
for img_path in tqdm(img_list):
|
||||||
|
frame = cv2.imread(img_path)
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def __mirror_index(size, index):
|
||||||
|
# size = len(self.coord_list_cycle)
|
||||||
|
turn = index // size
|
||||||
|
res = index % size
|
||||||
|
if turn % 2 == 0:
|
||||||
|
return res
|
||||||
|
else:
|
||||||
|
return size - res - 1
|
||||||
|
|
||||||
|
|
||||||
|
# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
|
||||||
|
# .\face\img00016.jpg --audio .\audio\audio1.wav
|
||||||
|
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
|
||||||
|
model = load_model(r'.\checkpoints\wav2lip.pth')
|
||||||
|
face_list_cycle = read_images(face_images_path)
|
||||||
|
face_images_length = len(face_list_cycle)
|
||||||
|
logger.info(f'face images length: {face_images_length}')
|
||||||
|
|
||||||
|
length = len(face_list_cycle)
|
||||||
|
index = 0
|
||||||
|
count = 0
|
||||||
|
count_time = 0
|
||||||
|
logger.info('start inference')
|
||||||
|
while render_event.is_set():
|
||||||
|
try:
|
||||||
|
mel_batch = audio_feat_queue.get(block=True, timeout=1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
audio_frames = []
|
||||||
|
is_all_silence = True
|
||||||
|
for _ in range(batch_size * 2):
|
||||||
|
frame, type = audio_feat_queue.get()
|
||||||
|
audio_frames.append((frame, type))
|
||||||
|
|
||||||
|
if type == 0:
|
||||||
|
is_all_silence = False
|
||||||
|
|
||||||
|
if is_all_silence:
|
||||||
|
for i in range(batch_size):
|
||||||
|
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
|
||||||
|
index = index + 1
|
||||||
|
else:
|
||||||
|
t = time.perf_counter()
|
||||||
|
image_batch = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
idx = __mirror_index(length, index + i)
|
||||||
|
face = face_list_cycle[idx]
|
||||||
|
image_batch.append(face)
|
||||||
|
image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch)
|
||||||
|
|
||||||
|
image_masked = image_batch.copy()
|
||||||
|
image_masked[:, face.shape[0]//2:] = 0
|
||||||
|
|
||||||
|
image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255.
|
||||||
|
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
||||||
|
|
||||||
|
image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device)
|
||||||
|
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = model(mel_batch, image_batch)
|
||||||
|
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
||||||
|
|
||||||
|
count_time += (time.perf_counter() - t)
|
||||||
|
count += batch_size
|
||||||
|
if count >= 100:
|
||||||
|
logger.info(f"------actual avg infer fps:{count/count_time:.4f}")
|
||||||
|
count = 0
|
||||||
|
count_time = 0
|
||||||
|
|
||||||
|
for i, res_frame in enumerate(pred):
|
||||||
|
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
|
||||||
|
index = index + 1
|
||||||
|
|
||||||
|
logger.info('finish inference')
|
||||||
|
|
||||||
|
|
||||||
class Human:
|
class Human:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._tts = None
|
self._tts = None
|
||||||
self._fps = 25 # 20 ms per frame
|
self._fps = 50 # 20 ms per frame
|
||||||
|
self._batch_size = 16
|
||||||
self._sample_rate = 16000
|
self._sample_rate = 16000
|
||||||
self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000)
|
self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000)
|
||||||
self._chunk_2_mal = Chunk2Mal()
|
self._chunk_2_mal = Chunk2Mal(self)
|
||||||
|
self._stride_left_size = 10
|
||||||
|
self._stride_right_size = 10
|
||||||
|
self._feat_queue = mp.Queue(2)
|
||||||
|
self._output_queue = mp.Queue()
|
||||||
|
self._res_frame_queue = mp.Queue(self._batch_size * 2)
|
||||||
|
|
||||||
|
self.face_images_path = r'.\face'
|
||||||
|
self.render_event = mp.Event()
|
||||||
|
mp.Process(target=inference, args=(self.render_event, self._batch_size, self.face_images_path,
|
||||||
|
self._feat_queue, self._output_queue, self._res_frame_queue,
|
||||||
|
)).start()
|
||||||
|
|
||||||
|
def get_fps(self):
|
||||||
|
return self._fps
|
||||||
|
|
||||||
|
def get_batch_size(self):
|
||||||
|
return self._batch_size
|
||||||
|
|
||||||
|
def get_chunk(self):
|
||||||
|
return self._chunk
|
||||||
|
|
||||||
|
def get_stride_left_size(self):
|
||||||
|
return self._stride_left_size
|
||||||
|
|
||||||
|
def get_stride_right_size(self):
|
||||||
|
return self._stride_right_size
|
||||||
|
|
||||||
def on_destroy(self):
|
def on_destroy(self):
|
||||||
|
self.render_event.set()
|
||||||
|
self._chunk_2_mal.stop()
|
||||||
if self._tts is not None:
|
if self._tts is not None:
|
||||||
self._tts.stop()
|
self._tts.stop()
|
||||||
logger.info('human destroy')
|
logger.info('human destroy')
|
||||||
@ -28,6 +176,7 @@ class Human:
|
|||||||
|
|
||||||
self._tts = tts
|
self._tts = tts
|
||||||
self._tts.start()
|
self._tts.start()
|
||||||
|
self._chunk_2_mal.start()
|
||||||
|
|
||||||
def read(self, txt):
|
def read(self, txt):
|
||||||
if self._tts is None:
|
if self._tts is None:
|
||||||
@ -38,13 +187,17 @@ class Human:
|
|||||||
|
|
||||||
def push_audio_chunk(self, chunk):
|
def push_audio_chunk(self, chunk):
|
||||||
self._chunk_2_mal.push_chunk(chunk)
|
self._chunk_2_mal.push_chunk(chunk)
|
||||||
self._audio_chunk_queue.put(chunk)
|
|
||||||
|
|
||||||
def pull_audio_chunk(self):
|
def push_feat_queue(self, mel_chunks):
|
||||||
try:
|
print("21")
|
||||||
chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
|
self._feat_queue.put(mel_chunks)
|
||||||
type = 1
|
print("22")
|
||||||
except queue.Empty:
|
|
||||||
chunk = np.zeros(self._chunk, dtype=np.float32)
|
# def pull_audio_chunk(self):
|
||||||
type = 0
|
# try:
|
||||||
return chunk, type
|
# chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
|
||||||
|
# type = 1
|
||||||
|
# except queue.Empty:
|
||||||
|
# chunk = np.zeros(self._chunk, dtype=np.float32)
|
||||||
|
# type = 0
|
||||||
|
# return chunk, type
|
||||||
|
@ -1,11 +1,89 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import queue
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
from threading import Thread, Event
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import audio
|
||||||
|
|
||||||
|
|
||||||
class Chunk2Mal:
|
class Chunk2Mal:
|
||||||
def __init__(self):
|
def __init__(self, human):
|
||||||
self._audio_chunk_queue = Queue()
|
self._audio_chunk_queue = Queue()
|
||||||
|
self._human = human
|
||||||
|
self._thread = None
|
||||||
|
self._exit_event = None
|
||||||
|
self._chunks = []
|
||||||
|
|
||||||
|
def _on_run(self):
|
||||||
|
logging.info('chunk2mal run')
|
||||||
|
while not self._exit_event.is_set():
|
||||||
|
try:
|
||||||
|
chunk, type = self.pull_chunk()
|
||||||
|
self._chunks.append(chunk)
|
||||||
|
print("1")
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(self._chunks) <= self._human.get_stride_left_size() + self._human.get_stride_right_size():
|
||||||
|
continue
|
||||||
|
|
||||||
|
inputs = np.concatenate(self._chunks) # [N * chunk]
|
||||||
|
mel = audio.melspectrogram(inputs)
|
||||||
|
left = max(0, self._human.get_stride_left_size() * 80 / 50)
|
||||||
|
right = min(len(mel[0]), len(mel[0]) - self._human.get_stride_right_size() * 80 / 50)
|
||||||
|
mel_idx_multiplier = 80. * 2 / self._human.get_fps()
|
||||||
|
mel_step_size = 16
|
||||||
|
i = 0
|
||||||
|
mel_chunks = []
|
||||||
|
while i < (len(self._chunks) - self._human.get_stride_left_size()
|
||||||
|
- self._human.get_stride_right_size()) / 2:
|
||||||
|
print("14")
|
||||||
|
start_idx = int(left + i * mel_idx_multiplier)
|
||||||
|
# print(start_idx)
|
||||||
|
if start_idx + mel_step_size > len(mel[0]):
|
||||||
|
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
|
||||||
|
else:
|
||||||
|
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
|
||||||
|
i += 1
|
||||||
|
print("13")
|
||||||
|
self._human.push_feat_queue(mel_chunks)
|
||||||
|
print("15")
|
||||||
|
|
||||||
|
# discard the old part to save memory
|
||||||
|
self._chunks = self._chunks[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):]
|
||||||
|
print("12")
|
||||||
|
|
||||||
|
logging.info('chunk2mal exit')
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
if self._exit_event is not None:
|
||||||
|
return
|
||||||
|
self._exit_event = Event()
|
||||||
|
self._thread = Thread(target=self._on_run)
|
||||||
|
self._thread.start()
|
||||||
|
logging.info('chunk2mal start')
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self._exit_event is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._exit_event.set()
|
||||||
|
if self._thread.is_alive():
|
||||||
|
self._thread.join()
|
||||||
|
logging.info('chunk2mal stop')
|
||||||
|
|
||||||
def push_chunk(self, chunk):
|
def push_chunk(self, chunk):
|
||||||
self._audio_chunk_queue.put(chunk)
|
self._audio_chunk_queue.put(chunk)
|
||||||
|
|
||||||
|
def pull_chunk(self):
|
||||||
|
try:
|
||||||
|
chunk = self._audio_chunk_queue.get(block=True, timeout=1)
|
||||||
|
type = 1
|
||||||
|
except queue.Empty:
|
||||||
|
chunk = np.zeros(self._human.get_chunk(), dtype=np.float32)
|
||||||
|
type = 0
|
||||||
|
return chunk, type
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ class TTSBase:
|
|||||||
self._queue = Queue()
|
self._queue = Queue()
|
||||||
self._exit_event = None
|
self._exit_event = None
|
||||||
self._io_stream = BytesIO()
|
self._io_stream = BytesIO()
|
||||||
self._fps = 50
|
self._fps = human.get_fps()
|
||||||
self._sample_rate = 16000
|
self._sample_rate = 16000
|
||||||
self._chunk = self._sample_rate // self._fps
|
self._chunk = self._sample_rate // self._fps
|
||||||
|
|
||||||
|
13
ui.py
13
ui.py
@ -120,18 +120,7 @@ def config_logging(file_name: str, console_level: int=logging.INFO, file_level:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# logging.basicConfig(filename='./logs/info.log', level=logging.INFO)
|
# logging.basicConfig(filename='./logs/info.log', level=logging.INFO)
|
||||||
config_logging('./logs/info.log', logging.INFO, logging.INFO)
|
config_logging('./logs/info.log', logging.INFO, logging.INFO)
|
||||||
# logger = logging.getLogger('manager')
|
|
||||||
# # 输出到控制台, 级别为DEBUG
|
|
||||||
# console = logging.StreamHandler()
|
|
||||||
# console.setLevel(logging.DEBUG)
|
|
||||||
# logger.addHandler(console)
|
|
||||||
#
|
|
||||||
# # 输出到文件, 级别为INFO, 文件按大小切分
|
|
||||||
# filelog = logging.handlers.RotatingFileHandler(filename='./logs/info.log', level=logging.INFO,
|
|
||||||
# maxBytes=1024 * 1024, backupCount=5)
|
|
||||||
# filelog.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
|
|
||||||
# logger.setLevel(logging.INFO)
|
|
||||||
# logger.addHandler(filelog)
|
|
||||||
logger.info('------------start------------')
|
logger.info('------------start------------')
|
||||||
app = App()
|
app = App()
|
||||||
app.mainloop()
|
app.mainloop()
|
||||||
|
Loading…
Reference in New Issue
Block a user