add asr to tts

This commit is contained in:
brige 2024-10-10 19:01:13 +08:00
parent 879bd5c825
commit 254e6a8359
11 changed files with 231 additions and 13 deletions

View File

@ -1,4 +1,5 @@
#encoding = utf8 #encoding = utf8
from .nlp_callback import NLPCallback
from .nlp_doubao import DouBao from .nlp_doubao import DouBao
from .nlp_split import PunctuationSplit from .nlp_split import PunctuationSplit

View File

@ -8,10 +8,23 @@ logger = logging.getLogger(__name__)
class NLPBase(AsrObserver): class NLPBase(AsrObserver):
def __init__(self, split): def __init__(self, split, callback=None):
self._ask_queue = AsyncTaskQueue() self._ask_queue = AsyncTaskQueue()
self._ask_queue.start_worker() self._ask_queue.start_worker()
self._split_handle = split self._split_handle = split
self._callback = callback
@property
def callback(self):
return self._callback
@callback.setter
def callback(self, value):
self._callback = value
def _on_callback(self, txt: str):
if self._callback is not None:
self._callback.on_message(txt)
async def _request(self, question): async def _request(self, question):
pass pass
@ -20,10 +33,11 @@ class NLPBase(AsrObserver):
pass pass
def completed(self, message: str): def completed(self, message: str):
print('complete :', message) logger.info(f'complete:{message}')
self.ask(message) self.ask(message)
def ask(self, question): def ask(self, question):
logger.info(f'ask:{question}')
self._ask_queue.add_task(self._request(question)) self._ask_queue.add_task(self._request(question))
def stop(self): def stop(self):

9
nlp/nlp_callback.py Normal file
View File

@ -0,0 +1,9 @@
#encoding = utf8
from abc import ABC, abstractmethod
class NLPCallback(ABC):
@abstractmethod
def on_message(self, txt: str):
pass

View File

@ -1,21 +1,17 @@
#encoding = utf8 #encoding = utf8
import re
import time
from queue import Queue
import volcenginesdkark import logging
import volcenginesdkcore import time
from volcenginesdkcore.rest import ApiException
from nlp.nlp_base import NLPBase from nlp.nlp_base import NLPBase
from volcenginesdkarkruntime import AsyncArk from volcenginesdkarkruntime import AsyncArk
nlp_queue = Queue() logger = logging.getLogger(__name__)
class DouBao(NLPBase): class DouBao(NLPBase):
def __init__(self, split): def __init__(self, split, callback=None):
super().__init__(split) super().__init__(split, callback)
# Access Key ID # Access Key ID
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y # AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc # AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
@ -33,6 +29,7 @@ class DouBao(NLPBase):
async def _request(self, question): async def _request(self, question):
t = time.time() t = time.time()
logger.info(f'_request:{question}')
print(f'-------dou_bao ask:', question) print(f'-------dou_bao ask:', question)
stream = await self.__client.chat.completions.create( stream = await self.__client.chat.completions.create(
model="ep-20241008152048-fsgzf", model="ep-20241008152048-fsgzf",
@ -51,7 +48,10 @@ class DouBao(NLPBase):
sec, message = self._split_handle.handle(sec) sec, message = self._split_handle.handle(sec)
if len(message) > 0: if len(message) > 0:
print(message) print(message)
self._on_callback(message)
print(sec) print(sec)
self._on_callback(sec)
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
''' '''

View File

@ -11,7 +11,7 @@ class NLPSplit(ABC):
class PunctuationSplit(NLPSplit): class PunctuationSplit(NLPSplit):
def __init__(self): def __init__(self):
self._pattern = r'[,。、;?!,.!?]' self._pattern = r'(?<!\d)[,.,。?!:;、]'
def handle(self, message: str): def handle(self, message: str):
match = re.search(self._pattern, message) match = re.search(self._pattern, message)

46
test/asr_nlp_tts.py Normal file
View File

@ -0,0 +1,46 @@
#encoding = utf8
import sys
import time
from asr import SherpaNcnnAsr
from nlp import PunctuationSplit
from nlp.nlp_doubao import DouBao
from tts import TTSEdge
try:
import sounddevice as sd
except ImportError as e:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
def main():
print("Started! Please speak")
tts = TTSEdge()
split = PunctuationSplit()
nlp = DouBao(split, tts)
asr = SherpaNcnnAsr()
asr.attach(nlp)
time.sleep(60)
print("Stop! ")
asr.stop()
asr.detach(nlp)
nlp.stop()
tts.stop()
if __name__ == "__main__":
devices = sd.query_devices()
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")

28
test/test_tts_only.py Normal file
View File

@ -0,0 +1,28 @@
#encoding = utf8
import time
from tts import TTSEdge
def main():
print("Started! Please speak")
tts = TTSEdge()
tts.message('你好,')
tts.message('请问有什么可以帮到您,')
tts.message('很高兴为您服务。')
tts.message('祝您平安,')
tts.message('再见')
time.sleep(20)
tts.stop()
print("Stop! ")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")

3
tts/__init__.py Normal file
View File

@ -0,0 +1,3 @@
#encoding = utf8
from .tts_edge import TTSEdge

40
tts/tts_base.py Normal file
View File

@ -0,0 +1,40 @@
#encoding = utf8
import logging
import time
from nlp import NLPCallback
from utils import AsyncTaskQueue
logger = logging.getLogger(__name__)
class TTSBase(NLPCallback):
def __init__(self):
self._sample_rate = 16000
self._message_queue = AsyncTaskQueue()
self._message_queue.start_worker()
async def _request(self, txt: str):
print('_request:', txt)
t = time.time()
await self._on_request(txt)
print(f'-------tts time:{time.time() - t:.4f}s')
await self._on_handle()
async def _on_request(self, text: str):
pass
async def _on_handle(self):
pass
def on_message(self, txt: str):
self.message(txt)
def message(self, txt):
logger.info(f'message:{txt}')
print(f'message:{txt}')
self._message_queue.add_task(self._request(txt))
def stop(self):
self._message_queue.stop()

72
tts/tts_edge.py Normal file
View File

@ -0,0 +1,72 @@
#encoding = utf8
from io import BytesIO
import numpy as np
import soundfile as sf
import edge_tts
import resampy
from audio import save_chunks, save_wav
from .tts_base import TTSBase
class TTSEdge(TTSBase):
def __init__(self, voice='zh-CN-XiaoyiNeural'):
super().__init__()
self._voice = voice
self._byte_stream = BytesIO()
self._count = 1
async def _on_request(self, txt: str):
communicate = edge_tts.Communicate(txt, self._voice)
first = True
async for chunk in communicate.stream():
if first:
first = False
if chunk["type"] == "audio":
# self.push_audio(chunk["data"])
self._byte_stream.write(chunk["data"])
# file.write(chunk["data"])
elif chunk["type"] == "WordBoundary":
pass
async def _on_handle(self):
self._byte_stream.seek(0)
try:
stream = self.__create_bytes_stream(self._byte_stream)
stream_len = stream.shape[0]
idx = 0
print('-------tts start push chunk')
save_wav(stream, '../temp/audio/' + str(self._count) + '.wav', 16000)
self._count = self._count + 1
# chunk = stream[0:]
# save_chunks(chunk, 16000, './temp/audio')
# while stream_len >= self.chunk:
# self._human.put_audio_frame(stream[idx:idx + self.chunk])
# streamlen -= self.chunk
# idx += self.chunk
# if streamlen>0: #skip last frame(not 20ms)
# self.queue.put(stream[idx:])
self._byte_stream.seek(0)
self._byte_stream.truncate()
print('-------tts finish push chunk')
except Exception as e:
self._byte_stream.seek(0)
self._byte_stream.truncate()
print('-------tts finish error:', e)
def __create_bytes_stream(self, byte_stream):
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0]
if sample_rate != self._sample_rate and stream.shape[0] > 0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate)
return stream

View File

@ -14,19 +14,24 @@ class AsyncTaskQueue:
self._thread.start() self._thread.start()
def _run_loop(self): def _run_loop(self):
asyncio.set_event_loop(self._loop) print('_run_loop')
self._loop_running.set() self._loop_running.set()
asyncio.set_event_loop(self._loop)
self._loop.run_forever() self._loop.run_forever()
async def _worker(self): async def _worker(self):
print('_worker')
while self._loop_running.is_set(): while self._loop_running.is_set():
task = await self._queue.get() task = await self._queue.get()
if task is None: if task is None:
break break
print('run task')
await task await task
self._queue.task_done() self._queue.task_done()
print('_worker finish')
def add_task(self, coro): def add_task(self, coro):
print('add_task')
asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop) asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop)
def start_worker(self): def start_worker(self):