From 055d1733f32ff2970324a8f24b74888d4647522b Mon Sep 17 00:00:00 2001 From: brige Date: Sat, 19 Oct 2024 18:47:34 +0800 Subject: [PATCH] modfiy tts source --- human/human_context.py | 10 +++-- human_handler/audio_handler.py | 3 ++ nlp/nlp_base.py | 3 ++ test/test_mzzsfy_tts.py | 79 ++++++++++++++++++++++++++++++++++ test/test_nlp_tts.py | 8 ++-- tts/__init__.py | 1 + tts/tts_audio_handle.py | 3 ++ tts/tts_base.py | 3 ++ tts/tts_edge_http.py | 79 ++++++++++++++++++++++++++++++++++ utils/async_task_queue.py | 5 +++ 10 files changed, 188 insertions(+), 6 deletions(-) create mode 100644 test/test_mzzsfy_tts.py create mode 100644 tts/tts_edge_http.py diff --git a/human/human_context.py b/human/human_context.py index 12fcb0e..ff172ca 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -7,7 +7,7 @@ from .audio_inference_handler import AudioInferenceHandler from .audio_mal_handler import AudioMalHandler from .human_render import HumanRender from nlp import PunctuationSplit, DouBao -from tts import TTSEdge, TTSAudioSplitHandle +from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp from utils import load_avatar, get_device logger = logging.getLogger(__name__) @@ -102,10 +102,14 @@ class HumanContext: self._infer_handler = AudioInferenceHandler(self, self._render_handler) self._mal_handler = AudioMalHandler(self, self._infer_handler) self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler) - self._tts = TTSEdge(self._tts_handle) + self._tts = TTSEdgeHttp(self._tts_handle) split = PunctuationSplit() self._nlp = DouBao(split, self._tts) self._asr = SherpaNcnnAsr() self._asr.attach(self._nlp) - + def pause_talk(self): + self._nlp.pause_talk() + self._tts.pause_talk() + self._mal_handler.pause_talk() + self._infer_handler.pause_talk() diff --git a/human_handler/audio_handler.py b/human_handler/audio_handler.py index ce05787..89e01d5 100644 --- a/human_handler/audio_handler.py +++ b/human_handler/audio_handler.py @@ -23,3 +23,6 @@ class AudioHandler(ABC): self._handler.on_handle(stream, type_) else: logging.info(f'_handler is None') + + def pause_talk(self): + pass diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py index b4fd8e5..facf678 100644 --- a/nlp/nlp_base.py +++ b/nlp/nlp_base.py @@ -45,4 +45,7 @@ class NLPBase(AsrObserver): def stop(self): self._ask_queue.add_task(self._on_close) self._ask_queue.stop() + + def pause_talk(self): + self._ask_queue.clear() \ No newline at end of file diff --git a/test/test_mzzsfy_tts.py b/test/test_mzzsfy_tts.py new file mode 100644 index 0000000..5211ca2 --- /dev/null +++ b/test/test_mzzsfy_tts.py @@ -0,0 +1,79 @@ +#encoding = utf8 +import time + +import librosa +import numpy as np +import requests +import resampy +import soundfile as sf + + +def download_tts(url): + file_name = url[3:] + print(file_name) + download_url = url + print('download tts', download_url) + resp = requests.get(download_url) + with open('./audio/mp3/' + file_name, 'wb') as mp3: + mp3.write(resp.content) + + from pydub import AudioSegment + sound = AudioSegment.from_mp3('./audio/mp3/' + file_name) + sound.export('./audio/wav/' + file_name + '.wav', format="wav") + + +def __create_bytes_stream(byte_stream): + stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 + print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') + stream = stream.astype(np.float32) + + if stream.ndim > 1: + print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + stream = stream[:, 0] + + if sample_rate != 16000 and stream.shape[0] > 0: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {16000}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=16000) + + return stream + + +def main(): + import aiohttp + import asyncio + from io import BytesIO + + async def fetch_audio(): + url = "http://localhost:8082/v1/audio/speech" + data = { + "model": "tts-1", + "input": "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用", + "voice": "alloy", + "speed": 1.0 + } + + async with aiohttp.ClientSession() as session: + async with session.post(url, json=data) as response: + if response.status == 200: + audio_data = BytesIO(await response.read()) + audio_stream = __create_bytes_stream(audio_data) + + # 保存为新的音频文件 + sf.write("output_audio.wav", audio_stream, 16000) + print("Audio data received and saved to output_audio.wav") + else: + print("Error:", response.status, await response.text()) + + # Run the async function + asyncio.run(fetch_audio()) + + +if __name__ == "__main__": + try: + t = time.time() + main() + print(f'-------tts time:{time.time() - t:.4f}s') + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") + except Exception as e: + print(e) diff --git a/test/test_nlp_tts.py b/test/test_nlp_tts.py index 85888a5..f2e02c3 100644 --- a/test/test_nlp_tts.py +++ b/test/test_nlp_tts.py @@ -6,7 +6,7 @@ import time from asr import SherpaNcnnAsr from nlp import PunctuationSplit from nlp.nlp_doubao import DouBao -from tts import TTSEdge, TTSAudioSaveHandle +from tts import TTSEdge, TTSAudioSaveHandle, TTSEdgeHttp try: import sounddevice as sd @@ -21,12 +21,14 @@ except ImportError as e: def main(): print("Started! Please speak") - handle = TTSAudioSaveHandle() - tts = TTSEdge(handle) + handle = TTSAudioSaveHandle(None, None) + # tts = TTSEdge(handle) + tts = TTSEdgeHttp(handle) split = PunctuationSplit() nlp = DouBao(split, tts) nlp.ask('你好,你是谁?') nlp.ask('可以帮我做什么?') + nlp.ask('背诵出师表') nlp.stop() tts.stop() print("Stop! ") diff --git a/tts/__init__.py b/tts/__init__.py index 291563b..48171b8 100644 --- a/tts/__init__.py +++ b/tts/__init__.py @@ -1,4 +1,5 @@ #encoding = utf8 from .tts_edge import TTSEdge +from .tts_edge_http import TTSEdgeHttp from .tts_audio_handle import TTSAudioSplitHandle, TTSAudioSaveHandle diff --git a/tts/tts_audio_handle.py b/tts/tts_audio_handle.py index 79eacf1..badbe7f 100644 --- a/tts/tts_audio_handle.py +++ b/tts/tts_audio_handle.py @@ -34,6 +34,9 @@ class TTSAudioHandle(AudioHandler): def stop(self): pass + def pause_talk(self): + pass + class TTSAudioSplitHandle(TTSAudioHandle): def __init__(self, context, handler): diff --git a/tts/tts_base.py b/tts/tts_base.py index a81a59c..9cebe93 100644 --- a/tts/tts_base.py +++ b/tts/tts_base.py @@ -59,3 +59,6 @@ class TTSBase(NLPCallback): def stop(self): self._message_queue.add_task(self._on_close) self._message_queue.stop() + + def pause_talk(self): + self._message_queue.clear() diff --git a/tts/tts_edge_http.py b/tts/tts_edge_http.py new file mode 100644 index 0000000..46ba42c --- /dev/null +++ b/tts/tts_edge_http.py @@ -0,0 +1,79 @@ +#encoding = utf8 +import logging +from io import BytesIO + +import aiohttp +import numpy as np +import soundfile as sf +import edge_tts +import resampy + +from .tts_base import TTSBase + +logger = logging.getLogger(__name__) + + +class TTSEdgeHttp(TTSBase): + def __init__(self, handle, voice='zh-CN-XiaoyiNeural'): + super().__init__(handle) + self._voice = voice + self._url = 'http://localhost:8082/v1/audio/speech' + logger.info(f"TTSEdge init, {voice}") + + async def _on_request(self, txt: str): + print('_on_request, txt') + data = { + "model": "tts-1", + "input": txt, + "voice": "alloy", + "speed": 1.0, + "thread": 10 + } + async with aiohttp.ClientSession() as session: + async with session.post(self._url, json=data) as response: + if response.status == 200: + stream = BytesIO(await response.read()) + + print("Audio data received and saved to output_audio.wav") + return stream + else: + byte_stream = None + return byte_stream + + async def _on_handle(self, stream, index): + print('-------tts _on_handle') + try: + stream.seek(0) + byte_stream = self.__create_bytes_stream(stream) + print('-------tts start push chunk') + self._handle.on_handle(byte_stream, index) + stream.seek(0) + stream.truncate() + print('-------tts finish push chunk') + + except Exception as e: + self._handle.on_handle(None, index) + stream.seek(0) + stream.truncate() + print('-------tts finish error:', e) + stream.close() + + def __create_bytes_stream(self, byte_stream): + stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 + print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') + stream = stream.astype(np.float32) + + if stream.ndim > 1: + print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + stream = stream[:, 0] + + if sample_rate != self._handle.sample_rate and stream.shape[0] > 0: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate) + + return stream + + async def _on_close(self): + print('TTSEdge close') + # if self._byte_stream is not None and not self._byte_stream.closed: + # self._byte_stream.close() diff --git a/utils/async_task_queue.py b/utils/async_task_queue.py index e2b7b4c..716b9cd 100644 --- a/utils/async_task_queue.py +++ b/utils/async_task_queue.py @@ -53,6 +53,11 @@ class AsyncTaskQueue: for _ in range(self._worker_num): self.add_task(None) # 发送结束信号 + def clear(self): + while not self._queue.empty(): + self._queue.get() + self._queue.task_done() + def stop(self): self.stop_workers() self._thread.join()