modfiy tts source
This commit is contained in:
parent
6b5361d91a
commit
055d1733f3
@ -7,7 +7,7 @@ from .audio_inference_handler import AudioInferenceHandler
|
|||||||
from .audio_mal_handler import AudioMalHandler
|
from .audio_mal_handler import AudioMalHandler
|
||||||
from .human_render import HumanRender
|
from .human_render import HumanRender
|
||||||
from nlp import PunctuationSplit, DouBao
|
from nlp import PunctuationSplit, DouBao
|
||||||
from tts import TTSEdge, TTSAudioSplitHandle
|
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
|
||||||
from utils import load_avatar, get_device
|
from utils import load_avatar, get_device
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -102,10 +102,14 @@ class HumanContext:
|
|||||||
self._infer_handler = AudioInferenceHandler(self, self._render_handler)
|
self._infer_handler = AudioInferenceHandler(self, self._render_handler)
|
||||||
self._mal_handler = AudioMalHandler(self, self._infer_handler)
|
self._mal_handler = AudioMalHandler(self, self._infer_handler)
|
||||||
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
|
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
|
||||||
self._tts = TTSEdge(self._tts_handle)
|
self._tts = TTSEdgeHttp(self._tts_handle)
|
||||||
split = PunctuationSplit()
|
split = PunctuationSplit()
|
||||||
self._nlp = DouBao(split, self._tts)
|
self._nlp = DouBao(split, self._tts)
|
||||||
self._asr = SherpaNcnnAsr()
|
self._asr = SherpaNcnnAsr()
|
||||||
self._asr.attach(self._nlp)
|
self._asr.attach(self._nlp)
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
self._nlp.pause_talk()
|
||||||
|
self._tts.pause_talk()
|
||||||
|
self._mal_handler.pause_talk()
|
||||||
|
self._infer_handler.pause_talk()
|
||||||
|
@ -23,3 +23,6 @@ class AudioHandler(ABC):
|
|||||||
self._handler.on_handle(stream, type_)
|
self._handler.on_handle(stream, type_)
|
||||||
else:
|
else:
|
||||||
logging.info(f'_handler is None')
|
logging.info(f'_handler is None')
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
pass
|
||||||
|
@ -46,3 +46,6 @@ class NLPBase(AsrObserver):
|
|||||||
self._ask_queue.add_task(self._on_close)
|
self._ask_queue.add_task(self._on_close)
|
||||||
self._ask_queue.stop()
|
self._ask_queue.stop()
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
self._ask_queue.clear()
|
||||||
|
|
79
test/test_mzzsfy_tts.py
Normal file
79
test/test_mzzsfy_tts.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
import time
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import resampy
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def download_tts(url):
|
||||||
|
file_name = url[3:]
|
||||||
|
print(file_name)
|
||||||
|
download_url = url
|
||||||
|
print('download tts', download_url)
|
||||||
|
resp = requests.get(download_url)
|
||||||
|
with open('./audio/mp3/' + file_name, 'wb') as mp3:
|
||||||
|
mp3.write(resp.content)
|
||||||
|
|
||||||
|
from pydub import AudioSegment
|
||||||
|
sound = AudioSegment.from_mp3('./audio/mp3/' + file_name)
|
||||||
|
sound.export('./audio/wav/' + file_name + '.wav', format="wav")
|
||||||
|
|
||||||
|
|
||||||
|
def __create_bytes_stream(byte_stream):
|
||||||
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||||
|
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||||
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
|
if stream.ndim > 1:
|
||||||
|
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||||
|
stream = stream[:, 0]
|
||||||
|
|
||||||
|
if sample_rate != 16000 and stream.shape[0] > 0:
|
||||||
|
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {16000}.')
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=16000)
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
import aiohttp
|
||||||
|
import asyncio
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
async def fetch_audio():
|
||||||
|
url = "http://localhost:8082/v1/audio/speech"
|
||||||
|
data = {
|
||||||
|
"model": "tts-1",
|
||||||
|
"input": "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用",
|
||||||
|
"voice": "alloy",
|
||||||
|
"speed": 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(url, json=data) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
audio_data = BytesIO(await response.read())
|
||||||
|
audio_stream = __create_bytes_stream(audio_data)
|
||||||
|
|
||||||
|
# 保存为新的音频文件
|
||||||
|
sf.write("output_audio.wav", audio_stream, 16000)
|
||||||
|
print("Audio data received and saved to output_audio.wav")
|
||||||
|
else:
|
||||||
|
print("Error:", response.status, await response.text())
|
||||||
|
|
||||||
|
# Run the async function
|
||||||
|
asyncio.run(fetch_audio())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
t = time.time()
|
||||||
|
main()
|
||||||
|
print(f'-------tts time:{time.time() - t:.4f}s')
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nCaught Ctrl + C. Exiting")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
@ -6,7 +6,7 @@ import time
|
|||||||
from asr import SherpaNcnnAsr
|
from asr import SherpaNcnnAsr
|
||||||
from nlp import PunctuationSplit
|
from nlp import PunctuationSplit
|
||||||
from nlp.nlp_doubao import DouBao
|
from nlp.nlp_doubao import DouBao
|
||||||
from tts import TTSEdge, TTSAudioSaveHandle
|
from tts import TTSEdge, TTSAudioSaveHandle, TTSEdgeHttp
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
@ -21,12 +21,14 @@ except ImportError as e:
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("Started! Please speak")
|
print("Started! Please speak")
|
||||||
handle = TTSAudioSaveHandle()
|
handle = TTSAudioSaveHandle(None, None)
|
||||||
tts = TTSEdge(handle)
|
# tts = TTSEdge(handle)
|
||||||
|
tts = TTSEdgeHttp(handle)
|
||||||
split = PunctuationSplit()
|
split = PunctuationSplit()
|
||||||
nlp = DouBao(split, tts)
|
nlp = DouBao(split, tts)
|
||||||
nlp.ask('你好,你是谁?')
|
nlp.ask('你好,你是谁?')
|
||||||
nlp.ask('可以帮我做什么?')
|
nlp.ask('可以帮我做什么?')
|
||||||
|
nlp.ask('背诵出师表')
|
||||||
nlp.stop()
|
nlp.stop()
|
||||||
tts.stop()
|
tts.stop()
|
||||||
print("Stop! ")
|
print("Stop! ")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
|
||||||
from .tts_edge import TTSEdge
|
from .tts_edge import TTSEdge
|
||||||
|
from .tts_edge_http import TTSEdgeHttp
|
||||||
from .tts_audio_handle import TTSAudioSplitHandle, TTSAudioSaveHandle
|
from .tts_audio_handle import TTSAudioSplitHandle, TTSAudioSaveHandle
|
||||||
|
@ -34,6 +34,9 @@ class TTSAudioHandle(AudioHandler):
|
|||||||
def stop(self):
|
def stop(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TTSAudioSplitHandle(TTSAudioHandle):
|
class TTSAudioSplitHandle(TTSAudioHandle):
|
||||||
def __init__(self, context, handler):
|
def __init__(self, context, handler):
|
||||||
|
@ -59,3 +59,6 @@ class TTSBase(NLPCallback):
|
|||||||
def stop(self):
|
def stop(self):
|
||||||
self._message_queue.add_task(self._on_close)
|
self._message_queue.add_task(self._on_close)
|
||||||
self._message_queue.stop()
|
self._message_queue.stop()
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
self._message_queue.clear()
|
||||||
|
79
tts/tts_edge_http.py
Normal file
79
tts/tts_edge_http.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
import logging
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import edge_tts
|
||||||
|
import resampy
|
||||||
|
|
||||||
|
from .tts_base import TTSBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TTSEdgeHttp(TTSBase):
|
||||||
|
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
|
||||||
|
super().__init__(handle)
|
||||||
|
self._voice = voice
|
||||||
|
self._url = 'http://localhost:8082/v1/audio/speech'
|
||||||
|
logger.info(f"TTSEdge init, {voice}")
|
||||||
|
|
||||||
|
async def _on_request(self, txt: str):
|
||||||
|
print('_on_request, txt')
|
||||||
|
data = {
|
||||||
|
"model": "tts-1",
|
||||||
|
"input": txt,
|
||||||
|
"voice": "alloy",
|
||||||
|
"speed": 1.0,
|
||||||
|
"thread": 10
|
||||||
|
}
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(self._url, json=data) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
stream = BytesIO(await response.read())
|
||||||
|
|
||||||
|
print("Audio data received and saved to output_audio.wav")
|
||||||
|
return stream
|
||||||
|
else:
|
||||||
|
byte_stream = None
|
||||||
|
return byte_stream
|
||||||
|
|
||||||
|
async def _on_handle(self, stream, index):
|
||||||
|
print('-------tts _on_handle')
|
||||||
|
try:
|
||||||
|
stream.seek(0)
|
||||||
|
byte_stream = self.__create_bytes_stream(stream)
|
||||||
|
print('-------tts start push chunk')
|
||||||
|
self._handle.on_handle(byte_stream, index)
|
||||||
|
stream.seek(0)
|
||||||
|
stream.truncate()
|
||||||
|
print('-------tts finish push chunk')
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._handle.on_handle(None, index)
|
||||||
|
stream.seek(0)
|
||||||
|
stream.truncate()
|
||||||
|
print('-------tts finish error:', e)
|
||||||
|
stream.close()
|
||||||
|
|
||||||
|
def __create_bytes_stream(self, byte_stream):
|
||||||
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||||
|
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||||
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
|
if stream.ndim > 1:
|
||||||
|
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||||
|
stream = stream[:, 0]
|
||||||
|
|
||||||
|
if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
|
||||||
|
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
async def _on_close(self):
|
||||||
|
print('TTSEdge close')
|
||||||
|
# if self._byte_stream is not None and not self._byte_stream.closed:
|
||||||
|
# self._byte_stream.close()
|
@ -53,6 +53,11 @@ class AsyncTaskQueue:
|
|||||||
for _ in range(self._worker_num):
|
for _ in range(self._worker_num):
|
||||||
self.add_task(None) # 发送结束信号
|
self.add_task(None) # 发送结束信号
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
while not self._queue.empty():
|
||||||
|
self._queue.get()
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.stop_workers()
|
self.stop_workers()
|
||||||
self._thread.join()
|
self._thread.join()
|
||||||
|
Loading…
Reference in New Issue
Block a user