add asr to tts
This commit is contained in:
parent
879bd5c825
commit
254e6a8359
@ -1,4 +1,5 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
|
||||||
|
from .nlp_callback import NLPCallback
|
||||||
from .nlp_doubao import DouBao
|
from .nlp_doubao import DouBao
|
||||||
from .nlp_split import PunctuationSplit
|
from .nlp_split import PunctuationSplit
|
||||||
|
@ -8,10 +8,23 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class NLPBase(AsrObserver):
|
class NLPBase(AsrObserver):
|
||||||
def __init__(self, split):
|
def __init__(self, split, callback=None):
|
||||||
self._ask_queue = AsyncTaskQueue()
|
self._ask_queue = AsyncTaskQueue()
|
||||||
self._ask_queue.start_worker()
|
self._ask_queue.start_worker()
|
||||||
self._split_handle = split
|
self._split_handle = split
|
||||||
|
self._callback = callback
|
||||||
|
|
||||||
|
@property
|
||||||
|
def callback(self):
|
||||||
|
return self._callback
|
||||||
|
|
||||||
|
@callback.setter
|
||||||
|
def callback(self, value):
|
||||||
|
self._callback = value
|
||||||
|
|
||||||
|
def _on_callback(self, txt: str):
|
||||||
|
if self._callback is not None:
|
||||||
|
self._callback.on_message(txt)
|
||||||
|
|
||||||
async def _request(self, question):
|
async def _request(self, question):
|
||||||
pass
|
pass
|
||||||
@ -20,10 +33,11 @@ class NLPBase(AsrObserver):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def completed(self, message: str):
|
def completed(self, message: str):
|
||||||
print('complete :', message)
|
logger.info(f'complete:{message}')
|
||||||
self.ask(message)
|
self.ask(message)
|
||||||
|
|
||||||
def ask(self, question):
|
def ask(self, question):
|
||||||
|
logger.info(f'ask:{question}')
|
||||||
self._ask_queue.add_task(self._request(question))
|
self._ask_queue.add_task(self._request(question))
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
9
nlp/nlp_callback.py
Normal file
9
nlp/nlp_callback.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class NLPCallback(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def on_message(self, txt: str):
|
||||||
|
pass
|
@ -1,21 +1,17 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from queue import Queue
|
|
||||||
|
|
||||||
import volcenginesdkark
|
import logging
|
||||||
import volcenginesdkcore
|
import time
|
||||||
from volcenginesdkcore.rest import ApiException
|
|
||||||
|
|
||||||
from nlp.nlp_base import NLPBase
|
from nlp.nlp_base import NLPBase
|
||||||
from volcenginesdkarkruntime import AsyncArk
|
from volcenginesdkarkruntime import AsyncArk
|
||||||
|
|
||||||
nlp_queue = Queue()
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DouBao(NLPBase):
|
class DouBao(NLPBase):
|
||||||
def __init__(self, split):
|
def __init__(self, split, callback=None):
|
||||||
super().__init__(split)
|
super().__init__(split, callback)
|
||||||
# Access Key ID
|
# Access Key ID
|
||||||
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
|
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
|
||||||
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
|
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
|
||||||
@ -33,6 +29,7 @@ class DouBao(NLPBase):
|
|||||||
|
|
||||||
async def _request(self, question):
|
async def _request(self, question):
|
||||||
t = time.time()
|
t = time.time()
|
||||||
|
logger.info(f'_request:{question}')
|
||||||
print(f'-------dou_bao ask:', question)
|
print(f'-------dou_bao ask:', question)
|
||||||
stream = await self.__client.chat.completions.create(
|
stream = await self.__client.chat.completions.create(
|
||||||
model="ep-20241008152048-fsgzf",
|
model="ep-20241008152048-fsgzf",
|
||||||
@ -51,7 +48,10 @@ class DouBao(NLPBase):
|
|||||||
sec, message = self._split_handle.handle(sec)
|
sec, message = self._split_handle.handle(sec)
|
||||||
if len(message) > 0:
|
if len(message) > 0:
|
||||||
print(message)
|
print(message)
|
||||||
|
self._on_callback(message)
|
||||||
print(sec)
|
print(sec)
|
||||||
|
self._on_callback(sec)
|
||||||
|
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
|
||||||
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||||
|
|
||||||
'''
|
'''
|
||||||
|
@ -11,7 +11,7 @@ class NLPSplit(ABC):
|
|||||||
|
|
||||||
class PunctuationSplit(NLPSplit):
|
class PunctuationSplit(NLPSplit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._pattern = r'[,。、;?!,.!?]'
|
self._pattern = r'(?<!\d)[,.,。?!:;、]'
|
||||||
|
|
||||||
def handle(self, message: str):
|
def handle(self, message: str):
|
||||||
match = re.search(self._pattern, message)
|
match = re.search(self._pattern, message)
|
||||||
|
46
test/asr_nlp_tts.py
Normal file
46
test/asr_nlp_tts.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from asr import SherpaNcnnAsr
|
||||||
|
from nlp import PunctuationSplit
|
||||||
|
from nlp.nlp_doubao import DouBao
|
||||||
|
from tts import TTSEdge
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
except ImportError as e:
|
||||||
|
print("Please install sounddevice first. You can use")
|
||||||
|
print()
|
||||||
|
print(" pip install sounddevice")
|
||||||
|
print()
|
||||||
|
print("to install it")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Started! Please speak")
|
||||||
|
tts = TTSEdge()
|
||||||
|
split = PunctuationSplit()
|
||||||
|
nlp = DouBao(split, tts)
|
||||||
|
asr = SherpaNcnnAsr()
|
||||||
|
asr.attach(nlp)
|
||||||
|
time.sleep(60)
|
||||||
|
print("Stop! ")
|
||||||
|
asr.stop()
|
||||||
|
asr.detach(nlp)
|
||||||
|
nlp.stop()
|
||||||
|
tts.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
devices = sd.query_devices()
|
||||||
|
print(devices)
|
||||||
|
default_input_device_idx = sd.default.device[0]
|
||||||
|
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nCaught Ctrl + C. Exiting")
|
28
test/test_tts_only.py
Normal file
28
test/test_tts_only.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from tts import TTSEdge
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Started! Please speak")
|
||||||
|
|
||||||
|
tts = TTSEdge()
|
||||||
|
tts.message('你好,')
|
||||||
|
tts.message('请问有什么可以帮到您,')
|
||||||
|
tts.message('很高兴为您服务。')
|
||||||
|
tts.message('祝您平安,')
|
||||||
|
tts.message('再见')
|
||||||
|
|
||||||
|
time.sleep(20)
|
||||||
|
tts.stop()
|
||||||
|
|
||||||
|
print("Stop! ")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nCaught Ctrl + C. Exiting")
|
3
tts/__init__.py
Normal file
3
tts/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
from .tts_edge import TTSEdge
|
40
tts/tts_base.py
Normal file
40
tts/tts_base.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from nlp import NLPCallback
|
||||||
|
from utils import AsyncTaskQueue
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TTSBase(NLPCallback):
|
||||||
|
def __init__(self):
|
||||||
|
self._sample_rate = 16000
|
||||||
|
self._message_queue = AsyncTaskQueue()
|
||||||
|
self._message_queue.start_worker()
|
||||||
|
|
||||||
|
async def _request(self, txt: str):
|
||||||
|
print('_request:', txt)
|
||||||
|
t = time.time()
|
||||||
|
await self._on_request(txt)
|
||||||
|
print(f'-------tts time:{time.time() - t:.4f}s')
|
||||||
|
await self._on_handle()
|
||||||
|
|
||||||
|
async def _on_request(self, text: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _on_handle(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_message(self, txt: str):
|
||||||
|
self.message(txt)
|
||||||
|
|
||||||
|
def message(self, txt):
|
||||||
|
logger.info(f'message:{txt}')
|
||||||
|
print(f'message:{txt}')
|
||||||
|
self._message_queue.add_task(self._request(txt))
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._message_queue.stop()
|
72
tts/tts_edge.py
Normal file
72
tts/tts_edge.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import edge_tts
|
||||||
|
import resampy
|
||||||
|
|
||||||
|
from audio import save_chunks, save_wav
|
||||||
|
from .tts_base import TTSBase
|
||||||
|
|
||||||
|
|
||||||
|
class TTSEdge(TTSBase):
|
||||||
|
def __init__(self, voice='zh-CN-XiaoyiNeural'):
|
||||||
|
super().__init__()
|
||||||
|
self._voice = voice
|
||||||
|
self._byte_stream = BytesIO()
|
||||||
|
self._count = 1
|
||||||
|
|
||||||
|
async def _on_request(self, txt: str):
|
||||||
|
communicate = edge_tts.Communicate(txt, self._voice)
|
||||||
|
first = True
|
||||||
|
async for chunk in communicate.stream():
|
||||||
|
if first:
|
||||||
|
first = False
|
||||||
|
if chunk["type"] == "audio":
|
||||||
|
# self.push_audio(chunk["data"])
|
||||||
|
self._byte_stream.write(chunk["data"])
|
||||||
|
# file.write(chunk["data"])
|
||||||
|
elif chunk["type"] == "WordBoundary":
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _on_handle(self):
|
||||||
|
self._byte_stream.seek(0)
|
||||||
|
try:
|
||||||
|
stream = self.__create_bytes_stream(self._byte_stream)
|
||||||
|
stream_len = stream.shape[0]
|
||||||
|
idx = 0
|
||||||
|
print('-------tts start push chunk')
|
||||||
|
save_wav(stream, '../temp/audio/' + str(self._count) + '.wav', 16000)
|
||||||
|
self._count = self._count + 1
|
||||||
|
# chunk = stream[0:]
|
||||||
|
# save_chunks(chunk, 16000, './temp/audio')
|
||||||
|
# while stream_len >= self.chunk:
|
||||||
|
# self._human.put_audio_frame(stream[idx:idx + self.chunk])
|
||||||
|
# streamlen -= self.chunk
|
||||||
|
# idx += self.chunk
|
||||||
|
# if streamlen>0: #skip last frame(not 20ms)
|
||||||
|
# self.queue.put(stream[idx:])
|
||||||
|
self._byte_stream.seek(0)
|
||||||
|
self._byte_stream.truncate()
|
||||||
|
print('-------tts finish push chunk')
|
||||||
|
except Exception as e:
|
||||||
|
self._byte_stream.seek(0)
|
||||||
|
self._byte_stream.truncate()
|
||||||
|
print('-------tts finish error:', e)
|
||||||
|
|
||||||
|
def __create_bytes_stream(self, byte_stream):
|
||||||
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||||
|
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||||
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
|
if stream.ndim > 1:
|
||||||
|
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||||
|
stream = stream[:, 0]
|
||||||
|
|
||||||
|
if sample_rate != self._sample_rate and stream.shape[0] > 0:
|
||||||
|
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._sample_rate}.')
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate)
|
||||||
|
|
||||||
|
return stream
|
@ -14,19 +14,24 @@ class AsyncTaskQueue:
|
|||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def _run_loop(self):
|
def _run_loop(self):
|
||||||
asyncio.set_event_loop(self._loop)
|
print('_run_loop')
|
||||||
self._loop_running.set()
|
self._loop_running.set()
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
self._loop.run_forever()
|
self._loop.run_forever()
|
||||||
|
|
||||||
async def _worker(self):
|
async def _worker(self):
|
||||||
|
print('_worker')
|
||||||
while self._loop_running.is_set():
|
while self._loop_running.is_set():
|
||||||
task = await self._queue.get()
|
task = await self._queue.get()
|
||||||
if task is None:
|
if task is None:
|
||||||
break
|
break
|
||||||
|
print('run task')
|
||||||
await task
|
await task
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
|
print('_worker finish')
|
||||||
|
|
||||||
def add_task(self, coro):
|
def add_task(self, coro):
|
||||||
|
print('add_task')
|
||||||
asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop)
|
asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop)
|
||||||
|
|
||||||
def start_worker(self):
|
def start_worker(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user