add asr to tts
This commit is contained in:
parent
879bd5c825
commit
254e6a8359
@ -1,4 +1,5 @@
|
||||
#encoding = utf8
|
||||
|
||||
from .nlp_callback import NLPCallback
|
||||
from .nlp_doubao import DouBao
|
||||
from .nlp_split import PunctuationSplit
|
||||
|
@ -8,10 +8,23 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NLPBase(AsrObserver):
|
||||
def __init__(self, split):
|
||||
def __init__(self, split, callback=None):
|
||||
self._ask_queue = AsyncTaskQueue()
|
||||
self._ask_queue.start_worker()
|
||||
self._split_handle = split
|
||||
self._callback = callback
|
||||
|
||||
@property
|
||||
def callback(self):
|
||||
return self._callback
|
||||
|
||||
@callback.setter
|
||||
def callback(self, value):
|
||||
self._callback = value
|
||||
|
||||
def _on_callback(self, txt: str):
|
||||
if self._callback is not None:
|
||||
self._callback.on_message(txt)
|
||||
|
||||
async def _request(self, question):
|
||||
pass
|
||||
@ -20,10 +33,11 @@ class NLPBase(AsrObserver):
|
||||
pass
|
||||
|
||||
def completed(self, message: str):
|
||||
print('complete :', message)
|
||||
logger.info(f'complete:{message}')
|
||||
self.ask(message)
|
||||
|
||||
def ask(self, question):
|
||||
logger.info(f'ask:{question}')
|
||||
self._ask_queue.add_task(self._request(question))
|
||||
|
||||
def stop(self):
|
||||
|
9
nlp/nlp_callback.py
Normal file
9
nlp/nlp_callback.py
Normal file
@ -0,0 +1,9 @@
|
||||
#encoding = utf8
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class NLPCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_message(self, txt: str):
|
||||
pass
|
@ -1,21 +1,17 @@
|
||||
#encoding = utf8
|
||||
import re
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
import volcenginesdkark
|
||||
import volcenginesdkcore
|
||||
from volcenginesdkcore.rest import ApiException
|
||||
import logging
|
||||
import time
|
||||
|
||||
from nlp.nlp_base import NLPBase
|
||||
from volcenginesdkarkruntime import AsyncArk
|
||||
|
||||
nlp_queue = Queue()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DouBao(NLPBase):
|
||||
def __init__(self, split):
|
||||
super().__init__(split)
|
||||
def __init__(self, split, callback=None):
|
||||
super().__init__(split, callback)
|
||||
# Access Key ID
|
||||
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
|
||||
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
|
||||
@ -33,6 +29,7 @@ class DouBao(NLPBase):
|
||||
|
||||
async def _request(self, question):
|
||||
t = time.time()
|
||||
logger.info(f'_request:{question}')
|
||||
print(f'-------dou_bao ask:', question)
|
||||
stream = await self.__client.chat.completions.create(
|
||||
model="ep-20241008152048-fsgzf",
|
||||
@ -51,7 +48,10 @@ class DouBao(NLPBase):
|
||||
sec, message = self._split_handle.handle(sec)
|
||||
if len(message) > 0:
|
||||
print(message)
|
||||
self._on_callback(message)
|
||||
print(sec)
|
||||
self._on_callback(sec)
|
||||
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
|
||||
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
|
||||
'''
|
||||
|
@ -11,7 +11,7 @@ class NLPSplit(ABC):
|
||||
|
||||
class PunctuationSplit(NLPSplit):
|
||||
def __init__(self):
|
||||
self._pattern = r'[,。、;?!,.!?]'
|
||||
self._pattern = r'(?<!\d)[,.,。?!:;、]'
|
||||
|
||||
def handle(self, message: str):
|
||||
match = re.search(self._pattern, message)
|
||||
|
46
test/asr_nlp_tts.py
Normal file
46
test/asr_nlp_tts.py
Normal file
@ -0,0 +1,46 @@
|
||||
#encoding = utf8
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
from asr import SherpaNcnnAsr
|
||||
from nlp import PunctuationSplit
|
||||
from nlp.nlp_doubao import DouBao
|
||||
from tts import TTSEdge
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError as e:
|
||||
print("Please install sounddevice first. You can use")
|
||||
print()
|
||||
print(" pip install sounddevice")
|
||||
print()
|
||||
print("to install it")
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def main():
|
||||
print("Started! Please speak")
|
||||
tts = TTSEdge()
|
||||
split = PunctuationSplit()
|
||||
nlp = DouBao(split, tts)
|
||||
asr = SherpaNcnnAsr()
|
||||
asr.attach(nlp)
|
||||
time.sleep(60)
|
||||
print("Stop! ")
|
||||
asr.stop()
|
||||
asr.detach(nlp)
|
||||
nlp.stop()
|
||||
tts.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
devices = sd.query_devices()
|
||||
print(devices)
|
||||
default_input_device_idx = sd.default.device[0]
|
||||
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nCaught Ctrl + C. Exiting")
|
28
test/test_tts_only.py
Normal file
28
test/test_tts_only.py
Normal file
@ -0,0 +1,28 @@
|
||||
#encoding = utf8
|
||||
|
||||
import time
|
||||
|
||||
from tts import TTSEdge
|
||||
|
||||
|
||||
def main():
|
||||
print("Started! Please speak")
|
||||
|
||||
tts = TTSEdge()
|
||||
tts.message('你好,')
|
||||
tts.message('请问有什么可以帮到您,')
|
||||
tts.message('很高兴为您服务。')
|
||||
tts.message('祝您平安,')
|
||||
tts.message('再见')
|
||||
|
||||
time.sleep(20)
|
||||
tts.stop()
|
||||
|
||||
print("Stop! ")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nCaught Ctrl + C. Exiting")
|
3
tts/__init__.py
Normal file
3
tts/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
#encoding = utf8
|
||||
|
||||
from .tts_edge import TTSEdge
|
40
tts/tts_base.py
Normal file
40
tts/tts_base.py
Normal file
@ -0,0 +1,40 @@
|
||||
#encoding = utf8
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from nlp import NLPCallback
|
||||
from utils import AsyncTaskQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTSBase(NLPCallback):
|
||||
def __init__(self):
|
||||
self._sample_rate = 16000
|
||||
self._message_queue = AsyncTaskQueue()
|
||||
self._message_queue.start_worker()
|
||||
|
||||
async def _request(self, txt: str):
|
||||
print('_request:', txt)
|
||||
t = time.time()
|
||||
await self._on_request(txt)
|
||||
print(f'-------tts time:{time.time() - t:.4f}s')
|
||||
await self._on_handle()
|
||||
|
||||
async def _on_request(self, text: str):
|
||||
pass
|
||||
|
||||
async def _on_handle(self):
|
||||
pass
|
||||
|
||||
def on_message(self, txt: str):
|
||||
self.message(txt)
|
||||
|
||||
def message(self, txt):
|
||||
logger.info(f'message:{txt}')
|
||||
print(f'message:{txt}')
|
||||
self._message_queue.add_task(self._request(txt))
|
||||
|
||||
def stop(self):
|
||||
self._message_queue.stop()
|
72
tts/tts_edge.py
Normal file
72
tts/tts_edge.py
Normal file
@ -0,0 +1,72 @@
|
||||
#encoding = utf8
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import edge_tts
|
||||
import resampy
|
||||
|
||||
from audio import save_chunks, save_wav
|
||||
from .tts_base import TTSBase
|
||||
|
||||
|
||||
class TTSEdge(TTSBase):
|
||||
def __init__(self, voice='zh-CN-XiaoyiNeural'):
|
||||
super().__init__()
|
||||
self._voice = voice
|
||||
self._byte_stream = BytesIO()
|
||||
self._count = 1
|
||||
|
||||
async def _on_request(self, txt: str):
|
||||
communicate = edge_tts.Communicate(txt, self._voice)
|
||||
first = True
|
||||
async for chunk in communicate.stream():
|
||||
if first:
|
||||
first = False
|
||||
if chunk["type"] == "audio":
|
||||
# self.push_audio(chunk["data"])
|
||||
self._byte_stream.write(chunk["data"])
|
||||
# file.write(chunk["data"])
|
||||
elif chunk["type"] == "WordBoundary":
|
||||
pass
|
||||
|
||||
async def _on_handle(self):
|
||||
self._byte_stream.seek(0)
|
||||
try:
|
||||
stream = self.__create_bytes_stream(self._byte_stream)
|
||||
stream_len = stream.shape[0]
|
||||
idx = 0
|
||||
print('-------tts start push chunk')
|
||||
save_wav(stream, '../temp/audio/' + str(self._count) + '.wav', 16000)
|
||||
self._count = self._count + 1
|
||||
# chunk = stream[0:]
|
||||
# save_chunks(chunk, 16000, './temp/audio')
|
||||
# while stream_len >= self.chunk:
|
||||
# self._human.put_audio_frame(stream[idx:idx + self.chunk])
|
||||
# streamlen -= self.chunk
|
||||
# idx += self.chunk
|
||||
# if streamlen>0: #skip last frame(not 20ms)
|
||||
# self.queue.put(stream[idx:])
|
||||
self._byte_stream.seek(0)
|
||||
self._byte_stream.truncate()
|
||||
print('-------tts finish push chunk')
|
||||
except Exception as e:
|
||||
self._byte_stream.seek(0)
|
||||
self._byte_stream.truncate()
|
||||
print('-------tts finish error:', e)
|
||||
|
||||
def __create_bytes_stream(self, byte_stream):
|
||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||
stream = stream.astype(np.float32)
|
||||
|
||||
if stream.ndim > 1:
|
||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||
stream = stream[:, 0]
|
||||
|
||||
if sample_rate != self._sample_rate and stream.shape[0] > 0:
|
||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._sample_rate}.')
|
||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate)
|
||||
|
||||
return stream
|
@ -14,19 +14,24 @@ class AsyncTaskQueue:
|
||||
self._thread.start()
|
||||
|
||||
def _run_loop(self):
|
||||
asyncio.set_event_loop(self._loop)
|
||||
print('_run_loop')
|
||||
self._loop_running.set()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._loop.run_forever()
|
||||
|
||||
async def _worker(self):
|
||||
print('_worker')
|
||||
while self._loop_running.is_set():
|
||||
task = await self._queue.get()
|
||||
if task is None:
|
||||
break
|
||||
print('run task')
|
||||
await task
|
||||
self._queue.task_done()
|
||||
print('_worker finish')
|
||||
|
||||
def add_task(self, coro):
|
||||
print('add_task')
|
||||
asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop)
|
||||
|
||||
def start_worker(self):
|
||||
|
Loading…
Reference in New Issue
Block a user