128 lines
3.6 KiB
Python
128 lines
3.6 KiB
Python
#encoding = utf8
|
|
import logging
|
|
import os
|
|
|
|
from asr import SherpaNcnnAsr
|
|
from .audio_inference_handler import AudioInferenceHandler
|
|
from .audio_mal_handler import AudioMalHandler
|
|
from .human_render import HumanRender
|
|
from nlp import PunctuationSplit, DouBao
|
|
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
|
|
from utils import load_avatar, get_device
|
|
|
|
logger = logging.getLogger(__name__)
|
|
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
|
class HumanContext:
|
|
def __init__(self):
|
|
self._fps = 50 # 20 ms per frame
|
|
self._image_size = 96
|
|
self._batch_size = 16
|
|
self._sample_rate = 16000
|
|
self._stride_left_size = 10
|
|
self._stride_right_size = 10
|
|
self._render_batch = 5
|
|
|
|
self._device = get_device()
|
|
print(f'device:{self._device}')
|
|
base_path = os.path.join(current_file_path, '..', 'face')
|
|
logger.info(f'_create_recognizer init, path:{base_path}')
|
|
full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
|
|
self._frame_list_cycle = full_images
|
|
self._face_list_cycle = face_frames
|
|
self._coord_list_cycle = coord_frames
|
|
face_images_length = len(self._face_list_cycle)
|
|
logging.info(f'face images length: {face_images_length}')
|
|
print(f'face images length: {face_images_length}')
|
|
|
|
self._asr = None
|
|
self._nlp = None
|
|
self._tts = None
|
|
self._tts_handle = None
|
|
self._mal_handler = None
|
|
self._infer_handler = None
|
|
self._render_handler = None
|
|
|
|
def __del__(self):
|
|
print(f'HumanContext: __del__')
|
|
self._asr.stop()
|
|
self._nlp.stop()
|
|
self._tts.stop()
|
|
self._tts_handle.stop()
|
|
self._mal_handler.stop()
|
|
self._infer_handler.stop()
|
|
self._render_handler.stop()
|
|
|
|
@property
|
|
def fps(self):
|
|
return self._fps
|
|
|
|
@property
|
|
def image_size(self):
|
|
return self._image_size
|
|
|
|
@property
|
|
def render_batch(self):
|
|
return self._render_batch
|
|
|
|
@property
|
|
def device(self):
|
|
return self._device
|
|
|
|
@property
|
|
def batch_size(self):
|
|
return self._batch_size
|
|
|
|
@property
|
|
def sample_rate(self):
|
|
return self._sample_rate
|
|
|
|
@property
|
|
def stride_left_size(self):
|
|
return self._stride_left_size
|
|
|
|
@property
|
|
def stride_right_size(self):
|
|
return self._stride_right_size
|
|
|
|
@property
|
|
def face_list_cycle(self):
|
|
return self._face_list_cycle
|
|
|
|
@property
|
|
def frame_list_cycle(self):
|
|
return self._frame_list_cycle
|
|
|
|
@property
|
|
def coord_list_cycle(self):
|
|
return self._coord_list_cycle
|
|
|
|
@property
|
|
def render_handler(self):
|
|
return self._render_handler
|
|
|
|
def notify(self, message):
|
|
if self._tts_handle is not None:
|
|
self._tts_handle.on_message(message)
|
|
else:
|
|
logger.info(f'notify message:{message}')
|
|
|
|
def build(self):
|
|
self._render_handler = HumanRender(self, None)
|
|
self._infer_handler = AudioInferenceHandler(self, self._render_handler)
|
|
self._mal_handler = AudioMalHandler(self, self._infer_handler)
|
|
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
|
|
self._tts = TTSEdgeHttp(self._tts_handle)
|
|
split = PunctuationSplit()
|
|
self._nlp = DouBao(self, split, self._tts)
|
|
self._asr = SherpaNcnnAsr()
|
|
self._asr.attach(self._nlp)
|
|
|
|
def pause_talk(self):
|
|
self._nlp.pause_talk()
|
|
self._tts.pause_talk()
|
|
self._mal_handler.pause_talk()
|
|
self._infer_handler.pause_talk()
|
|
self._render_handler.pause_talk()
|