human/tts/tts_edge_http.py
2024-11-16 10:13:11 +08:00

105 lines
3.4 KiB
Python

#encoding = utf8
import logging
import time
from io import BytesIO
import aiohttp
import numpy as np
import requests
import soundfile as sf
import edge_tts
import resampy
from .tts_base import TTSBase
logger = logging.getLogger(__name__)
class TTSEdgeHttp(TTSBase):
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
super().__init__(handle)
self._voice = voice
# self._url = 'http://localhost:8082/v1/audio/speech'
self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech'
logger.info(f"TTSEdge init, {voice}")
self._response_list = []
def _on_async_request(self, data):
with aiohttp.ClientSession() as session:
with session.post(self._url, json=data) as response:
print('TTSEdgeHttp, _on_request, response:', response)
if response.status == 200:
stream = BytesIO(response.read())
return stream
else:
byte_stream = None
return byte_stream, None
def _on_sync_request(self, data):
response = requests.post(self._url, json=data)
self._response_list.append(response)
stream = None
if response.status_code == 200:
stream = BytesIO(response.content)
self._response_list.remove(response)
return stream
def _on_request(self, txt: str):
logger.info(f'TTSEdgeHttp, _on_request, txt:{txt}')
data = {
"model": "tts-1",
"input": txt,
"voice": "alloy",
"speed": 1.0,
"thread": 10
}
# return self._on_async_request(data)
return self._on_sync_request(data)
def _on_handle(self, stream, index):
st, txt = stream
try:
st.seek(0)
t = time.time()
byte_stream = self.__create_bytes_stream(st)
logger.info(f'-------tts resample time:{time.time() - t:.4f}s, txt:{txt}')
t = time.time()
self._handle.on_handle((byte_stream, txt), index)
logger.info(f'-------tts handle time:{time.time() - t:.4f}s')
st.seek(0)
st.truncate()
except Exception as e:
self._handle.on_handle(None, index)
st.seek(0)
st.truncate()
logger.error(f'-------tts finish error:{e}')
st.close()
def __create_bytes_stream(self, byte_stream):
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
return stream
def _on_close(self):
print('TTSEdge close')
# if self._byte_stream is not None and not self._byte_stream.closed:
# self._byte_stream.close()
def on_clear_cache(self, *args, **kwargs):
logger.info('TTSEdgeHttp clear_cache')
super().on_clear_cache(*args, **kwargs)
for response in self._response_list:
response.close()