human/human/human_context.py
2025-06-17 16:13:12 +08:00

139 lines
4.2 KiB
Python

#encoding = utf8
import logging
import os
from asr import SherpaNcnnAsr
from eventbus import EventBus
from .audio_inference_handler import AudioInferenceHandler
from .audio_mal_handler import AudioMalHandler
from nlp import PunctuationSplit, DouBao, Kimi
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
from utils import load_avatar, get_device, object_stop, load_avatar_from_processed, load_avatar_from_256_processed
logger = logging.getLogger(__name__)
current_file_path = os.path.dirname(os.path.abspath(__file__))
class HumanContext:
def __init__(self):
self._fps = 50 # 20 ms per frame
self._image_size = 288
self._batch_size = 6
self._sample_rate = 16000
self._stride_left_size = 4
self._stride_right_size = 4
self._asr = None
self._nlp = None
self._tts = None
self._tts_handle = None
self._mal_handler = None
self._infer_handler = None
self._render_handler = None
self._device = get_device()
print(f'device:{self._device}')
base_path = os.path.join(current_file_path, '..')
logger.info(f'base path:{base_path}')
# full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
# full_images, face_frames, coord_frames = load_avatar_from_processed(base_path,
# 'wav2lip_avatar3')
full_images, face_frames, coord_frames, align_frames, m_frames, inv_m_frames = load_avatar_from_256_processed(
base_path, 'wav2lip_avatar4', '26.pkl')
self._frame_list_cycle = full_images
self._face_list_cycle = face_frames
self._coord_list_cycle = coord_frames
self._align_frames = align_frames
self._m_frames = m_frames
self._inv_m_frames = inv_m_frames
face_images_length = len(self._face_list_cycle)
# TODO: get person config
self.person_config ={
"frame_config": [[1,face_frames-1, True]],
}
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
def __del__(self):
print(f'HumanContext: __del__')
@property
def fps(self):
return self._fps
@property
def image_size(self):
return self._image_size
@property
def device(self):
return self._device
@property
def batch_size(self):
return self._batch_size
@property
def sample_rate(self):
return self._sample_rate
@property
def stride_left_size(self):
return self._stride_left_size
@property
def stride_right_size(self):
return self._stride_right_size
@property
def face_list_cycle(self):
return self._face_list_cycle
@property
def frame_list_cycle(self):
return self._frame_list_cycle
@property
def coord_list_cycle(self):
return self._coord_list_cycle
@property
def align_frames(self):
return self._align_frames
@property
def inv_m_frames(self):
return self._inv_m_frames
@property
def render_handler(self):
return self._render_handler
def notify(self, message):
if self._tts_handle is not None:
self._tts_handle.on_message(message)
else:
logger.info(f'notify message:{message}')
def build(self, render_handler):
self._render_handler = render_handler
self._infer_handler = AudioInferenceHandler(self, self._render_handler, self.person_config)
self._mal_handler = AudioMalHandler(self, self._infer_handler)
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
self._tts = TTSEdgeHttp(self._tts_handle)
split = PunctuationSplit()
self._nlp = DouBao(self, split, self._tts)
# self._nlp = Kimi(self, split, self._tts)
self._asr = SherpaNcnnAsr()
self._asr.attach(self._nlp)
def stop(self):
EventBus().post('stop')
def pause_talk(self):
self._nlp.pause_talk()
self._tts.pause_talk()
self._mal_handler.pause_talk()
self._infer_handler.pause_talk()
self._render_handler.pause_talk()