diff --git a/human/audio_handler.py b/human/audio_handler.py index 94a7b41..79f9239 100644 --- a/human/audio_handler.py +++ b/human/audio_handler.py @@ -1,6 +1,9 @@ #encoding = utf8 +import logging from abc import ABC, abstractmethod +logger = logging.getLogger(__name__) + class AudioHandler(ABC): def __init__(self, context, handler): @@ -8,6 +11,11 @@ class AudioHandler(ABC): self._handler = handler @abstractmethod - def on_handle(self, stream, index): + def on_handle(self, stream, index): + pass + + def on_next_handle(self, stream, type_): if self._handler is not None: - self._handler.on_handle(stream, index) + self._handler.on_handle(stream, type_) + else: + logging.info(f'_handler is None') diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py index d64b141..c800818 100644 --- a/human/audio_inference_handler.py +++ b/human/audio_inference_handler.py @@ -1,6 +1,7 @@ #encoding = utf8 import queue import time +from queue import Queue from threading import Event, Thread import numpy as np @@ -14,20 +15,25 @@ class AudioInferenceHandler(AudioHandler): def __init__(self, context, handler): super().__init__(context, handler) + self._mal_queue = Queue() + self._audio_queue = Queue() + self._exit_event = Event() self._run_thread = Thread(target=self.__on_run) self._exit_event.set() self._run_thread.start() - def on_handle(self, stream, index): - if self._handler is not None: - self._handler.on_handle(stream, index) + def on_handle(self, stream, type_): + if type_ == 1: + self._mal_queue.put(stream) + elif type_ == 0: + self._audio_queue.put(stream) def __on_run(self): model = load_model(r'.\checkpoints\wav2lip.pth') print("Model loaded") - face_list_cycle = self._human.get_face_list_cycle() + face_list_cycle = self._context.face_list_cycle() length = len(face_list_cycle) index = 0 @@ -43,20 +49,21 @@ class AudioInferenceHandler(AudioHandler): start_time = time.perf_counter() batch_size = self._context.batch_size() try: - mel_batch = self._feat_queue.get(block=True, timeout=0.1) + mel_batch = self._mal_queue.get(block=True, timeout=0.1) except queue.Empty: continue is_all_silence = True audio_frames = [] for _ in range(batch_size * 2): - frame, type_ = self._audio_out_queue.get() + frame, type_ = self._audio_queue.get() audio_frames.append((frame, type_)) if type_ == 0: is_all_silence = False if is_all_silence: for i in range(batch_size): - self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]) + self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), + 0) index = index + 1 else: print('infer=======') @@ -71,7 +78,7 @@ class AudioInferenceHandler(AudioHandler): mel_batch = np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, face.shape[0] // 2:] = 0 - # + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) @@ -94,8 +101,9 @@ class AudioInferenceHandler(AudioHandler): image_index = 0 for i, res_frame in enumerate(pred): - self._human.push_res_frame(res_frame, mirror_index(length, index), - audio_frames[i * 2:i * 2 + 2]) + self.on_next_handle( + (res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), + 0) index = index + 1 image_index = image_index + 1 print('batch count', image_index) diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py index 3b12eee..047ff85 100644 --- a/human/audio_mal_handler.py +++ b/human/audio_mal_handler.py @@ -27,8 +27,7 @@ class AudioMalHandler(AudioHandler): self.chunk = context.sample_rate() // context.fps() def on_handle(self, stream, index): - if self._handler is not None: - self._handler.on_handle(stream, index) + self._queue.put(stream) def _on_run(self): logging.info('chunk2mal run') @@ -42,9 +41,7 @@ class AudioMalHandler(AudioHandler): for _ in range(self._context.batch_size() * 2): frame, _type = self.get_audio_frame() self.frames.append(frame) - # put to output - # self.output_queue.put((frame, _type)) - self._human.push_out_put(frame, _type) + self.on_next_handle((frame, _type), 0) # context not enough, do not run network. if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size(): return @@ -67,8 +64,7 @@ class AudioMalHandler(AudioHandler): else: mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) i += 1 - # self.feat_queue.put(mel_chunks) - self._human.push_mel_chunks(mel_chunks) + self.on_next_handle(mel_chunks, 1) # discard the old part to save memory self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):] diff --git a/human/human_context.py b/human/human_context.py index a728e24..aa8baec 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -4,6 +4,7 @@ import logging from asr import SherpaNcnnAsr from nlp import PunctuationSplit, DouBao from tts import TTSEdge, TTSAudioSplitHandle +from utils import load_avatar, get_device logger = logging.getLogger(__name__) @@ -11,12 +12,14 @@ logger = logging.getLogger(__name__) class HumanContext: def __init__(self): self._fps = 50 # 20 ms per frame + self._image_size = 96 self._batch_size = 16 self._sample_rate = 16000 self._stride_left_size = 10 self._stride_right_size = 10 - full_images, face_frames, coord_frames = load_avatar(r'./face/') + self._device = get_device() + full_images, face_frames, coord_frames = load_avatar(r'./face/', self._device, self._image_size) self._frame_list_cycle = full_images self._face_list_cycle = face_frames self._coord_list_cycle = coord_frames @@ -28,6 +31,14 @@ class HumanContext: def fps(self): return self._fps + @property + def image_size(self): + return self._image_size + + @property + def device(self): + return self._device + @property def batch_size(self): return self._batch_size @@ -44,6 +55,10 @@ class HumanContext: def stride_right_size(self): return self._stride_right_size + @property + def face_list_cycle(self): + return self._face_list_cycle + def build(self): tts_handle = TTSAudioSplitHandle(self, None) tts = TTSEdge(tts_handle) diff --git a/human/human_render.py b/human/human_render.py new file mode 100644 index 0000000..f36537a --- /dev/null +++ b/human/human_render.py @@ -0,0 +1,16 @@ +#encoding = utf8 +from queue import Queue + +from human import AudioHandler + + +class HumanRender(AudioHandler): + def __init__(self, context, handler): + super().__init__(context, handler) + + self._queue = Queue(context.batch_size * 2) + + def on_handle(self, stream, index): + self._queue.put(stream) + +