diff --git a/human/__init__.py b/human/__init__.py index 7d606c4..f1319f8 100644 --- a/human/__init__.py +++ b/human/__init__.py @@ -1,3 +1,4 @@ #encoding = utf8 from .human_context import HumanContext +from .audio_handler import AudioHandler diff --git a/human/audio_handler.py b/human/audio_handler.py new file mode 100644 index 0000000..6d5648b --- /dev/null +++ b/human/audio_handler.py @@ -0,0 +1,8 @@ +#encoding = utf8 +from abc import ABC, abstractmethod + + +class AudioHandler(ABC): + @abstractmethod + def on_handle(self, stream, index): + pass diff --git a/human/human_context.py b/human/human_context.py index f5231f5..0db4d10 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -1,4 +1,7 @@ #encoding = utf8 +from asr import SherpaNcnnAsr +from nlp import PunctuationSplit, DouBao +from tts import TTSEdge, TTSAudioSplitHandle class HumanContext: @@ -29,3 +32,12 @@ class HumanContext: def stride_right_size(self): return self._stride_right_size + def build(self): + tts_handle = TTSAudioSplitHandle(self) + tts = TTSEdge(tts_handle) + split = PunctuationSplit() + nlp = DouBao(split, tts) + asr = SherpaNcnnAsr() + asr.attach(nlp) + + diff --git a/tts/tts_audio_handle.py b/tts/tts_audio_handle.py index 8c1d39d..545705a 100644 --- a/tts/tts_audio_handle.py +++ b/tts/tts_audio_handle.py @@ -1,12 +1,12 @@ #encoding = utf8 import os import shutil -from abc import ABC, abstractmethod from audio import save_wav +from human import AudioHandler -class TTSAudioHandle(ABC): +class TTSAudioHandle(AudioHandler): def __init__(self): self._sample_rate = 16000 self._index = 1 @@ -19,28 +19,24 @@ class TTSAudioHandle(ABC): def sample_rate(self, value): self._sample_rate = value - @abstractmethod - def on_handle(self, stream, index): - pass - def get_index(self): self._index = self._index + 1 return self._index class TTSAudioSplitHandle(TTSAudioHandle): - def __init__(self, human): + def __init__(self, context): super().__init__() - self._human = human - self.sample_rate = self._human.get_audio_sample_rate() - self._chunk = self.sample_rate // self._human.get_fps() + self._context = context + self.sample_rate = self._context.get_audio_sample_rate() + self._chunk = self.sample_rate // self._context.get_fps() def on_handle(self, stream, index): stream_len = stream.shape[0] idx = 0 while stream_len >= self._chunk: - self._human.put_audio_frame(stream[idx:idx + self._chunk]) + self._context.put_audio_frame(stream[idx:idx + self._chunk]) stream_len -= self._chunk idx += self._chunk diff --git a/tts/tts_audio_index.py b/tts/tts_audio_index.py new file mode 100644 index 0000000..b348ff2 --- /dev/null +++ b/tts/tts_audio_index.py @@ -0,0 +1,10 @@ +#encoding = utf8 + + +class TTSAudioIndex: + def __init__(self): + self._index = 0 + + def get_index(self): + self._index = self._index + 1 + return self._index