add audio index and handle

This commit is contained in:
jiegeaiai 2024-10-15 08:31:43 +08:00
parent 205c8f21fe
commit 1bb6684416
5 changed files with 38 additions and 11 deletions

View File

@ -1,3 +1,4 @@
#encoding = utf8 #encoding = utf8
from .human_context import HumanContext from .human_context import HumanContext
from .audio_handler import AudioHandler

8
human/audio_handler.py Normal file
View File

@ -0,0 +1,8 @@
#encoding = utf8
from abc import ABC, abstractmethod
class AudioHandler(ABC):
@abstractmethod
def on_handle(self, stream, index):
pass

View File

@ -1,4 +1,7 @@
#encoding = utf8 #encoding = utf8
from asr import SherpaNcnnAsr
from nlp import PunctuationSplit, DouBao
from tts import TTSEdge, TTSAudioSplitHandle
class HumanContext: class HumanContext:
@ -29,3 +32,12 @@ class HumanContext:
def stride_right_size(self): def stride_right_size(self):
return self._stride_right_size return self._stride_right_size
def build(self):
tts_handle = TTSAudioSplitHandle(self)
tts = TTSEdge(tts_handle)
split = PunctuationSplit()
nlp = DouBao(split, tts)
asr = SherpaNcnnAsr()
asr.attach(nlp)

View File

@ -1,12 +1,12 @@
#encoding = utf8 #encoding = utf8
import os import os
import shutil import shutil
from abc import ABC, abstractmethod
from audio import save_wav from audio import save_wav
from human import AudioHandler
class TTSAudioHandle(ABC): class TTSAudioHandle(AudioHandler):
def __init__(self): def __init__(self):
self._sample_rate = 16000 self._sample_rate = 16000
self._index = 1 self._index = 1
@ -19,28 +19,24 @@ class TTSAudioHandle(ABC):
def sample_rate(self, value): def sample_rate(self, value):
self._sample_rate = value self._sample_rate = value
@abstractmethod
def on_handle(self, stream, index):
pass
def get_index(self): def get_index(self):
self._index = self._index + 1 self._index = self._index + 1
return self._index return self._index
class TTSAudioSplitHandle(TTSAudioHandle): class TTSAudioSplitHandle(TTSAudioHandle):
def __init__(self, human): def __init__(self, context):
super().__init__() super().__init__()
self._human = human self._context = context
self.sample_rate = self._human.get_audio_sample_rate() self.sample_rate = self._context.get_audio_sample_rate()
self._chunk = self.sample_rate // self._human.get_fps() self._chunk = self.sample_rate // self._context.get_fps()
def on_handle(self, stream, index): def on_handle(self, stream, index):
stream_len = stream.shape[0] stream_len = stream.shape[0]
idx = 0 idx = 0
while stream_len >= self._chunk: while stream_len >= self._chunk:
self._human.put_audio_frame(stream[idx:idx + self._chunk]) self._context.put_audio_frame(stream[idx:idx + self._chunk])
stream_len -= self._chunk stream_len -= self._chunk
idx += self._chunk idx += self._chunk

10
tts/tts_audio_index.py Normal file
View File

@ -0,0 +1,10 @@
#encoding = utf8
class TTSAudioIndex:
def __init__(self):
self._index = 0
def get_index(self):
self._index = self._index + 1
return self._index