modfiy tts source
This commit is contained in:
parent
6b5361d91a
commit
055d1733f3
@ -7,7 +7,7 @@ from .audio_inference_handler import AudioInferenceHandler
|
||||
from .audio_mal_handler import AudioMalHandler
|
||||
from .human_render import HumanRender
|
||||
from nlp import PunctuationSplit, DouBao
|
||||
from tts import TTSEdge, TTSAudioSplitHandle
|
||||
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
|
||||
from utils import load_avatar, get_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -102,10 +102,14 @@ class HumanContext:
|
||||
self._infer_handler = AudioInferenceHandler(self, self._render_handler)
|
||||
self._mal_handler = AudioMalHandler(self, self._infer_handler)
|
||||
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
|
||||
self._tts = TTSEdge(self._tts_handle)
|
||||
self._tts = TTSEdgeHttp(self._tts_handle)
|
||||
split = PunctuationSplit()
|
||||
self._nlp = DouBao(split, self._tts)
|
||||
self._asr = SherpaNcnnAsr()
|
||||
self._asr.attach(self._nlp)
|
||||
|
||||
|
||||
def pause_talk(self):
|
||||
self._nlp.pause_talk()
|
||||
self._tts.pause_talk()
|
||||
self._mal_handler.pause_talk()
|
||||
self._infer_handler.pause_talk()
|
||||
|
@ -23,3 +23,6 @@ class AudioHandler(ABC):
|
||||
self._handler.on_handle(stream, type_)
|
||||
else:
|
||||
logging.info(f'_handler is None')
|
||||
|
||||
def pause_talk(self):
|
||||
pass
|
||||
|
@ -46,3 +46,6 @@ class NLPBase(AsrObserver):
|
||||
self._ask_queue.add_task(self._on_close)
|
||||
self._ask_queue.stop()
|
||||
|
||||
def pause_talk(self):
|
||||
self._ask_queue.clear()
|
||||
|
79
test/test_mzzsfy_tts.py
Normal file
79
test/test_mzzsfy_tts.py
Normal file
@ -0,0 +1,79 @@
|
||||
#encoding = utf8
|
||||
import time
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import requests
|
||||
import resampy
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def download_tts(url):
|
||||
file_name = url[3:]
|
||||
print(file_name)
|
||||
download_url = url
|
||||
print('download tts', download_url)
|
||||
resp = requests.get(download_url)
|
||||
with open('./audio/mp3/' + file_name, 'wb') as mp3:
|
||||
mp3.write(resp.content)
|
||||
|
||||
from pydub import AudioSegment
|
||||
sound = AudioSegment.from_mp3('./audio/mp3/' + file_name)
|
||||
sound.export('./audio/wav/' + file_name + '.wav', format="wav")
|
||||
|
||||
|
||||
def __create_bytes_stream(byte_stream):
|
||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||
stream = stream.astype(np.float32)
|
||||
|
||||
if stream.ndim > 1:
|
||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||
stream = stream[:, 0]
|
||||
|
||||
if sample_rate != 16000 and stream.shape[0] > 0:
|
||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {16000}.')
|
||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=16000)
|
||||
|
||||
return stream
|
||||
|
||||
|
||||
def main():
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from io import BytesIO
|
||||
|
||||
async def fetch_audio():
|
||||
url = "http://localhost:8082/v1/audio/speech"
|
||||
data = {
|
||||
"model": "tts-1",
|
||||
"input": "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用",
|
||||
"voice": "alloy",
|
||||
"speed": 1.0
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as response:
|
||||
if response.status == 200:
|
||||
audio_data = BytesIO(await response.read())
|
||||
audio_stream = __create_bytes_stream(audio_data)
|
||||
|
||||
# 保存为新的音频文件
|
||||
sf.write("output_audio.wav", audio_stream, 16000)
|
||||
print("Audio data received and saved to output_audio.wav")
|
||||
else:
|
||||
print("Error:", response.status, await response.text())
|
||||
|
||||
# Run the async function
|
||||
asyncio.run(fetch_audio())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
t = time.time()
|
||||
main()
|
||||
print(f'-------tts time:{time.time() - t:.4f}s')
|
||||
except KeyboardInterrupt:
|
||||
print("\nCaught Ctrl + C. Exiting")
|
||||
except Exception as e:
|
||||
print(e)
|
@ -6,7 +6,7 @@ import time
|
||||
from asr import SherpaNcnnAsr
|
||||
from nlp import PunctuationSplit
|
||||
from nlp.nlp_doubao import DouBao
|
||||
from tts import TTSEdge, TTSAudioSaveHandle
|
||||
from tts import TTSEdge, TTSAudioSaveHandle, TTSEdgeHttp
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
@ -21,12 +21,14 @@ except ImportError as e:
|
||||
|
||||
def main():
|
||||
print("Started! Please speak")
|
||||
handle = TTSAudioSaveHandle()
|
||||
tts = TTSEdge(handle)
|
||||
handle = TTSAudioSaveHandle(None, None)
|
||||
# tts = TTSEdge(handle)
|
||||
tts = TTSEdgeHttp(handle)
|
||||
split = PunctuationSplit()
|
||||
nlp = DouBao(split, tts)
|
||||
nlp.ask('你好,你是谁?')
|
||||
nlp.ask('可以帮我做什么?')
|
||||
nlp.ask('背诵出师表')
|
||||
nlp.stop()
|
||||
tts.stop()
|
||||
print("Stop! ")
|
||||
|
@ -1,4 +1,5 @@
|
||||
#encoding = utf8
|
||||
|
||||
from .tts_edge import TTSEdge
|
||||
from .tts_edge_http import TTSEdgeHttp
|
||||
from .tts_audio_handle import TTSAudioSplitHandle, TTSAudioSaveHandle
|
||||
|
@ -34,6 +34,9 @@ class TTSAudioHandle(AudioHandler):
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def pause_talk(self):
|
||||
pass
|
||||
|
||||
|
||||
class TTSAudioSplitHandle(TTSAudioHandle):
|
||||
def __init__(self, context, handler):
|
||||
|
@ -59,3 +59,6 @@ class TTSBase(NLPCallback):
|
||||
def stop(self):
|
||||
self._message_queue.add_task(self._on_close)
|
||||
self._message_queue.stop()
|
||||
|
||||
def pause_talk(self):
|
||||
self._message_queue.clear()
|
||||
|
79
tts/tts_edge_http.py
Normal file
79
tts/tts_edge_http.py
Normal file
@ -0,0 +1,79 @@
|
||||
#encoding = utf8
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import edge_tts
|
||||
import resampy
|
||||
|
||||
from .tts_base import TTSBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTSEdgeHttp(TTSBase):
|
||||
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
|
||||
super().__init__(handle)
|
||||
self._voice = voice
|
||||
self._url = 'http://localhost:8082/v1/audio/speech'
|
||||
logger.info(f"TTSEdge init, {voice}")
|
||||
|
||||
async def _on_request(self, txt: str):
|
||||
print('_on_request, txt')
|
||||
data = {
|
||||
"model": "tts-1",
|
||||
"input": txt,
|
||||
"voice": "alloy",
|
||||
"speed": 1.0,
|
||||
"thread": 10
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self._url, json=data) as response:
|
||||
if response.status == 200:
|
||||
stream = BytesIO(await response.read())
|
||||
|
||||
print("Audio data received and saved to output_audio.wav")
|
||||
return stream
|
||||
else:
|
||||
byte_stream = None
|
||||
return byte_stream
|
||||
|
||||
async def _on_handle(self, stream, index):
|
||||
print('-------tts _on_handle')
|
||||
try:
|
||||
stream.seek(0)
|
||||
byte_stream = self.__create_bytes_stream(stream)
|
||||
print('-------tts start push chunk')
|
||||
self._handle.on_handle(byte_stream, index)
|
||||
stream.seek(0)
|
||||
stream.truncate()
|
||||
print('-------tts finish push chunk')
|
||||
|
||||
except Exception as e:
|
||||
self._handle.on_handle(None, index)
|
||||
stream.seek(0)
|
||||
stream.truncate()
|
||||
print('-------tts finish error:', e)
|
||||
stream.close()
|
||||
|
||||
def __create_bytes_stream(self, byte_stream):
|
||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||
stream = stream.astype(np.float32)
|
||||
|
||||
if stream.ndim > 1:
|
||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||
stream = stream[:, 0]
|
||||
|
||||
if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
|
||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
|
||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
|
||||
|
||||
return stream
|
||||
|
||||
async def _on_close(self):
|
||||
print('TTSEdge close')
|
||||
# if self._byte_stream is not None and not self._byte_stream.closed:
|
||||
# self._byte_stream.close()
|
@ -53,6 +53,11 @@ class AsyncTaskQueue:
|
||||
for _ in range(self._worker_num):
|
||||
self.add_task(None) # 发送结束信号
|
||||
|
||||
def clear(self):
|
||||
while not self._queue.empty():
|
||||
self._queue.get()
|
||||
self._queue.task_done()
|
||||
|
||||
def stop(self):
|
||||
self.stop_workers()
|
||||
self._thread.join()
|
||||
|
Loading…
Reference in New Issue
Block a user