From d0c8ddb18b4479572a29390ce6eef020b14bf517 Mon Sep 17 00:00:00 2001 From: jiegeaiai Date: Sun, 24 Nov 2024 15:19:40 +0800 Subject: [PATCH] modify video render --- human/__init__.py | 1 + human/audio_inference_onnx_handler.py | 207 ++++++++++++++++++++++++++ human/human_context.py | 9 ++ render/video_render.py | 23 ++- 4 files changed, 236 insertions(+), 4 deletions(-) create mode 100644 human/audio_inference_onnx_handler.py diff --git a/human/__init__.py b/human/__init__.py index 966011f..503a041 100644 --- a/human/__init__.py +++ b/human/__init__.py @@ -3,4 +3,5 @@ from .human_context import HumanContext from .audio_mal_handler import AudioMalHandler from .audio_inference_handler import AudioInferenceHandler +from .audio_inference_onnx_handler import AudioInferenceOnnxHandler from .human_render import HumanRender diff --git a/human/audio_inference_onnx_handler.py b/human/audio_inference_onnx_handler.py new file mode 100644 index 0000000..f7d28db --- /dev/null +++ b/human/audio_inference_onnx_handler.py @@ -0,0 +1,207 @@ +#encoding = utf8 +import logging +import os +import queue +import time +from threading import Event, Thread + +from gfpgan import GFPGANer +from eventbus import EventBus +from human_handler import AudioHandler +from utils import load_model, mirror_index, get_device, SyncQueue + +logger = logging.getLogger(__name__) +current_file_path = os.path.dirname(os.path.abspath(__file__)) + + +def load_gfpgan_model(model_path): + logger.info(f'load_gfpgan_model, path:{model_path}') + model = GFPGANer( + model_path=model_path, + upscale=1, + arch='clean', + channel_multiplier=2, + bg_upsampler=None, + ) + return model + + +def load_model(model_path): + import onnxruntime as ort + + sess_opt = ort.SessionOptions() + sess_opt.intra_op_num_threads = 8 + sess = ort.InferenceSession(model_path, sess_options=sess_opt, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + + return sess + + +class AudioInferenceOnnxHandler(AudioHandler): + def __init__(self, context, handler): + super().__init__(context, handler) + + EventBus().register('stop', self._on_stop) + EventBus().register('clear_cache', self.on_clear_cache) + self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel") + self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio") + + self._is_running = True + self._exit_event = Event() + self._run_thread = Thread(target=self.__on_run, name="AudioInferenceHandlerThread") + self._exit_event.set() + self._run_thread.start() + + logger.info("AudioInferenceHandler init") + + def __del__(self): + EventBus().unregister('stop', self._on_stop) + EventBus().unregister('clear_cache', self.on_clear_cache) + + def _on_stop(self, *args, **kwargs): + self.stop() + + def on_clear_cache(self, *args, **kwargs): + self._mal_queue.clear() + self._audio_queue.clear() + + def on_handle(self, stream, type_): + if not self._is_running: + return + + if type_ == 1: + self._mal_queue.put(stream) + elif type_ == 0: + self._audio_queue.put(stream) + # print('AudioInferenceHandler on_handle', type_, self._audio_queue.size()) + + def on_message(self, message): + super().on_message(message) + + def __on_run(self): + model_path = os.path.join(current_file_path, '..', 'checkpoints', 'weights', 'wav2lip') + model_a_path = os.path.join(model_path, 'model_a_general.onnx') + logger.info(f'AudioInferenceHandler model_a_path, path:{model_a_path}') + model_g_path = os.path.join(model_path, 'model_g_general.onnx') + logger.info(f'AudioInferenceHandler model_g_path, path:{model_g_path}') + + model_a = load_model(model_a_path) + model_g = load_model(model_g_path) + logger.info("Model loaded") + + gfpgan_model_path = os.path.join(current_file_path, '..', 'checkpoints', 'gfpgan', 'weights', 'GFPGANv1.4.pth') + gfpgan_model = load_gfpgan_model(gfpgan_model_path) + + face_list_cycle = self._context.face_list_cycle + + + length = len(face_list_cycle) + index = 0 + count = 0 + count_time = 0 + logger.info('start inference') + + device = get_device() + logger.info(f'use device:{device}') + + while self._is_running: + if self._exit_event.is_set(): + start_time = time.perf_counter() + batch_size = self._context.batch_size + try: + mel_batch = self._mal_queue.get(timeout=0.02) + # print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', self._mal_queue.size()) + except queue.Empty: + continue + + # print('origin mel_batch:', len(mel_batch)) + is_all_silence = True + audio_frames = [] + current_text = '' + for _ in range(batch_size * 2): + frame, type_ = self._audio_queue.get() + # print('AudioInferenceHandler type_', type_) + current_text = frame[1] + audio_frames.append((frame, type_)) + if type_ == 0: + is_all_silence = False + + if not self._is_running: + print('AudioInferenceHandler not running') + break + + if is_all_silence: + for i in range(batch_size): + if not self._is_running: + break + self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), + 0) + index = index + 1 + else: + logger.info(f'infer======= {current_text}') + t = time.perf_counter() + img_batch = [] + # for i in range(batch_size): + for i in range(len(mel_batch)): + idx = mirror_index(length, index + i) + face = face_list_cycle[idx] + img_batch.append(face) + + # print('orign img_batch:', len(img_batch), 'origin mel_batch:', len(mel_batch)) + onnx_input = {'audio_seqs__0': mel_batch, } + onnx_names = [output.name for output in model_a.get_outputs()] + embeddings = model_a.run(onnx_names, onnx_input)[0] + + onnx_input = {"audio_embedings__0": embeddings, "img_seqs__1": img_batch} + onnx_names = [output.name for output in model_g.get_outputs()] + onnx_out = model_g.run(onnx_names, onnx_input)[0] + pred = onnx_out + + # onnxruntime_inputs = {"audio_seqs__0": mel_batch, } + # onnxruntime_names = [output.name for output in model_a.get_outputs()] + # embeddings = model_a.run(onnxruntime_names, onnxruntime_inputs)[0] + # + # onnxruntime_inputs = {"audio_embedings__0": embeddings, "img_seqs__1": img_batch} + # onnxruntime_names = [output.name for output in model_g.get_outputs()] + # + # start_model = time.time() + # onnxruntime_output = model_g.run(onnxruntime_names, onnxruntime_inputs)[0] + # end_model = time.time() + # pred = onnxruntime_output + + count_time += (time.perf_counter() - t) + count += batch_size + + if count >= 100: + logger.info(f"------actual avg infer fps:{count / count_time:.4f}") + count = 0 + count_time = 0 + + for res_frame in zip(pred): + if not self._is_running: + break + + self.on_next_handle( + (res_frame[0], mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), + 0) + index = index + 1 + logger.info(f'total batch time: {time.perf_counter() - start_time}') + else: + time.sleep(1) + break + logger.info('AudioInferenceHandler inference processor stop') + + def stop(self): + logger.info('AudioInferenceHandler stop') + self._is_running = False + self._exit_event.clear() + if self._run_thread.is_alive(): + logger.info('AudioInferenceHandler stop join') + self._run_thread.join() + logger.info('AudioInferenceHandler stop exit') + + def pause_talk(self): + print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size()) + self._audio_queue.clear() + print('AudioInferenceHandler111') + self._mal_queue.clear() + print('AudioInferenceHandler222') diff --git a/human/human_context.py b/human/human_context.py index 193e0ea..e76e8ad 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -4,6 +4,7 @@ import os from asr import SherpaNcnnAsr from eventbus import EventBus +from .audio_inference_onnx_handler import AudioInferenceOnnxHandler from .audio_inference_handler import AudioInferenceHandler from .audio_mal_handler import AudioMalHandler from .human_render import HumanRender @@ -99,6 +100,14 @@ class HumanContext: def coord_list_cycle(self): return self._coord_list_cycle + @property + def align_frames(self): + return self._align_frames + + @property + def inv_m_frames(self): + return self._inv_m_frames + @property def render_handler(self): return self._render_handler diff --git a/render/video_render.py b/render/video_render.py index ab88fd2..8831180 100644 --- a/render/video_render.py +++ b/render/video_render.py @@ -10,6 +10,17 @@ import numpy as np from .base_render import BaseRender +def img_warp_back_inv_m(img, img_to, inv_m): + h_up, w_up, c = img_to.shape + + mask = np.ones_like(img).astype(np.float32) + inv_mask = cv2.warpAffine(mask, inv_m, (w_up, h_up)) + inv_img = cv2.warpAffine(img, inv_m, (w_up, h_up)) + + img_to[inv_mask == 1] = inv_img[inv_mask == 1] + return img_to + + class VideoRender(BaseRender): def __init__(self, play_clock, context, human_render): super().__init__(play_clock, context, 'Video') @@ -24,18 +35,22 @@ class VideoRender(BaseRender): else: bbox = self._context.coord_list_cycle[idx] combine_frame = copy.deepcopy(self._context.frame_list_cycle[idx]) + af = copy.deepcopy(self._context.align_frames[idx]) + inv_m = self._context.inv_m_frames[idx] y1, y2, x1, x2 = bbox try: res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) - except: - print('resize error') + af[y1:y2, x1:x2] = res_frame + combine_frame = img_warp_back_inv_m(af, combine_frame, inv_m) + except Exception as e: + print('resize error', e) return # cv2.imwrite(f'./images/res_frame_{ self.index }.png', res_frame) - combine_frame[y1:y2, x1:x2] = res_frame + # combine_frame[y1:y2, x1:x2] = res_frame # cv2.imwrite(f'/combine_frame_{self.index}.png', combine_frame) # self.index = self.index + 1 image = combine_frame - # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self._human_render is not None: self._human_render.put_image(image)