diff --git a/asr/sherpa_ncnn_asr.py b/asr/sherpa_ncnn_asr.py index 9dd2e67..777d04b 100644 --- a/asr/sherpa_ncnn_asr.py +++ b/asr/sherpa_ncnn_asr.py @@ -26,7 +26,7 @@ class SherpaNcnnAsr(AsrBase): self._recognizer = self._create_recognizer() def _create_recognizer(self): - base_path = os.path.join(os.getcwd(), '..', 'data', 'asr', 'sherpa-ncnn', + base_path = os.path.join(os.getcwd(), 'data', 'asr', 'sherpa-ncnn', 'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23') recognizer = sherpa_ncnn.Recognizer( tokens=base_path + '/tokens.txt', diff --git a/human/__init__.py b/human/__init__.py index f1319f8..175ad85 100644 --- a/human/__init__.py +++ b/human/__init__.py @@ -1,4 +1,4 @@ #encoding = utf8 -from .human_context import HumanContext from .audio_handler import AudioHandler +from .human_context import HumanContext diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py index c800818..1b829d5 100644 --- a/human/audio_inference_handler.py +++ b/human/audio_inference_handler.py @@ -7,7 +7,7 @@ from threading import Event, Thread import numpy as np import torch -from human import AudioHandler +from .audio_handler import AudioHandler from utils import load_model, mirror_index, get_device @@ -25,7 +25,7 @@ class AudioInferenceHandler(AudioHandler): def on_handle(self, stream, type_): if type_ == 1: - self._mal_queue.put(stream) + self._mal_queue.put(stream) elif type_ == 0: self._audio_queue.put(stream) @@ -33,7 +33,7 @@ class AudioInferenceHandler(AudioHandler): model = load_model(r'.\checkpoints\wav2lip.pth') print("Model loaded") - face_list_cycle = self._context.face_list_cycle() + face_list_cycle = self._context.face_list_cycle length = len(face_list_cycle) index = 0 @@ -47,7 +47,7 @@ class AudioInferenceHandler(AudioHandler): while True: if self._exit_event.is_set(): start_time = time.perf_counter() - batch_size = self._context.batch_size() + batch_size = self._context.batch_size try: mel_batch = self._mal_queue.get(block=True, timeout=0.1) except queue.Empty: @@ -78,7 +78,7 @@ class AudioInferenceHandler(AudioHandler): mel_batch = np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, face.shape[0] // 2:] = 0 - + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py index 047ff85..78dc4ae 100644 --- a/human/audio_mal_handler.py +++ b/human/audio_mal_handler.py @@ -7,7 +7,7 @@ from threading import Thread, Event import numpy as np -from human import AudioHandler +from .audio_handler import AudioHandler from utils import melspectrogram logger = logging.getLogger(__name__) @@ -24,7 +24,7 @@ class AudioMalHandler(AudioHandler): self._thread.start() self.frames = [] - self.chunk = context.sample_rate() // context.fps() + self.chunk = context.sample_rate // context.fps def on_handle(self, stream, index): self._queue.put(stream) @@ -38,25 +38,25 @@ class AudioMalHandler(AudioHandler): logging.info('chunk2mal exit') def _run_step(self): - for _ in range(self._context.batch_size() * 2): + for _ in range(self._context.batch_size * 2): frame, _type = self.get_audio_frame() self.frames.append(frame) self.on_next_handle((frame, _type), 0) # context not enough, do not run network. - if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size(): + if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size: return inputs = np.concatenate(self.frames) # [N * chunk] mel = melspectrogram(inputs) # print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames)) # cut off stride - left = max(0, self._context.stride_left_size() * 80 / 50) - right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size() * 80 / 50) - mel_idx_multiplier = 80. * 2 / self._context.fps() + left = max(0, self._context.stride_left_size * 80 / 50) + right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50) + mel_idx_multiplier = 80. * 2 / self._context.fps mel_step_size = 16 i = 0 mel_chunks = [] - while i < (len(self.frames) - self._context.stride_left_size() - self._context.stride_right_size()) / 2: + while i < (len(self.frames) - self._context.stride_left_size - self._context.stride_right_size) / 2: start_idx = int(left + i * mel_idx_multiplier) # print(start_idx) if start_idx + mel_step_size > len(mel[0]): @@ -67,7 +67,7 @@ class AudioMalHandler(AudioHandler): self.on_next_handle(mel_chunks, 1) # discard the old part to save memory - self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):] + self.frames = self.frames[-(self._context.stride_left_size + self._context.stride_right_size):] def get_audio_frame(self): try: diff --git a/human/human_context.py b/human/human_context.py index aa8baec..34b9ce0 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -2,6 +2,9 @@ import logging from asr import SherpaNcnnAsr +from human.audio_inference_handler import AudioInferenceHandler +from human.audio_mal_handler import AudioMalHandler +from human.human_render import HumanRender from nlp import PunctuationSplit, DouBao from tts import TTSEdge, TTSAudioSplitHandle from utils import load_avatar, get_device @@ -19,7 +22,8 @@ class HumanContext: self._stride_right_size = 10 self._device = get_device() - full_images, face_frames, coord_frames = load_avatar(r'./face/', self._device, self._image_size) + print(f'device:{self._device}') + full_images, face_frames, coord_frames = load_avatar(r'./face/', self._image_size, self._device) self._frame_list_cycle = full_images self._face_list_cycle = face_frames self._coord_list_cycle = coord_frames @@ -27,6 +31,14 @@ class HumanContext: logging.info(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}') + self.asr = None + self.nlp = None + self.tts = None + self.tts_handle = None + self.mal_handler = None + self.infer_handler = None + self._render_handler = None + @property def fps(self): return self._fps @@ -59,12 +71,27 @@ class HumanContext: def face_list_cycle(self): return self._face_list_cycle + @property + def frame_list_cycle(self): + return self._frame_list_cycle + + @property + def coord_list_cycle(self): + return self._coord_list_cycle + + @property + def render_handler(self): + return self.render_handler + def build(self): - tts_handle = TTSAudioSplitHandle(self, None) - tts = TTSEdge(tts_handle) + self._render_handler = HumanRender(self, None) + self.infer_handler = AudioInferenceHandler(self, self._render_handler) + self.mal_handler = AudioMalHandler(self, self.infer_handler) + self.tts_handle = TTSAudioSplitHandle(self, self.mal_handler) + self.tts = TTSEdge(self.tts_handle) split = PunctuationSplit() - nlp = DouBao(split, tts) - asr = SherpaNcnnAsr() - asr.attach(nlp) + nlp = DouBao(split, self.tts) + self.asr = SherpaNcnnAsr() + self.asr.attach(nlp) diff --git a/human/human_render.py b/human/human_render.py index f36537a..db319f0 100644 --- a/human/human_render.py +++ b/human/human_render.py @@ -1,7 +1,16 @@ #encoding = utf8 +import copy +import logging +import queue +import time from queue import Queue +from threading import Thread, Event -from human import AudioHandler +import cv2 +import numpy as np + +from audio_render import AudioRender +from .audio_handler import AudioHandler class HumanRender(AudioHandler): @@ -9,6 +18,63 @@ class HumanRender(AudioHandler): super().__init__(context, handler) self._queue = Queue(context.batch_size * 2) + self._audio_render = None + self._image_render = None + + self._exit_event = Event() + self._thread = Thread(target=self._on_run) + self._exit_event.set() + self._thread.start() + + def _on_run(self): + logging.info('chunk2mal run') + while self._exit_event.is_set(): + self._run_step() + time.sleep(0.002) + + logging.info('chunk2mal exit') + + def _run_step(self): + try: + res_frame, idx, audio_frames = self._queue.get(block=True, timeout=.002) + except queue.Empty: + # print('render queue.Empty:') + return None + if audio_frames[0][1] != 0 and audio_frames[1][1] != 0: + combine_frame = self._context.frame_list_cycle[idx] + else: + bbox = self._context.coord_list_cycle[idx] + combine_frame = copy.deepcopy(self._context.frame_list_cycle[idx]) + y1, y2, x1, x2 = bbox + try: + res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) + except: + return + # combine_frame = get_image(ori_frame,res_frame,bbox) + # t=time.perf_counter() + combine_frame[y1:y2, x1:x2] = res_frame + + image = combine_frame + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + if self._image_render is not None: + self._image_render.render(image) + + for audio_frame in audio_frames: + frame, type_ = audio_frame + frame = (frame * 32767).astype(np.int16) + if self._audio_render is not None: + self._audio_render.write(frame.tobytes(), int(frame.shape[0]*2)) + # new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) + # new_frame.planes[0].update(frame.tobytes()) + # new_frame.sample_rate = 16000 + + + def set_audio_render(self, render): + self._audio_render = render + + def set_image_render(self, render): + self._image_render = render def on_handle(self, stream, index): self._queue.put(stream) diff --git a/tts/tts_audio_handle.py b/tts/tts_audio_handle.py index 8e7babf..e892033 100644 --- a/tts/tts_audio_handle.py +++ b/tts/tts_audio_handle.py @@ -31,8 +31,8 @@ class TTSAudioHandle(AudioHandler): class TTSAudioSplitHandle(TTSAudioHandle): def __init__(self, context, handler): super().__init__(context, handler) - self.sample_rate = self._context.get_audio_sample_rate() - self._chunk = self.sample_rate // self._context.get_fps() + self.sample_rate = self._context.sample_rate + self._chunk = self.sample_rate // self._context.fps def on_handle(self, stream, index): stream_len = stream.shape[0] diff --git a/ui.py b/ui.py index 521b186..1ea22bd 100644 --- a/ui.py +++ b/ui.py @@ -2,9 +2,12 @@ import json import logging import os +import queue from logging import handlers import tkinter import tkinter.messagebox +from queue import Queue + import customtkinter import cv2 import requests @@ -13,8 +16,9 @@ from PIL import Image, ImageTk from playsound import playsound -from Human import Human -from tts.EdgeTTS import EdgeTTS +# from Human import Human +from human import HumanContext +# from tts.EdgeTTS import EdgeTTS logger = logging.getLogger(__name__) @@ -55,7 +59,12 @@ class App(customtkinter.CTk): self._init_image_canvas() self._is_play_audio = False - self._human = Human() + # self._human = Human() + self._queue = Queue() + self._human_context = HumanContext() + self._human_context.build() + render = self._human_context.render_handler + render.set_image_render(self) self._render() # self.play_audio() @@ -65,29 +74,25 @@ class App(customtkinter.CTk): def on_destroy(self): logger.info('------------App destroy------------') - self._human.on_destroy() + # self._human.on_destroy() - def play_audio(self): - return - # if self._is_play_audio: - # return - # self._is_play_audio = True - # file = os.path.curdir + '/audio/test1.wav' - # print(file) - # winsound.PlaySound(file, winsound.SND_ASYNC or winsound.SND_FILENAME) - # playsound(file) + def render_image(self, image): + self._queue.put(image) def _init_image_canvas(self): self._canvas = customtkinter.CTkCanvas(self.image_frame) self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES) def _render(self): - image = self._human.render() - if image is None: - self.after(100, self._render) + try: + image = self._queue.get() + if image is None: + self.after(20, self._render) + return + except queue.Empty: + self.after(20, self._render) return - # self.play_audio() iheight, iwidth = image.shape[0], image.shape[1] width = self.winfo_width() height = self.winfo_height() @@ -95,10 +100,6 @@ class App(customtkinter.CTk): image = cv2.resize(image, (int(width), int(iheight * width / iwidth))) else: image = cv2.resize(image, (int(iwidth * height / iheight), int(height)), interpolation=cv2.INTER_AREA) - # image = cv2.resize(image, (int(width), int(height)), interpolation=cv2.INTER_AREA) - - # image = cv2.resize(image, (int(width), int(height)), interpolation=cv2.INTER_AREA) - img = Image.fromarray(image) imgtk = ImageTk.PhotoImage(image=img) @@ -110,7 +111,7 @@ class App(customtkinter.CTk): height = self.winfo_height() * 0.5 self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk) self._canvas.update() - self.after(40, self._render) + self.after(20, self._render) def request_tts(self): content = self.entry.get()