From a71740f40c61ce2a6a4910b36738267b9f6410d3 Mon Sep 17 00:00:00 2001 From: brige Date: Fri, 4 Oct 2024 14:37:50 +0800 Subject: [PATCH] modify human --- Human.py | 65 +++++++++------ infer.py | 213 ++++++++++++++++++++++++++++++++++++++++------- tts/Chunk2Mal.py | 49 ++++++----- utils.py | 13 +++ 4 files changed, 260 insertions(+), 80 deletions(-) diff --git a/Human.py b/Human.py index 02a2479..3de9cc9 100644 --- a/Human.py +++ b/Human.py @@ -1,9 +1,12 @@ #encoding = utf8 import copy +import glob import io import logging import multiprocessing as mp +import os +import pickle import platform, subprocess import queue import threading @@ -16,7 +19,7 @@ import pyaudio import audio import face_detection import utils -from infer import Infer +from infer import Infer, read_images from models import Wav2Lip from tts.Chunk2Mal import Chunk2Mal import torch @@ -54,16 +57,6 @@ def load_model(path): return model.eval() -def read_images(img_list): - frames = [] - print('reading images...') - for img_path in tqdm(img_list): - print(f'read image path:{img_path}') - frame = cv2.imread(img_path) - frames.append(frame) - return frames - - # python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face # .\face\img00016.jpg --audio .\audio\audio1.wav def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue): @@ -295,24 +288,34 @@ class Human: self._output_queue = mp.Queue() self._res_frame_queue = mp.Queue(self._batch_size * 2) - full_images, face_frames, coord_frames = self._avatar() - self._frame_list_cycle = full_images - self._face_list_cycle = face_frames - self._coord_list_cycle = coord_frames - face_images_length = len(self._face_list_cycle) - logging.info(f'face images length: {face_images_length}') - print(f'face images length: {face_images_length}') + # full_images, face_frames, coord_frames = self._avatar() + # self._frame_list_cycle = full_images + # self._face_list_cycle = face_frames + # self._coord_list_cycle = coord_frames + # face_images_length = len(self._face_list_cycle) + # logging.info(f'face images length: {face_images_length}') + # print(f'face images length: {face_images_length}') + self.avatar_id = 'wav2lip_avatar1' + self.avatar_path = f"./data/{self.avatar_id}" + self.full_imgs_path = f"{self.avatar_path}/full_imgs" + self.face_imgs_path = f"{self.avatar_path}/face_imgs" + self.coords_path = f"{self.avatar_path}/coords.pkl" + + self.__loadavatar() self.mel_chunks_queue_ = Queue() self.audio_chunks_queue_ = Queue() self._test_image_queue = Queue() - self._res_render_queue = Queue() + # self._res_render_queue = Queue() - self._chunk_2_mal = Chunk2Mal(self) + self.res_render_queue = mp.Queue(self._batch_size * 2) + + self.chunk_2_mal = Chunk2Mal(self) self._tts = TTSBase(self) self._infer = Infer(self) + self.chunk_2_mal.warm_up() - # # + # # self._thread = None # thread = threading.Thread(target=self.test) # thread.start() @@ -341,6 +344,13 @@ class Human: # stream.close() # p.terminate() + def __loadavatar(self): + with open(self.coords_path, 'rb') as f: + self._coord_list_cycle = pickle.load(f) + input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + self._frame_list_cycle = read_images(input_img_list) + def _avatar(self): face_images_path = r'./face/' face_images_path = utils.read_files_path(face_images_path) @@ -433,7 +443,7 @@ class Human: print('wav length:', stream_len) _audio_chunk_queue = queue.Queue() index = 0 - chunk_len = 640# // 200 + chunk_len = 320# // 200 print('chunk_len:', chunk_len) while stream_len >= chunk_len: audio_chunk = stream[index:index + chunk_len] @@ -451,7 +461,7 @@ class Human: j = 0 while not _audio_chunk_queue.empty(): chunks = [] - length = min(64, _audio_chunk_queue.qsize()) + length = min(128, _audio_chunk_queue.qsize()) for i in range(length): chunks.append(_audio_chunk_queue.get()) @@ -496,7 +506,7 @@ class Human: self._tts.push_txt(txt) def put_audio_frame(self, audio_chunk): - self._chunk_2_mal.put_audio_frame(audio_chunk) + self.chunk_2_mal.put_audio_frame(audio_chunk) # def push_audio_chunk(self, audio_chunk): # self._chunk_2_mal.push_chunk(audio_chunk) @@ -507,6 +517,9 @@ class Human: def push_out_put(self, frame, type_): self._infer.push_out_queue(frame, type_) + def get_out_put(self): + return self._infer.get_out_put() + def push_mel_chunks_queue(self, audio_chunk): self.audio_chunks_queue_.put(audio_chunk) @@ -521,13 +534,13 @@ class Human: self._test_image_queue.put(image) def push_res_frame(self, res_frame, idx, audio_frames): - self._res_render_queue.put((res_frame, idx, audio_frames)) + self.res_render_queue.put((res_frame, idx, audio_frames)) def render(self): try: # img, aud = self._res_frame_queue.get(block=True, timeout=.3) # img = self._test_image_queue.get(block=True, timeout=.3) - res_frame, idx, audio_frames = self._res_render_queue.get(block=True, timeout=.3) + res_frame, idx, audio_frames = self.res_render_queue.get(block=True, timeout=.3) except queue.Empty: # print('render queue.Empty:') return None diff --git a/infer.py b/infer.py index 420b51b..a0c8cf3 100644 --- a/infer.py +++ b/infer.py @@ -1,5 +1,8 @@ #encoding = utf8 +import os +import glob import queue +import multiprocessing as mp import time from queue import Queue from threading import Thread, Event @@ -11,6 +14,7 @@ import torch from tqdm import tqdm import face_detection +import utils from models import Wav2Lip from utils import mirror_index @@ -107,31 +111,36 @@ img_size = 96 wav2lip_batch_size = 128 -def datagen_signal(frame, mel, face_det_results): +def datagen(frames, mels): img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] - # for i, m in enumerate(mels): - idx = 0 - frame_to_save = frame.copy() - face, coords = face_det_results[idx].copy() + face_det_results = face_detect(frames) # BGR2RGB for CNN face detection - face = cv2.resize(face, (img_size, img_size)) - m = mel + i = 0 + for m in enumerate(mels): + # for i in range(mels.qsize()): + idx = 0 if True else i%len(frames) + frame_to_save = frames[mirror_index(1, i)].copy() + face, coords = face_det_results[idx].copy() - img_batch.append(face) - mel_batch.append(m) - frame_batch.append(frame_to_save) - coords_batch.append(coords) + face = cv2.resize(face, (img_size, img_size)) - if len(img_batch) >= wav2lip_batch_size: - img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) - img_masked = img_batch.copy() - img_masked[:, img_size // 2:] = 0 - img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. - mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + if len(img_batch) >= wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) - return img_batch, mel_batch, frame_batch, coords_batch + img_masked = img_batch.copy() + img_masked[:, img_size//2:] = 0 + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + i = i + 1 if len(img_batch) > 0: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) @@ -141,19 +150,145 @@ def datagen_signal(frame, mel, face_det_results): img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) - return img_batch, mel_batch, frame_batch, coords_batch + yield img_batch, mel_batch, frame_batch, coords_batch +def datagen_signal(frame, mel, face_det_results): + img_batch, mel_batch, frame_batch, coord_batch = [], [], [], [] + + # for i, m in enumerate(mels): + idx = 0 + frame_to_save = frame.copy() + face, coord = face_det_results[idx].copy() + + face = cv2.resize(face, (img_size, img_size)) + + for i, m in enumerate(mel): + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coord_batch.append(coord) + + if len(img_batch) >= wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, img_size // 2:] = 0 + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + return img_batch, mel_batch, frame_batch, coord_batch + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + img_masked = img_batch.copy() + img_masked[:, img_size // 2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + return img_batch, mel_batch, frame_batch, coord_batch + + +def inference(render_event, batch_size, face_imgs_path, audio_feat_queue, audio_out_queue, res_frame_queue): + model = load_model(r'.\checkpoints\wav2lip.pth') + # face_list_cycle = read_images(face_imgs_path) + input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]')) + input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + face_list_cycle = read_images(input_face_list) + + # input_latent_list_cycle = torch.load(latents_out_path) + length = len(face_list_cycle) + index = 0 + count = 0 + counttime = 0 + print('start inference') + while True: + if render_event.is_set(): + starttime = time.perf_counter() + mel_batch = [] + try: + mel_batch = audio_feat_queue.get(block=True, timeout=1) + except queue.Empty: + continue + + is_all_silence = True + audio_frames = [] + for _ in range(batch_size * 2): + frame, type_ = audio_out_queue.get() + audio_frames.append((frame, type_)) + if type_ == 0: + is_all_silence = False + + if is_all_silence: + for i in range(batch_size): + res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) + index = index + 1 + else: + print('infer=======') + t = time.perf_counter() + img_batch = [] + for i in range(batch_size): + idx = mirror_index(length, index + i) + face = face_list_cycle[idx] + img_batch.append(face) + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, face.shape[0] // 2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + + with torch.no_grad(): + pred = model(mel_batch, img_batch) + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + + counttime += (time.perf_counter() - t) + count += batch_size + # _totalframe += 1 + if count >= 100: + print(f"------actual avg infer fps:{count / counttime:.4f}") + count = 0 + counttime = 0 + for i, res_frame in enumerate(pred): + # self.__pushmedia(res_frame,loop,audio_track,video_track) + res_frame_queue.put((res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) + index = index + 1 + # print('total batch time:',time.perf_counter()-starttime) + else: + time.sleep(1) + print('musereal inference processor stop') + class Infer: def __init__(self, human): self._human = human - self._feat_queue = Queue() - self._audio_out_queue = Queue() + # self._feat_queue = Queue() + # self._audio_out_queue = Queue() - self._exit_event = Event() - self._run_thread = Thread(target=self.__on_run) - self._exit_event.set() - self._run_thread.start() + self.batch_size = human.get_batch_size() + self.asr = human.chunk_2_mal + self.res_frame_queue = human.res_render_queue + + # self._exit_event = Event() + # face_images_path = r'./face/' + # self.face_images_path = utils.read_files_path(face_images_path) + self.avatar_id = 'wav2lip_avatar1' + self.avatar_path = f"./data/{self.avatar_id}" + self.full_imgs_path = f"{self.avatar_path}/full_imgs" + self.face_images_path = f"{self.avatar_path}/face_imgs" + self.coords_path = f"{self.avatar_path}/coords.pkl" + self.render_event = mp.Event() + mp.Process(target=inference, args=(self.render_event, self.batch_size, self.face_images_path, + self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue, + )).start() + self.render_event.set() + # self._run_thread = Thread(target=self.__on_run) + # self._exit_event.set() + # self._run_thread.start() def __on_run(self): model = load_model(r'.\checkpoints\wav2lip.pth') @@ -209,9 +344,16 @@ class Infer: count = 0 count_time = 0 print('start inference') + # + # face_images_path = r'./face/' + # face_images_path = utils.read_files_path(face_images_path) + # face_list_cycle1 = read_images(face_images_path) + # face_det_results = face_detect(face_list_cycle1) + while True: if self._exit_event.is_set(): start_time = time.perf_counter() + batch_size = self._human.get_batch_size() try: mel_batch = self._feat_queue.get(block=True, timeout=1) except queue.Empty: @@ -219,14 +361,14 @@ class Infer: is_all_silence = True audio_frames = [] - for _ in range(self._human.get_batch_size() * 2): + for _ in range(batch_size * 2): frame, type_ = self._audio_out_queue.get() audio_frames.append((frame, type_)) if type_ == 0: is_all_silence = False if is_all_silence: - for i in range(self._human.get_batch_size()): + for i in range(batch_size): # res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]) index = index + 1 @@ -234,23 +376,29 @@ class Infer: print('infer=======') t = time.perf_counter() img_batch = [] - for i in range(self._human.get_batch_size()): + for i in range(batch_size): idx = mirror_index(length, index + i) face = face_list_cycle[idx] img_batch.append(face) - img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + # img_batch_1, mel_batch_1, frames, coords = datagen_signal(face_list_cycle1, + # mel_batch, face_det_results) + + img_batch = np.asarray(img_batch) + mel_batch = np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, face.shape[0] // 2:] = 0 - + # img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. - mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + mel_batch = np.reshape(mel_batch, + [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) with torch.no_grad(): pred = model(mel_batch, img_batch) + # pred = model(mel_batch, img_batch) * 255.0 pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. count_time += (time.perf_counter() - t) @@ -277,3 +425,6 @@ class Infer: def push_out_queue(self, frame, type_): self._audio_out_queue.put((frame, type_)) + + def get_out_put(self): + return self._audio_out_queue.get() diff --git a/tts/Chunk2Mal.py b/tts/Chunk2Mal.py index 21d2cc2..78890d8 100644 --- a/tts/Chunk2Mal.py +++ b/tts/Chunk2Mal.py @@ -20,8 +20,13 @@ class Chunk2Mal: self.frames = [] self.queue = Queue() - # self.output_queue = mp.Queue() - # self.feat_queue = mp.Queue(2) + + self.fps = human.get_fps() + self.batch_size = human.get_batch_size() + self.stride_left_size = human.get_stride_left_size() + self.stride_right_size = human.get_stride_right_size() + self.output_queue = mp.Queue() + self.feat_queue = mp.Queue(2) # 320 samples per chunk (20ms * 16000 / 1000)audio_chunk self.chunk = self._human.get_audio_sample_rate() // self._human.get_fps() @@ -43,27 +48,26 @@ class Chunk2Mal: logging.info('chunk2mal exit') def _run_step(self): - for _ in range(self._human.get_batch_size() * 2): + for _ in range(self.batch_size * 2): frame, _type = self.get_audio_frame() self.frames.append(frame) # put to output - self._human.push_out_put(frame, _type) - # self.output_queue.put((frame, _type)) + self.output_queue.put((frame, _type)) # context not enough, do not run network. - if len(self.frames) <= self._human.get_stride_left_size() + self._human.get_stride_right_size(): + if len(self.frames) <= self.stride_left_size + self.stride_right_size: return inputs = np.concatenate(self.frames) # [N * chunk] mel = audio.melspectrogram(inputs) # print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames)) # cut off stride - left = max(0, self._human.get_stride_left_size() * 80 / 50) - right = min(len(mel[0]), len(mel[0]) - self._human.get_stride_right_size() * 80 / 50) - mel_idx_multiplier = 80. * 2 / self._human.get_fps() + left = max(0, self.stride_left_size * 80 / 50) + right = min(len(mel[0]), len(mel[0]) - self.stride_right_size * 80 / 50) + mel_idx_multiplier = 80. * 2 / self.fps mel_step_size = 16 i = 0 mel_chunks = [] - while i < (len(self.frames) - self._human.get_stride_left_size() - self._human.get_stride_right_size()) / 2: + while i < (len(self.frames) - self.stride_left_size - self.stride_right_size) / 2: start_idx = int(left + i * mel_idx_multiplier) # print(start_idx) if start_idx + mel_step_size > len(mel[0]): @@ -71,11 +75,10 @@ class Chunk2Mal: else: mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) i += 1 - self._human.push_mel_chunks(mel_chunks) - # self.feat_queue.put(mel_chunks) + self.feat_queue.put(mel_chunks) # discard the old part to save memory - self.frames = self.frames[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):] + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] def stop(self): if self._exit_event is None: @@ -95,25 +98,25 @@ class Chunk2Mal: def get_audio_frame(self): try: frame = self.queue.get(block=True, timeout=0.01) - type = 0 + type_ = 0 # print(f'[INFO] get frame {frame.shape}') except queue.Empty: frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 + type_ = 1 - return frame, type + return frame, type_ def get_audio_out(self): # get origin audio pcm to nerf return self.output_queue.get() def warm_up(self): - for _ in range(self._human.get_stride_left_size() + self._human.get_stride_right_size()): - audio_frame, _type = self.get_audio_frame() + for _ in range(self.stride_left_size + self.stride_right_size): + audio_frame, type_ = self.get_audio_frame() self.frames.append(audio_frame) - self.output_queue.put((audio_frame, type)) - for _ in range(self._human.get_stride_right_size()): + self.output_queue.put((audio_frame, type_)) + for _ in range(self.stride_left_size): self.output_queue.get() - def get_next_feat(self, block, timeout): - return self.feat_queue.get(block, timeout) - + # + # def get_next_feat(self, block, timeout): + # return self.feat_queue.get(block, timeout) diff --git a/utils.py b/utils.py index e1bbb86..12aa4f7 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,9 @@ #encoding = utf8 import os +import cv2 +from tqdm import tqdm + def read_files_path(path): file_paths = [] @@ -11,6 +14,16 @@ def read_files_path(path): return file_paths +def read_images(img_list): + frames = [] + print('reading images...') + for img_path in tqdm(img_list): + print(f'read image path:{img_path}') + frame = cv2.imread(img_path) + frames.append(frame) + return frames + + def mirror_index(size, index): # size = len(self.coord_list_cycle) turn = index // size