From 5d98af51de4e5418ae79b4491e6dadfe419c5a2e Mon Sep 17 00:00:00 2001 From: jiegeaiai Date: Mon, 14 Oct 2024 18:20:55 +0800 Subject: [PATCH] modify async task --- nlp/nlp_base.py | 7 +++-- nlp/nlp_doubao.py | 44 +++++++++++++++------------ test/test_nlp_only.py | 5 +-- test/test_nlp_tts.py | 44 +++++++++++++++++++++++++++ test/test_tts_only.py | 9 ++---- tts/tts_audio_handle.py | 15 +++++---- tts/tts_base.py | 23 +++++++++----- tts/tts_edge.py | 31 ++++++++++++------- utils/async_task_queue.py | 64 +++++++++++++++++++++------------------ 9 files changed, 159 insertions(+), 83 deletions(-) create mode 100644 test/test_nlp_tts.py diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py index 9f3410a..b4fd8e5 100644 --- a/nlp/nlp_base.py +++ b/nlp/nlp_base.py @@ -10,7 +10,6 @@ logger = logging.getLogger(__name__) class NLPBase(AsrObserver): def __init__(self, split, callback=None): self._ask_queue = AsyncTaskQueue() - self._ask_queue.start_worker() self._split_handle = split self._callback = callback @@ -29,6 +28,9 @@ class NLPBase(AsrObserver): async def _request(self, question): pass + async def _on_close(self): + pass + def process(self, message: str): pass @@ -38,8 +40,9 @@ class NLPBase(AsrObserver): def ask(self, question): logger.info(f'ask:{question}') - self._ask_queue.add_task(self._request(question)) + self._ask_queue.add_task(self._request, question) def stop(self): + self._ask_queue.add_task(self._on_close) self._ask_queue.stop() \ No newline at end of file diff --git a/nlp/nlp_doubao.py b/nlp/nlp_doubao.py index fc8c525..ab7c9ad 100644 --- a/nlp/nlp_doubao.py +++ b/nlp/nlp_doubao.py @@ -31,29 +31,33 @@ class DouBao(NLPBase): t = time.time() logger.info(f'_request:{question}') print(f'-------dou_bao ask:', question) - stream = await self.__client.chat.completions.create( - model="ep-20241008152048-fsgzf", - messages=[ - {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"}, - {"role": "user", "content": question}, - ], - stream=True - ) - sec = '' - async for completion in stream: - # print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') - # nlp_queue.put(completion.choices[0].delta.content) - # print(completion.choices[0].delta.content, end="") - sec = sec + completion.choices[0].delta.content - sec, message = self._split_handle.handle(sec) - if len(message) > 0: - print(message) - self._on_callback(message) - print(sec) - self._on_callback(sec) + try: + stream = await self.__client.chat.completions.create( + model="ep-20241008152048-fsgzf", + messages=[ + {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"}, + {"role": "user", "content": question}, + ], + stream=True + ) + sec = '' + async for completion in stream: + sec = sec + completion.choices[0].delta.content + sec, message = self._split_handle.handle(sec) + if len(message) > 0: + self._on_callback(message) + self._on_callback(sec) + await stream.close() + except Exception as e: + print(e) logger.info(f'_request:{question}, time:{time.time() - t:.4f}s') print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + async def _on_close(self): + print('AsyncArk close') + if self.__client is not None and not self.__client.is_closed(): + await self.__client.close() + ''' if __name__ == "__main__": # print(get_dou_bao_api()) diff --git a/test/test_nlp_only.py b/test/test_nlp_only.py index e07a7e4..9de74fb 100644 --- a/test/test_nlp_only.py +++ b/test/test_nlp_only.py @@ -12,9 +12,8 @@ def main(): nlp.ask('你好') nlp.ask('你是谁') nlp.ask('能做什么') - time.sleep(20) - print("Stop! ") nlp.stop() + print("stop") if __name__ == "__main__": @@ -22,3 +21,5 @@ if __name__ == "__main__": main() except KeyboardInterrupt: print("\nCaught Ctrl + C. Exiting") + except Exception as e: + print(e) diff --git a/test/test_nlp_tts.py b/test/test_nlp_tts.py new file mode 100644 index 0000000..85888a5 --- /dev/null +++ b/test/test_nlp_tts.py @@ -0,0 +1,44 @@ +#encoding = utf8 + +import sys +import time + +from asr import SherpaNcnnAsr +from nlp import PunctuationSplit +from nlp.nlp_doubao import DouBao +from tts import TTSEdge, TTSAudioSaveHandle + +try: + import sounddevice as sd +except ImportError as e: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + + +def main(): + print("Started! Please speak") + handle = TTSAudioSaveHandle() + tts = TTSEdge(handle) + split = PunctuationSplit() + nlp = DouBao(split, tts) + nlp.ask('你好,你是谁?') + nlp.ask('可以帮我做什么?') + nlp.stop() + tts.stop() + print("Stop! ") + + +if __name__ == "__main__": + devices = sd.query_devices() + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/test/test_tts_only.py b/test/test_tts_only.py index a5bfe4a..b6bf2da 100644 --- a/test/test_tts_only.py +++ b/test/test_tts_only.py @@ -2,20 +2,17 @@ import time -from tts import TTSEdge +from tts import TTSEdge, TTSAudioSaveHandle def main(): - print("Started! Please speak") - - tts = TTSEdge() + handle = TTSAudioSaveHandle() + tts = TTSEdge(handle) tts.message('你好,') tts.message('请问有什么可以帮到您,') tts.message('很高兴为您服务。') tts.message('祝您平安,') tts.message('再见') - - time.sleep(20) tts.stop() print("Stop! ") diff --git a/tts/tts_audio_handle.py b/tts/tts_audio_handle.py index b04016a..8c1d39d 100644 --- a/tts/tts_audio_handle.py +++ b/tts/tts_audio_handle.py @@ -9,6 +9,7 @@ from audio import save_wav class TTSAudioHandle(ABC): def __init__(self): self._sample_rate = 16000 + self._index = 1 @property def sample_rate(self): @@ -19,9 +20,13 @@ class TTSAudioHandle(ABC): self._sample_rate = value @abstractmethod - def on_handle(self, stream): + def on_handle(self, stream, index): pass + def get_index(self): + self._index = self._index + 1 + return self._index + class TTSAudioSplitHandle(TTSAudioHandle): def __init__(self, human): @@ -30,7 +35,7 @@ class TTSAudioSplitHandle(TTSAudioHandle): self.sample_rate = self._human.get_audio_sample_rate() self._chunk = self.sample_rate // self._human.get_fps() - def on_handle(self, stream): + def on_handle(self, stream, index): stream_len = stream.shape[0] idx = 0 @@ -43,7 +48,6 @@ class TTSAudioSplitHandle(TTSAudioHandle): class TTSAudioSaveHandle(TTSAudioHandle): def __init__(self): super().__init__() - self._count = 1 self._save_path_dir = '../temp/audio/' self._clean() @@ -65,8 +69,7 @@ class TTSAudioSaveHandle(TTSAudioHandle): shutil.rmtree(file_path) print(f"Deleted directory and its contents: {file_path}") - def on_handle(self, stream): - file_name = self._save_path_dir + str(self._count) + '.wav' + def on_handle(self, stream, index): + file_name = self._save_path_dir + str(index) + '.wav' save_wav(stream, file_name, self.sample_rate) - self._count = self._count + 1 diff --git a/tts/tts_base.py b/tts/tts_base.py index fd1d3e8..067e2fe 100644 --- a/tts/tts_base.py +++ b/tts/tts_base.py @@ -12,8 +12,7 @@ logger = logging.getLogger(__name__) class TTSBase(NLPCallback): def __init__(self, handle): self._handle = handle - self._message_queue = AsyncTaskQueue() - self._message_queue.start_worker() + self._message_queue = AsyncTaskQueue(5) @property def handle(self): @@ -23,17 +22,23 @@ class TTSBase(NLPCallback): def handle(self, value): self._handle = value - async def _request(self, txt: str): + async def _request(self, txt: str, index): print('_request:', txt) t = time.time() - await self._on_request(txt) + stream = await self._on_request(txt) + if stream is None: + print(f'-------stream is None') + return print(f'-------tts time:{time.time() - t:.4f}s') - await self._on_handle() + await self._on_handle(stream, index) async def _on_request(self, text: str): pass - async def _on_handle(self): + async def _on_handle(self, stream, index): + pass + + async def _on_close(self): pass def on_message(self, txt: str): @@ -42,7 +47,11 @@ class TTSBase(NLPCallback): def message(self, txt): logger.info(f'message:{txt}') print(f'message:{txt}') - self._message_queue.add_task(self._request(txt)) + index = 0 + if self._handle is not None: + index = self._handle.get_index() + self._message_queue.add_task(self._request, txt, index) def stop(self): + self._message_queue.add_task(self._on_close) self._message_queue.stop() diff --git a/tts/tts_edge.py b/tts/tts_edge.py index 3778b31..0926c25 100644 --- a/tts/tts_edge.py +++ b/tts/tts_edge.py @@ -14,34 +14,38 @@ class TTSEdge(TTSBase): def __init__(self, handle, voice='zh-CN-XiaoyiNeural'): super().__init__(handle) self._voice = voice - self._byte_stream = BytesIO() async def _on_request(self, txt: str): + print('_on_request, txt') communicate = edge_tts.Communicate(txt, self._voice) first = True + byte_stream = BytesIO() async for chunk in communicate.stream(): if first: first = False if chunk["type"] == "audio": - self._byte_stream.write(chunk["data"]) + byte_stream.write(chunk["data"]) elif chunk["type"] == "WordBoundary": pass + await communicate.stream().aclose() + return byte_stream - async def _on_handle(self): - self._byte_stream.seek(0) + async def _on_handle(self, stream, index): + stream.seek(0) try: - self._byte_stream.seek(0) - stream = self.__create_bytes_stream(self._byte_stream) + stream.seek(0) + byte_stream = self.__create_bytes_stream(stream) print('-------tts start push chunk') - self._handle.on_handle(stream) - self._byte_stream.seek(0) - self._byte_stream.truncate() + self._handle.on_handle(byte_stream, index) + stream.seek(0) + stream.truncate() print('-------tts finish push chunk') except Exception as e: - self._byte_stream.seek(0) - self._byte_stream.truncate() + stream.seek(0) + stream.truncate() print('-------tts finish error:', e) + stream.close() def __create_bytes_stream(self, byte_stream): stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 @@ -57,3 +61,8 @@ class TTSEdge(TTSBase): stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate) return stream + + async def _on_close(self): + print('TTSEdge close') + # if self._byte_stream is not None and not self._byte_stream.closed: + # self._byte_stream.close() diff --git a/utils/async_task_queue.py b/utils/async_task_queue.py index 0dfa505..e2b7b4c 100644 --- a/utils/async_task_queue.py +++ b/utils/async_task_queue.py @@ -5,48 +5,54 @@ import threading class AsyncTaskQueue: - def __init__(self): + def __init__(self, work_num=1): self._queue = asyncio.Queue() - self._loop = asyncio.new_event_loop() + self._worker_num = work_num + self._current_worker_num = work_num + self._condition = threading.Condition() self._thread = threading.Thread(target=self._run_loop) - self._worker_task = None - # self._loop_running = threading.Event() - # self._loop_running.set() self._thread.start() + self.__loop = None def _run_loop(self): print('_run_loop') - asyncio.set_event_loop(self._loop) - self._loop.run_forever() + self.__loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.__loop) + self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)] + self.__loop.run_forever() + print("exit run") + if not self.__loop.is_closed(): + self.__loop.close() async def _worker(self): print('_worker') while True: - print('_worker1') - task = await self._queue.get() - print('_worker2') - if task is None: - break - print('run task') - await task - self._queue.task_done() + with self._condition: + self._condition.wait_for(lambda: not self._queue.empty()) + task = await self._queue.get() + func, *args = task # 解包任务 + if func is None: # None 作为结束信号 + break + + print(f"Executing task with args: {args}") + await func(*args) # 执行异步函数 + self._queue.task_done() + print('_worker finish') + self._current_worker_num = self._current_worker_num - 1 + if self._current_worker_num == 0: + print('loop stop') + self.__loop.stop() - def add_task(self, coro): - print('add_task') - asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop) - print('add_task1') + def add_task(self, func, *args): + with self._condition: + self._queue.put_nowait((func, *args)) + self._condition.notify() - def start_worker(self): - print('start_worker') - if not self._worker_task: - self._worker_task = asyncio.run_coroutine_threadsafe(self._worker(), self._loop) + def stop_workers(self): + for _ in range(self._worker_num): + self.add_task(None) # 发送结束信号 def stop(self): - # self._loop_running.clear() - asyncio.run_coroutine_threadsafe(self._queue.put(None), self._loop).result() - if self._worker_task: - self._worker_task.result() - self._loop.call_soon_threadsafe(self._loop.stop) + self.stop_workers() self._thread.join() - self._loop.close()