add tts audio tts

This commit is contained in:
brige 2024-10-11 20:09:54 +08:00
parent 254e6a8359
commit 9372020747
6 changed files with 96 additions and 26 deletions

View File

@ -6,7 +6,7 @@ import time
from asr import SherpaNcnnAsr from asr import SherpaNcnnAsr
from nlp import PunctuationSplit from nlp import PunctuationSplit
from nlp.nlp_doubao import DouBao from nlp.nlp_doubao import DouBao
from tts import TTSEdge from tts import TTSEdge, TTSAudioSaveHandle
try: try:
import sounddevice as sd import sounddevice as sd
@ -21,7 +21,8 @@ except ImportError as e:
def main(): def main():
print("Started! Please speak") print("Started! Please speak")
tts = TTSEdge() handle = TTSAudioSaveHandle()
tts = TTSEdge(handle)
split = PunctuationSplit() split = PunctuationSplit()
nlp = DouBao(split, tts) nlp = DouBao(split, tts)
asr = SherpaNcnnAsr() asr = SherpaNcnnAsr()

View File

@ -74,7 +74,8 @@ class TTSBase:
stream = stream[:, 0] stream = stream[:, 0]
if sample_rate != self._human.get_audio_sample_rate() and stream.shape[0] > 0: if sample_rate != self._human.get_audio_sample_rate() and stream.shape[0] > 0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._human.get_audio_sample_rate()}.') print(f'[WARN] audio sample rate is {sample_rate},'
f'resampling into {self._human.get_audio_sample_rate()}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._human.get_audio_sample_rate()) stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._human.get_audio_sample_rate())
return stream return stream

View File

@ -1,3 +1,4 @@
#encoding = utf8 #encoding = utf8
from .tts_edge import TTSEdge from .tts_edge import TTSEdge
from .tts_audio_handle import TTSAudioSplitHandle, TTSAudioSaveHandle

72
tts/tts_audio_handle.py Normal file
View File

@ -0,0 +1,72 @@
#encoding = utf8
import os
import shutil
from abc import ABC, abstractmethod
from audio import save_wav
class TTSAudioHandle(ABC):
def __init__(self):
self._sample_rate = 16000
@property
def sample_rate(self):
return self._sample_rate
@sample_rate.setter
def sample_rate(self, value):
self._sample_rate = value
@abstractmethod
def on_handle(self, stream):
pass
class TTSAudioSplitHandle(TTSAudioHandle):
def __init__(self, human):
super().__init__()
self._human = human
self.sample_rate = self._human.get_audio_sample_rate()
self._chunk = self.sample_rate // self._human.get_fps()
def on_handle(self, stream):
stream_len = stream.shape[0]
idx = 0
while stream_len >= self._chunk:
self._human.put_audio_frame(stream[idx:idx + self._chunk])
stream_len -= self._chunk
idx += self._chunk
class TTSAudioSaveHandle(TTSAudioHandle):
def __init__(self):
super().__init__()
self._count = 1
self._save_path_dir = '../temp/audio/'
self._clean()
def _clean(self):
directory = self._save_path_dir
if not os.path.exists(directory):
print(f"The directory {directory} does not exist.")
return
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
# 如果是文件,删除
if os.path.isfile(file_path):
os.remove(file_path)
print(f"Deleted file: {file_path}")
# 如果是文件夹,递归删除所有文件夹中的内容
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
print(f"Deleted directory and its contents: {file_path}")
def on_handle(self, stream):
file_name = self._save_path_dir + str(self._count) + '.wav'
save_wav(stream, file_name, self.sample_rate)
self._count = self._count + 1

View File

@ -10,11 +10,19 @@ logger = logging.getLogger(__name__)
class TTSBase(NLPCallback): class TTSBase(NLPCallback):
def __init__(self): def __init__(self, handle):
self._sample_rate = 16000 self._handle = handle
self._message_queue = AsyncTaskQueue() self._message_queue = AsyncTaskQueue()
self._message_queue.start_worker() self._message_queue.start_worker()
@property
def handle(self):
return self._handle
@handle.setter
def handle(self, value):
self._handle = value
async def _request(self, txt: str): async def _request(self, txt: str):
print('_request:', txt) print('_request:', txt)
t = time.time() t = time.time()

View File

@ -7,16 +7,14 @@ import soundfile as sf
import edge_tts import edge_tts
import resampy import resampy
from audio import save_chunks, save_wav
from .tts_base import TTSBase from .tts_base import TTSBase
class TTSEdge(TTSBase): class TTSEdge(TTSBase):
def __init__(self, voice='zh-CN-XiaoyiNeural'): def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
super().__init__() super().__init__(handle)
self._voice = voice self._voice = voice
self._byte_stream = BytesIO() self._byte_stream = BytesIO()
self._count = 1
async def _on_request(self, txt: str): async def _on_request(self, txt: str):
communicate = edge_tts.Communicate(txt, self._voice) communicate = edge_tts.Communicate(txt, self._voice)
@ -25,32 +23,21 @@ class TTSEdge(TTSBase):
if first: if first:
first = False first = False
if chunk["type"] == "audio": if chunk["type"] == "audio":
# self.push_audio(chunk["data"])
self._byte_stream.write(chunk["data"]) self._byte_stream.write(chunk["data"])
# file.write(chunk["data"])
elif chunk["type"] == "WordBoundary": elif chunk["type"] == "WordBoundary":
pass pass
async def _on_handle(self): async def _on_handle(self):
self._byte_stream.seek(0) self._byte_stream.seek(0)
try: try:
self._byte_stream.seek(0)
stream = self.__create_bytes_stream(self._byte_stream) stream = self.__create_bytes_stream(self._byte_stream)
stream_len = stream.shape[0]
idx = 0
print('-------tts start push chunk') print('-------tts start push chunk')
save_wav(stream, '../temp/audio/' + str(self._count) + '.wav', 16000) self._handle.on_handle(stream)
self._count = self._count + 1
# chunk = stream[0:]
# save_chunks(chunk, 16000, './temp/audio')
# while stream_len >= self.chunk:
# self._human.put_audio_frame(stream[idx:idx + self.chunk])
# streamlen -= self.chunk
# idx += self.chunk
# if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
self._byte_stream.seek(0) self._byte_stream.seek(0)
self._byte_stream.truncate() self._byte_stream.truncate()
print('-------tts finish push chunk') print('-------tts finish push chunk')
except Exception as e: except Exception as e:
self._byte_stream.seek(0) self._byte_stream.seek(0)
self._byte_stream.truncate() self._byte_stream.truncate()
@ -65,8 +52,8 @@ class TTSEdge(TTSBase):
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0] stream = stream[:, 0]
if sample_rate != self._sample_rate and stream.shape[0] > 0: if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._sample_rate}.') print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate) stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
return stream return stream