diff --git a/test/asr_nlp_tts.py b/test/asr_nlp_tts.py index c64d0d4..528ee03 100644 --- a/test/asr_nlp_tts.py +++ b/test/asr_nlp_tts.py @@ -6,7 +6,7 @@ import time from asr import SherpaNcnnAsr from nlp import PunctuationSplit from nlp.nlp_doubao import DouBao -from tts import TTSEdge +from tts import TTSEdge, TTSAudioSaveHandle try: import sounddevice as sd @@ -21,7 +21,8 @@ except ImportError as e: def main(): print("Started! Please speak") - tts = TTSEdge() + handle = TTSAudioSaveHandle() + tts = TTSEdge(handle) split = PunctuationSplit() nlp = DouBao(split, tts) asr = SherpaNcnnAsr() diff --git a/tts/TTSBase.py b/tts/TTSBase.py index a5b2dfb..3c9a37b 100644 --- a/tts/TTSBase.py +++ b/tts/TTSBase.py @@ -74,7 +74,8 @@ class TTSBase: stream = stream[:, 0] if sample_rate != self._human.get_audio_sample_rate() and stream.shape[0] > 0: - print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._human.get_audio_sample_rate()}.') + print(f'[WARN] audio sample rate is {sample_rate},' + f'resampling into {self._human.get_audio_sample_rate()}.') stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._human.get_audio_sample_rate()) return stream diff --git a/tts/__init__.py b/tts/__init__.py index f0fa76e..291563b 100644 --- a/tts/__init__.py +++ b/tts/__init__.py @@ -1,3 +1,4 @@ #encoding = utf8 from .tts_edge import TTSEdge +from .tts_audio_handle import TTSAudioSplitHandle, TTSAudioSaveHandle diff --git a/tts/tts_audio_handle.py b/tts/tts_audio_handle.py new file mode 100644 index 0000000..b04016a --- /dev/null +++ b/tts/tts_audio_handle.py @@ -0,0 +1,72 @@ +#encoding = utf8 +import os +import shutil +from abc import ABC, abstractmethod + +from audio import save_wav + + +class TTSAudioHandle(ABC): + def __init__(self): + self._sample_rate = 16000 + + @property + def sample_rate(self): + return self._sample_rate + + @sample_rate.setter + def sample_rate(self, value): + self._sample_rate = value + + @abstractmethod + def on_handle(self, stream): + pass + + +class TTSAudioSplitHandle(TTSAudioHandle): + def __init__(self, human): + super().__init__() + self._human = human + self.sample_rate = self._human.get_audio_sample_rate() + self._chunk = self.sample_rate // self._human.get_fps() + + def on_handle(self, stream): + stream_len = stream.shape[0] + idx = 0 + + while stream_len >= self._chunk: + self._human.put_audio_frame(stream[idx:idx + self._chunk]) + stream_len -= self._chunk + idx += self._chunk + + +class TTSAudioSaveHandle(TTSAudioHandle): + def __init__(self): + super().__init__() + self._count = 1 + self._save_path_dir = '../temp/audio/' + self._clean() + + def _clean(self): + directory = self._save_path_dir + if not os.path.exists(directory): + print(f"The directory {directory} does not exist.") + return + + for filename in os.listdir(directory): + file_path = os.path.join(directory, filename) + + # 如果是文件,删除 + if os.path.isfile(file_path): + os.remove(file_path) + print(f"Deleted file: {file_path}") + # 如果是文件夹,递归删除所有文件夹中的内容 + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + print(f"Deleted directory and its contents: {file_path}") + + def on_handle(self, stream): + file_name = self._save_path_dir + str(self._count) + '.wav' + save_wav(stream, file_name, self.sample_rate) + self._count = self._count + 1 + diff --git a/tts/tts_base.py b/tts/tts_base.py index cd0bc37..fd1d3e8 100644 --- a/tts/tts_base.py +++ b/tts/tts_base.py @@ -10,11 +10,19 @@ logger = logging.getLogger(__name__) class TTSBase(NLPCallback): - def __init__(self): - self._sample_rate = 16000 + def __init__(self, handle): + self._handle = handle self._message_queue = AsyncTaskQueue() self._message_queue.start_worker() + @property + def handle(self): + return self._handle + + @handle.setter + def handle(self, value): + self._handle = value + async def _request(self, txt: str): print('_request:', txt) t = time.time() diff --git a/tts/tts_edge.py b/tts/tts_edge.py index 9914e36..3778b31 100644 --- a/tts/tts_edge.py +++ b/tts/tts_edge.py @@ -7,16 +7,14 @@ import soundfile as sf import edge_tts import resampy -from audio import save_chunks, save_wav from .tts_base import TTSBase class TTSEdge(TTSBase): - def __init__(self, voice='zh-CN-XiaoyiNeural'): - super().__init__() + def __init__(self, handle, voice='zh-CN-XiaoyiNeural'): + super().__init__(handle) self._voice = voice self._byte_stream = BytesIO() - self._count = 1 async def _on_request(self, txt: str): communicate = edge_tts.Communicate(txt, self._voice) @@ -25,32 +23,21 @@ class TTSEdge(TTSBase): if first: first = False if chunk["type"] == "audio": - # self.push_audio(chunk["data"]) self._byte_stream.write(chunk["data"]) - # file.write(chunk["data"]) elif chunk["type"] == "WordBoundary": pass async def _on_handle(self): self._byte_stream.seek(0) try: + self._byte_stream.seek(0) stream = self.__create_bytes_stream(self._byte_stream) - stream_len = stream.shape[0] - idx = 0 print('-------tts start push chunk') - save_wav(stream, '../temp/audio/' + str(self._count) + '.wav', 16000) - self._count = self._count + 1 - # chunk = stream[0:] - # save_chunks(chunk, 16000, './temp/audio') - # while stream_len >= self.chunk: - # self._human.put_audio_frame(stream[idx:idx + self.chunk]) - # streamlen -= self.chunk - # idx += self.chunk - # if streamlen>0: #skip last frame(not 20ms) - # self.queue.put(stream[idx:]) + self._handle.on_handle(stream) self._byte_stream.seek(0) self._byte_stream.truncate() print('-------tts finish push chunk') + except Exception as e: self._byte_stream.seek(0) self._byte_stream.truncate() @@ -65,8 +52,8 @@ class TTSEdge(TTSBase): print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') stream = stream[:, 0] - if sample_rate != self._sample_rate and stream.shape[0] > 0: - print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._sample_rate}.') - stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate) + if sample_rate != self._handle.sample_rate and stream.shape[0] > 0: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate) return stream