modify human

This commit is contained in:
brige 2024-10-04 14:37:50 +08:00
parent b75ecce46a
commit a71740f40c
4 changed files with 260 additions and 80 deletions

View File

@ -1,9 +1,12 @@
#encoding = utf8
import copy
import glob
import io
import logging
import multiprocessing as mp
import os
import pickle
import platform, subprocess
import queue
import threading
@ -16,7 +19,7 @@ import pyaudio
import audio
import face_detection
import utils
from infer import Infer
from infer import Infer, read_images
from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal
import torch
@ -54,16 +57,6 @@ def load_model(path):
return model.eval()
def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
print(f'read image path:{img_path}')
frame = cv2.imread(img_path)
frames.append(frame)
return frames
# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
# .\face\img00016.jpg --audio .\audio\audio1.wav
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
@ -295,24 +288,34 @@ class Human:
self._output_queue = mp.Queue()
self._res_frame_queue = mp.Queue(self._batch_size * 2)
full_images, face_frames, coord_frames = self._avatar()
self._frame_list_cycle = full_images
self._face_list_cycle = face_frames
self._coord_list_cycle = coord_frames
face_images_length = len(self._face_list_cycle)
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
# full_images, face_frames, coord_frames = self._avatar()
# self._frame_list_cycle = full_images
# self._face_list_cycle = face_frames
# self._coord_list_cycle = coord_frames
# face_images_length = len(self._face_list_cycle)
# logging.info(f'face images length: {face_images_length}')
# print(f'face images length: {face_images_length}')
self.avatar_id = 'wav2lip_avatar1'
self.avatar_path = f"./data/{self.avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.face_imgs_path = f"{self.avatar_path}/face_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.__loadavatar()
self.mel_chunks_queue_ = Queue()
self.audio_chunks_queue_ = Queue()
self._test_image_queue = Queue()
self._res_render_queue = Queue()
# self._res_render_queue = Queue()
self._chunk_2_mal = Chunk2Mal(self)
self.res_render_queue = mp.Queue(self._batch_size * 2)
self.chunk_2_mal = Chunk2Mal(self)
self._tts = TTSBase(self)
self._infer = Infer(self)
self.chunk_2_mal.warm_up()
# #
#
# self._thread = None
# thread = threading.Thread(target=self.test)
# thread.start()
@ -341,6 +344,13 @@ class Human:
# stream.close()
# p.terminate()
def __loadavatar(self):
with open(self.coords_path, 'rb') as f:
self._coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self._frame_list_cycle = read_images(input_img_list)
def _avatar(self):
face_images_path = r'./face/'
face_images_path = utils.read_files_path(face_images_path)
@ -433,7 +443,7 @@ class Human:
print('wav length:', stream_len)
_audio_chunk_queue = queue.Queue()
index = 0
chunk_len = 640# // 200
chunk_len = 320# // 200
print('chunk_len:', chunk_len)
while stream_len >= chunk_len:
audio_chunk = stream[index:index + chunk_len]
@ -451,7 +461,7 @@ class Human:
j = 0
while not _audio_chunk_queue.empty():
chunks = []
length = min(64, _audio_chunk_queue.qsize())
length = min(128, _audio_chunk_queue.qsize())
for i in range(length):
chunks.append(_audio_chunk_queue.get())
@ -496,7 +506,7 @@ class Human:
self._tts.push_txt(txt)
def put_audio_frame(self, audio_chunk):
self._chunk_2_mal.put_audio_frame(audio_chunk)
self.chunk_2_mal.put_audio_frame(audio_chunk)
# def push_audio_chunk(self, audio_chunk):
# self._chunk_2_mal.push_chunk(audio_chunk)
@ -507,6 +517,9 @@ class Human:
def push_out_put(self, frame, type_):
self._infer.push_out_queue(frame, type_)
def get_out_put(self):
return self._infer.get_out_put()
def push_mel_chunks_queue(self, audio_chunk):
self.audio_chunks_queue_.put(audio_chunk)
@ -521,13 +534,13 @@ class Human:
self._test_image_queue.put(image)
def push_res_frame(self, res_frame, idx, audio_frames):
self._res_render_queue.put((res_frame, idx, audio_frames))
self.res_render_queue.put((res_frame, idx, audio_frames))
def render(self):
try:
# img, aud = self._res_frame_queue.get(block=True, timeout=.3)
# img = self._test_image_queue.get(block=True, timeout=.3)
res_frame, idx, audio_frames = self._res_render_queue.get(block=True, timeout=.3)
res_frame, idx, audio_frames = self.res_render_queue.get(block=True, timeout=.3)
except queue.Empty:
# print('render queue.Empty:')
return None

189
infer.py
View File

@ -1,5 +1,8 @@
#encoding = utf8
import os
import glob
import queue
import multiprocessing as mp
import time
from queue import Queue
from threading import Thread, Event
@ -11,6 +14,7 @@ import torch
from tqdm import tqdm
import face_detection
import utils
from models import Wav2Lip
from utils import mirror_index
@ -107,16 +111,19 @@ img_size = 96
wav2lip_batch_size = 128
def datagen_signal(frame, mel, face_det_results):
def datagen(frames, mels):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
# for i, m in enumerate(mels):
idx = 0
frame_to_save = frame.copy()
face_det_results = face_detect(frames) # BGR2RGB for CNN face detection
i = 0
for m in enumerate(mels):
# for i in range(mels.qsize()):
idx = 0 if True else i%len(frames)
frame_to_save = frames[mirror_index(1, i)].copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
m = mel
img_batch.append(face)
mel_batch.append(m)
@ -131,7 +138,9 @@ def datagen_signal(frame, mel, face_det_results):
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coords_batch
yield img_batch, mel_batch, frame_batch, coords_batch
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
i = i + 1
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
@ -141,19 +150,145 @@ def datagen_signal(frame, mel, face_det_results):
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coords_batch
yield img_batch, mel_batch, frame_batch, coords_batch
def datagen_signal(frame, mel, face_det_results):
img_batch, mel_batch, frame_batch, coord_batch = [], [], [], []
# for i, m in enumerate(mels):
idx = 0
frame_to_save = frame.copy()
face, coord = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
for i, m in enumerate(mel):
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coord_batch.append(coord)
if len(img_batch) >= wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coord_batch
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coord_batch
def inference(render_event, batch_size, face_imgs_path, audio_feat_queue, audio_out_queue, res_frame_queue):
model = load_model(r'.\checkpoints\wav2lip.pth')
# face_list_cycle = read_images(face_imgs_path)
input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
face_list_cycle = read_images(input_face_list)
# input_latent_list_cycle = torch.load(latents_out_path)
length = len(face_list_cycle)
index = 0
count = 0
counttime = 0
print('start inference')
while True:
if render_event.is_set():
starttime = time.perf_counter()
mel_batch = []
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence = True
audio_frames = []
for _ in range(batch_size * 2):
frame, type_ = audio_out_queue.get()
audio_frames.append((frame, type_))
if type_ == 0:
is_all_silence = False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
index = index + 1
else:
print('infer=======')
t = time.perf_counter()
img_batch = []
for i in range(batch_size):
idx = mirror_index(length, index + i)
face = face_list_cycle[idx]
img_batch.append(face)
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, face.shape[0] // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
counttime += (time.perf_counter() - t)
count += batch_size
# _totalframe += 1
if count >= 100:
print(f"------actual avg infer fps:{count / counttime:.4f}")
count = 0
counttime = 0
for i, res_frame in enumerate(pred):
# self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
index = index + 1
# print('total batch time:',time.perf_counter()-starttime)
else:
time.sleep(1)
print('musereal inference processor stop')
class Infer:
def __init__(self, human):
self._human = human
self._feat_queue = Queue()
self._audio_out_queue = Queue()
# self._feat_queue = Queue()
# self._audio_out_queue = Queue()
self._exit_event = Event()
self._run_thread = Thread(target=self.__on_run)
self._exit_event.set()
self._run_thread.start()
self.batch_size = human.get_batch_size()
self.asr = human.chunk_2_mal
self.res_frame_queue = human.res_render_queue
# self._exit_event = Event()
# face_images_path = r'./face/'
# self.face_images_path = utils.read_files_path(face_images_path)
self.avatar_id = 'wav2lip_avatar1'
self.avatar_path = f"./data/{self.avatar_id}"
self.full_imgs_path = f"{self.avatar_path}/full_imgs"
self.face_images_path = f"{self.avatar_path}/face_imgs"
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event, self.batch_size, self.face_images_path,
self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue,
)).start()
self.render_event.set()
# self._run_thread = Thread(target=self.__on_run)
# self._exit_event.set()
# self._run_thread.start()
def __on_run(self):
model = load_model(r'.\checkpoints\wav2lip.pth')
@ -209,9 +344,16 @@ class Infer:
count = 0
count_time = 0
print('start inference')
#
# face_images_path = r'./face/'
# face_images_path = utils.read_files_path(face_images_path)
# face_list_cycle1 = read_images(face_images_path)
# face_det_results = face_detect(face_list_cycle1)
while True:
if self._exit_event.is_set():
start_time = time.perf_counter()
batch_size = self._human.get_batch_size()
try:
mel_batch = self._feat_queue.get(block=True, timeout=1)
except queue.Empty:
@ -219,14 +361,14 @@ class Infer:
is_all_silence = True
audio_frames = []
for _ in range(self._human.get_batch_size() * 2):
for _ in range(batch_size * 2):
frame, type_ = self._audio_out_queue.get()
audio_frames.append((frame, type_))
if type_ == 0:
is_all_silence = False
if is_all_silence:
for i in range(self._human.get_batch_size()):
for i in range(batch_size):
# res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])
index = index + 1
@ -234,23 +376,29 @@ class Infer:
print('infer=======')
t = time.perf_counter()
img_batch = []
for i in range(self._human.get_batch_size()):
for i in range(batch_size):
idx = mirror_index(length, index + i)
face = face_list_cycle[idx]
img_batch.append(face)
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
# img_batch_1, mel_batch_1, frames, coords = datagen_signal(face_list_cycle1,
# mel_batch, face_det_results)
img_batch = np.asarray(img_batch)
mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, face.shape[0] // 2:] = 0
#
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
mel_batch = np.reshape(mel_batch,
[len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
# pred = model(mel_batch, img_batch) * 255.0
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
count_time += (time.perf_counter() - t)
@ -277,3 +425,6 @@ class Infer:
def push_out_queue(self, frame, type_):
self._audio_out_queue.put((frame, type_))
def get_out_put(self):
return self._audio_out_queue.get()

View File

@ -20,8 +20,13 @@ class Chunk2Mal:
self.frames = []
self.queue = Queue()
# self.output_queue = mp.Queue()
# self.feat_queue = mp.Queue(2)
self.fps = human.get_fps()
self.batch_size = human.get_batch_size()
self.stride_left_size = human.get_stride_left_size()
self.stride_right_size = human.get_stride_right_size()
self.output_queue = mp.Queue()
self.feat_queue = mp.Queue(2)
# 320 samples per chunk (20ms * 16000 / 1000)audio_chunk
self.chunk = self._human.get_audio_sample_rate() // self._human.get_fps()
@ -43,27 +48,26 @@ class Chunk2Mal:
logging.info('chunk2mal exit')
def _run_step(self):
for _ in range(self._human.get_batch_size() * 2):
for _ in range(self.batch_size * 2):
frame, _type = self.get_audio_frame()
self.frames.append(frame)
# put to output
self._human.push_out_put(frame, _type)
# self.output_queue.put((frame, _type))
self.output_queue.put((frame, _type))
# context not enough, do not run network.
if len(self.frames) <= self._human.get_stride_left_size() + self._human.get_stride_right_size():
if len(self.frames) <= self.stride_left_size + self.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk]
mel = audio.melspectrogram(inputs)
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
# cut off stride
left = max(0, self._human.get_stride_left_size() * 80 / 50)
right = min(len(mel[0]), len(mel[0]) - self._human.get_stride_right_size() * 80 / 50)
mel_idx_multiplier = 80. * 2 / self._human.get_fps()
left = max(0, self.stride_left_size * 80 / 50)
right = min(len(mel[0]), len(mel[0]) - self.stride_right_size * 80 / 50)
mel_idx_multiplier = 80. * 2 / self.fps
mel_step_size = 16
i = 0
mel_chunks = []
while i < (len(self.frames) - self._human.get_stride_left_size() - self._human.get_stride_right_size()) / 2:
while i < (len(self.frames) - self.stride_left_size - self.stride_right_size) / 2:
start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx)
if start_idx + mel_step_size > len(mel[0]):
@ -71,11 +75,10 @@ class Chunk2Mal:
else:
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
i += 1
self._human.push_mel_chunks(mel_chunks)
# self.feat_queue.put(mel_chunks)
self.feat_queue.put(mel_chunks)
# discard the old part to save memory
self.frames = self.frames[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):]
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
def stop(self):
if self._exit_event is None:
@ -95,25 +98,25 @@ class Chunk2Mal:
def get_audio_frame(self):
try:
frame = self.queue.get(block=True, timeout=0.01)
type = 0
type_ = 0
# print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
type_ = 1
return frame, type
return frame, type_
def get_audio_out(self): # get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self._human.get_stride_left_size() + self._human.get_stride_right_size()):
audio_frame, _type = self.get_audio_frame()
for _ in range(self.stride_left_size + self.stride_right_size):
audio_frame, type_ = self.get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame, type))
for _ in range(self._human.get_stride_right_size()):
self.output_queue.put((audio_frame, type_))
for _ in range(self.stride_left_size):
self.output_queue.get()
def get_next_feat(self, block, timeout):
return self.feat_queue.get(block, timeout)
#
# def get_next_feat(self, block, timeout):
# return self.feat_queue.get(block, timeout)

View File

@ -1,6 +1,9 @@
#encoding = utf8
import os
import cv2
from tqdm import tqdm
def read_files_path(path):
file_paths = []
@ -11,6 +14,16 @@ def read_files_path(path):
return file_paths
def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
print(f'read image path:{img_path}')
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def mirror_index(size, index):
# size = len(self.coord_list_cycle)
turn = index // size