From 254e6a835998010d43d09b4a585d4ffb9c84892f Mon Sep 17 00:00:00 2001 From: brige Date: Thu, 10 Oct 2024 19:01:13 +0800 Subject: [PATCH] add asr to tts --- nlp/__init__.py | 1 + nlp/nlp_base.py | 18 ++++++++-- nlp/nlp_callback.py | 9 +++++ nlp/nlp_doubao.py | 18 +++++----- nlp/nlp_split.py | 2 +- test/asr_nlp_tts.py | 46 +++++++++++++++++++++++++ test/test_tts_only.py | 28 +++++++++++++++ tts/__init__.py | 3 ++ tts/tts_base.py | 40 ++++++++++++++++++++++ tts/tts_edge.py | 72 +++++++++++++++++++++++++++++++++++++++ utils/async_task_queue.py | 7 +++- 11 files changed, 231 insertions(+), 13 deletions(-) create mode 100644 nlp/nlp_callback.py create mode 100644 test/asr_nlp_tts.py create mode 100644 test/test_tts_only.py create mode 100644 tts/__init__.py create mode 100644 tts/tts_base.py create mode 100644 tts/tts_edge.py diff --git a/nlp/__init__.py b/nlp/__init__.py index 050c481..3ebb749 100644 --- a/nlp/__init__.py +++ b/nlp/__init__.py @@ -1,4 +1,5 @@ #encoding = utf8 +from .nlp_callback import NLPCallback from .nlp_doubao import DouBao from .nlp_split import PunctuationSplit diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py index b22b62a..9f3410a 100644 --- a/nlp/nlp_base.py +++ b/nlp/nlp_base.py @@ -8,10 +8,23 @@ logger = logging.getLogger(__name__) class NLPBase(AsrObserver): - def __init__(self, split): + def __init__(self, split, callback=None): self._ask_queue = AsyncTaskQueue() self._ask_queue.start_worker() self._split_handle = split + self._callback = callback + + @property + def callback(self): + return self._callback + + @callback.setter + def callback(self, value): + self._callback = value + + def _on_callback(self, txt: str): + if self._callback is not None: + self._callback.on_message(txt) async def _request(self, question): pass @@ -20,10 +33,11 @@ class NLPBase(AsrObserver): pass def completed(self, message: str): - print('complete :', message) + logger.info(f'complete:{message}') self.ask(message) def ask(self, question): + logger.info(f'ask:{question}') self._ask_queue.add_task(self._request(question)) def stop(self): diff --git a/nlp/nlp_callback.py b/nlp/nlp_callback.py new file mode 100644 index 0000000..96619b2 --- /dev/null +++ b/nlp/nlp_callback.py @@ -0,0 +1,9 @@ +#encoding = utf8 + +from abc import ABC, abstractmethod + + +class NLPCallback(ABC): + @abstractmethod + def on_message(self, txt: str): + pass diff --git a/nlp/nlp_doubao.py b/nlp/nlp_doubao.py index 872652d..fc8c525 100644 --- a/nlp/nlp_doubao.py +++ b/nlp/nlp_doubao.py @@ -1,21 +1,17 @@ #encoding = utf8 -import re -import time -from queue import Queue -import volcenginesdkark -import volcenginesdkcore -from volcenginesdkcore.rest import ApiException +import logging +import time from nlp.nlp_base import NLPBase from volcenginesdkarkruntime import AsyncArk -nlp_queue = Queue() +logger = logging.getLogger(__name__) class DouBao(NLPBase): - def __init__(self, split): - super().__init__(split) + def __init__(self, split, callback=None): + super().__init__(split, callback) # Access Key ID # AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y # AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc @@ -33,6 +29,7 @@ class DouBao(NLPBase): async def _request(self, question): t = time.time() + logger.info(f'_request:{question}') print(f'-------dou_bao ask:', question) stream = await self.__client.chat.completions.create( model="ep-20241008152048-fsgzf", @@ -51,7 +48,10 @@ class DouBao(NLPBase): sec, message = self._split_handle.handle(sec) if len(message) > 0: print(message) + self._on_callback(message) print(sec) + self._on_callback(sec) + logger.info(f'_request:{question}, time:{time.time() - t:.4f}s') print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') ''' diff --git a/nlp/nlp_split.py b/nlp/nlp_split.py index 5b01228..930ff46 100644 --- a/nlp/nlp_split.py +++ b/nlp/nlp_split.py @@ -11,7 +11,7 @@ class NLPSplit(ABC): class PunctuationSplit(NLPSplit): def __init__(self): - self._pattern = r'[,。、;?!,.!?]' + self._pattern = r'(?= self.chunk: + # self._human.put_audio_frame(stream[idx:idx + self.chunk]) + # streamlen -= self.chunk + # idx += self.chunk + # if streamlen>0: #skip last frame(not 20ms) + # self.queue.put(stream[idx:]) + self._byte_stream.seek(0) + self._byte_stream.truncate() + print('-------tts finish push chunk') + except Exception as e: + self._byte_stream.seek(0) + self._byte_stream.truncate() + print('-------tts finish error:', e) + + def __create_bytes_stream(self, byte_stream): + stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 + print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') + stream = stream.astype(np.float32) + + if stream.ndim > 1: + print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + stream = stream[:, 0] + + if sample_rate != self._sample_rate and stream.shape[0] > 0: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._sample_rate}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate) + + return stream diff --git a/utils/async_task_queue.py b/utils/async_task_queue.py index 21216ab..fefb958 100644 --- a/utils/async_task_queue.py +++ b/utils/async_task_queue.py @@ -14,19 +14,24 @@ class AsyncTaskQueue: self._thread.start() def _run_loop(self): - asyncio.set_event_loop(self._loop) + print('_run_loop') self._loop_running.set() + asyncio.set_event_loop(self._loop) self._loop.run_forever() async def _worker(self): + print('_worker') while self._loop_running.is_set(): task = await self._queue.get() if task is None: break + print('run task') await task self._queue.task_done() + print('_worker finish') def add_task(self, coro): + print('add_task') asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop) def start_worker(self):