modfiy tts source

This commit is contained in:
brige 2024-10-19 18:47:34 +08:00
parent 6b5361d91a
commit 055d1733f3
10 changed files with 188 additions and 6 deletions

View File

@ -7,7 +7,7 @@ from .audio_inference_handler import AudioInferenceHandler
from .audio_mal_handler import AudioMalHandler
from .human_render import HumanRender
from nlp import PunctuationSplit, DouBao
from tts import TTSEdge, TTSAudioSplitHandle
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
from utils import load_avatar, get_device
logger = logging.getLogger(__name__)
@ -102,10 +102,14 @@ class HumanContext:
self._infer_handler = AudioInferenceHandler(self, self._render_handler)
self._mal_handler = AudioMalHandler(self, self._infer_handler)
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
self._tts = TTSEdge(self._tts_handle)
self._tts = TTSEdgeHttp(self._tts_handle)
split = PunctuationSplit()
self._nlp = DouBao(split, self._tts)
self._asr = SherpaNcnnAsr()
self._asr.attach(self._nlp)
def pause_talk(self):
self._nlp.pause_talk()
self._tts.pause_talk()
self._mal_handler.pause_talk()
self._infer_handler.pause_talk()

View File

@ -23,3 +23,6 @@ class AudioHandler(ABC):
self._handler.on_handle(stream, type_)
else:
logging.info(f'_handler is None')
def pause_talk(self):
pass

View File

@ -46,3 +46,6 @@ class NLPBase(AsrObserver):
self._ask_queue.add_task(self._on_close)
self._ask_queue.stop()
def pause_talk(self):
self._ask_queue.clear()

79
test/test_mzzsfy_tts.py Normal file
View File

@ -0,0 +1,79 @@
#encoding = utf8
import time
import librosa
import numpy as np
import requests
import resampy
import soundfile as sf
def download_tts(url):
file_name = url[3:]
print(file_name)
download_url = url
print('download tts', download_url)
resp = requests.get(download_url)
with open('./audio/mp3/' + file_name, 'wb') as mp3:
mp3.write(resp.content)
from pydub import AudioSegment
sound = AudioSegment.from_mp3('./audio/mp3/' + file_name)
sound.export('./audio/wav/' + file_name + '.wav', format="wav")
def __create_bytes_stream(byte_stream):
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != 16000 and stream.shape[0] > 0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {16000}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=16000)
return stream
def main():
import aiohttp
import asyncio
from io import BytesIO
async def fetch_audio():
url = "http://localhost:8082/v1/audio/speech"
data = {
"model": "tts-1",
"input": "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用",
"voice": "alloy",
"speed": 1.0
}
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as response:
if response.status == 200:
audio_data = BytesIO(await response.read())
audio_stream = __create_bytes_stream(audio_data)
# 保存为新的音频文件
sf.write("output_audio.wav", audio_stream, 16000)
print("Audio data received and saved to output_audio.wav")
else:
print("Error:", response.status, await response.text())
# Run the async function
asyncio.run(fetch_audio())
if __name__ == "__main__":
try:
t = time.time()
main()
print(f'-------tts time:{time.time() - t:.4f}s')
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
except Exception as e:
print(e)

View File

@ -6,7 +6,7 @@ import time
from asr import SherpaNcnnAsr
from nlp import PunctuationSplit
from nlp.nlp_doubao import DouBao
from tts import TTSEdge, TTSAudioSaveHandle
from tts import TTSEdge, TTSAudioSaveHandle, TTSEdgeHttp
try:
import sounddevice as sd
@ -21,12 +21,14 @@ except ImportError as e:
def main():
print("Started! Please speak")
handle = TTSAudioSaveHandle()
tts = TTSEdge(handle)
handle = TTSAudioSaveHandle(None, None)
# tts = TTSEdge(handle)
tts = TTSEdgeHttp(handle)
split = PunctuationSplit()
nlp = DouBao(split, tts)
nlp.ask('你好,你是谁?')
nlp.ask('可以帮我做什么?')
nlp.ask('背诵出师表')
nlp.stop()
tts.stop()
print("Stop! ")

View File

@ -1,4 +1,5 @@
#encoding = utf8
from .tts_edge import TTSEdge
from .tts_edge_http import TTSEdgeHttp
from .tts_audio_handle import TTSAudioSplitHandle, TTSAudioSaveHandle

View File

@ -34,6 +34,9 @@ class TTSAudioHandle(AudioHandler):
def stop(self):
pass
def pause_talk(self):
pass
class TTSAudioSplitHandle(TTSAudioHandle):
def __init__(self, context, handler):

View File

@ -59,3 +59,6 @@ class TTSBase(NLPCallback):
def stop(self):
self._message_queue.add_task(self._on_close)
self._message_queue.stop()
def pause_talk(self):
self._message_queue.clear()

79
tts/tts_edge_http.py Normal file
View File

@ -0,0 +1,79 @@
#encoding = utf8
import logging
from io import BytesIO
import aiohttp
import numpy as np
import soundfile as sf
import edge_tts
import resampy
from .tts_base import TTSBase
logger = logging.getLogger(__name__)
class TTSEdgeHttp(TTSBase):
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
super().__init__(handle)
self._voice = voice
self._url = 'http://localhost:8082/v1/audio/speech'
logger.info(f"TTSEdge init, {voice}")
async def _on_request(self, txt: str):
print('_on_request, txt')
data = {
"model": "tts-1",
"input": txt,
"voice": "alloy",
"speed": 1.0,
"thread": 10
}
async with aiohttp.ClientSession() as session:
async with session.post(self._url, json=data) as response:
if response.status == 200:
stream = BytesIO(await response.read())
print("Audio data received and saved to output_audio.wav")
return stream
else:
byte_stream = None
return byte_stream
async def _on_handle(self, stream, index):
print('-------tts _on_handle')
try:
stream.seek(0)
byte_stream = self.__create_bytes_stream(stream)
print('-------tts start push chunk')
self._handle.on_handle(byte_stream, index)
stream.seek(0)
stream.truncate()
print('-------tts finish push chunk')
except Exception as e:
self._handle.on_handle(None, index)
stream.seek(0)
stream.truncate()
print('-------tts finish error:', e)
stream.close()
def __create_bytes_stream(self, byte_stream):
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
return stream
async def _on_close(self):
print('TTSEdge close')
# if self._byte_stream is not None and not self._byte_stream.closed:
# self._byte_stream.close()

View File

@ -53,6 +53,11 @@ class AsyncTaskQueue:
for _ in range(self._worker_num):
self.add_task(None) # 发送结束信号
def clear(self):
while not self._queue.empty():
self._queue.get()
self._queue.task_done()
def stop(self):
self.stop_workers()
self._thread.join()