From 4c9ec6831f3ad88fd8350e7581ff136b3feecffe Mon Sep 17 00:00:00 2001 From: brige Date: Thu, 7 Nov 2024 20:43:46 +0800 Subject: [PATCH] modify async --- asr/sherpa_ncnn_asr.py | 2 +- human/audio_inference_handler.py | 17 +++-------- human/human_context.py | 5 --- human/human_render.py | 16 +--------- nlp/nlp_base.py | 4 ++- nlp/nlp_doubao.py | 7 ++--- tts/tts_base.py | 13 ++++++-- utils/async_task_queue.py | 52 ++++++++++++++++++-------------- 8 files changed, 53 insertions(+), 63 deletions(-) diff --git a/asr/sherpa_ncnn_asr.py b/asr/sherpa_ncnn_asr.py index eca2459..5c7a25f 100644 --- a/asr/sherpa_ncnn_asr.py +++ b/asr/sherpa_ncnn_asr.py @@ -65,7 +65,7 @@ class SherpaNcnnAsr(AsrBase): self._notify_complete('介绍中国5000年历史文学') logger.info(f'_recognize_loop111') segment_id += 1 - time.sleep(15) + time.sleep(60) logger.info(f'_recognize_loop222') logger.info(f'_recognize_loop exit') ''' diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py index abfe34d..f24bce4 100644 --- a/human/audio_inference_handler.py +++ b/human/audio_inference_handler.py @@ -70,17 +70,13 @@ class AudioInferenceHandler(AudioHandler): logger.info(f'use device:{device}') while self._is_running: - print('AudioInferenceHandler mel_batch:000') if self._exit_event.is_set(): start_time = time.perf_counter() batch_size = self._context.batch_size try: - print('AudioInferenceHandler mel_batch:') - mel_batch = self._mal_queue.get(timeout=0.03) - print('AudioInferenceHandler mel_batch:111') - # print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', size) + mel_batch = self._mal_queue.get(timeout=0.02) + print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', self._mal_queue.size()) except queue.Empty: - print('AudioInferenceHandler mel_batch:111') continue # print('origin mel_batch:', len(mel_batch)) @@ -100,13 +96,11 @@ class AudioInferenceHandler(AudioHandler): if is_all_silence: for i in range(batch_size): if not self._is_running: - print('AudioInferenceHandler not running1111') break - print('AudioInferenceHandler is_all_silence 111') self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), 0) - print('AudioInferenceHandler is_all_silence 222') index = index + 1 + print('AudioInferenceHandler all silence') else: logger.info('infer=======') t = time.perf_counter() @@ -152,7 +146,6 @@ class AudioInferenceHandler(AudioHandler): index = index + 1 logger.info(f'total batch time: {time.perf_counter() - start_time}') else: - print('AudioInferenceHandler mel_batch:333') time.sleep(1) break logger.info('AudioInferenceHandler inference processor stop') @@ -162,9 +155,9 @@ class AudioInferenceHandler(AudioHandler): self._is_running = False self._exit_event.clear() if self._run_thread.is_alive(): - print('AudioInferenceHandler stop join') + logger.info('AudioInferenceHandler stop join') self._run_thread.join() - print('AudioInferenceHandler stop exit') + logger.info('AudioInferenceHandler stop exit') def pause_talk(self): print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size()) diff --git a/human/human_context.py b/human/human_context.py index 2c96cde..cca030d 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -115,11 +115,6 @@ class HumanContext: def stop(self): EventBus().post('stop') - object_stop(self._tts) - object_stop(self._tts_handle) - object_stop(self._mal_handler) - object_stop(self._infer_handler) - object_stop(self._render_handler) def pause_talk(self): self._nlp.pause_talk() diff --git a/human/human_render.py b/human/human_render.py index 8c1851e..a627ff7 100644 --- a/human/human_render.py +++ b/human/human_render.py @@ -53,6 +53,7 @@ class HumanRender(AudioHandler): if value is None: return res_frame, idx, audio_frames = value + print('voice render queue size', self._queue.size()) if not self._empty_log: self._empty_log = True logging.info('render render:') @@ -84,25 +85,10 @@ class HumanRender(AudioHandler): super().on_message(message) def on_handle(self, stream, index): - print('human render:', self._is_running) if not self._is_running: return self._queue.put(stream) - # res_frame, idx, audio_frames = stream - # self._voice_render.put(audio_frames, self._last_audio_ps) - # self._last_audio_ps = self._last_audio_ps + 0.4 - # type_ = 1 - # if audio_frames[0][1] != 0 and audio_frames[1][1] != 0: - # type_ = 0 - # self._video_render.put((res_frame, idx, type_), self._last_video_ps) - # self._last_video_ps = self._last_video_ps + 0.4 - # - # if self._voice_render.is_full(): - # self._context.notify({'msg_id': MessageType.Video_Render_Queue_Full}) - - # def get_audio_queue_size(self): - # return self._voice_render.size() def pause_talk(self): logging.info('hunan pause_talk') diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py index 7e8f458..bf11ab5 100644 --- a/nlp/nlp_base.py +++ b/nlp/nlp_base.py @@ -14,7 +14,7 @@ class NLPBase(AsrObserver): self._context = context self._split_handle = split self._callback = callback - self._is_running = False + self._is_running = True EventBus().register('stop', self.on_stop) @@ -46,6 +46,8 @@ class NLPBase(AsrObserver): pass def completed(self, message: str): + if not self._is_running: + return logger.info(f'complete:{message}') # self._context.pause_talk() self.ask(message) diff --git a/nlp/nlp_doubao.py b/nlp/nlp_doubao.py index 7bbc637..6935c10 100644 --- a/nlp/nlp_doubao.py +++ b/nlp/nlp_doubao.py @@ -40,17 +40,16 @@ class DouBao(NLPBase): ], stream=True ) - t1 = time.time() - await stream.close() - logger.info(f'-------dou_bao close time:{time.time() - t1:.4f}s') - return + sec = '' async for completion in stream: sec = sec + completion.choices[0].delta.content sec, message = self._split_handle.handle(sec) if len(message) > 0: + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') self._on_callback(message) self._on_callback(sec) + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') await stream.close() # sec = "你是测试客服,是由字节跳动开发的 AI 人工智能助手" diff --git a/tts/tts_base.py b/tts/tts_base.py index 7fc1ba2..940e399 100644 --- a/tts/tts_base.py +++ b/tts/tts_base.py @@ -3,6 +3,7 @@ import logging import time +from eventbus import EventBus from nlp import NLPCallback from utils import AsyncTaskQueue @@ -13,7 +14,14 @@ class TTSBase(NLPCallback): def __init__(self, handle): self._handle = handle self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5) - self._is_running = False + self._is_running = True + EventBus().register('stop', self.on_stop) + + def __del__(self): + EventBus().unregister('stop', self.on_stop) + + def on_stop(self, *args, **kwargs): + self.stop() @property def handle(self): @@ -49,7 +57,6 @@ class TTSBase(NLPCallback): self.message(txt) def message(self, txt): - self._is_running = True txt = txt.strip() if len(txt) == 0: logger.info(f'message is empty') @@ -62,10 +69,10 @@ class TTSBase(NLPCallback): self._message_queue.add_task(self._request, txt, index) def stop(self): + self._is_running = False self._message_queue.add_task(self._on_close) self._message_queue.stop() def pause_talk(self): logger.info(f'TTSBase pause_talk') - self._is_running = False self._message_queue.clear() diff --git a/utils/async_task_queue.py b/utils/async_task_queue.py index b7a32bb..f95317b 100644 --- a/utils/async_task_queue.py +++ b/utils/async_task_queue.py @@ -1,12 +1,16 @@ #encoding = utf8 import asyncio +import logging +from queue import Queue import threading +logger = logging.getLogger(__name__) + class AsyncTaskQueue: def __init__(self, name, work_num=1): - self._queue = asyncio.Queue() + self._queue = Queue() self._worker_num = work_num self._current_worker_num = work_num self._name = name @@ -15,39 +19,43 @@ class AsyncTaskQueue: self.__loop = None def _run_loop(self): - print(self._name, '_run_loop') + logging.info(f'{self._name}, _run_loop') self.__loop = asyncio.new_event_loop() asyncio.set_event_loop(self.__loop) self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)] - self.__loop.run_forever() - print(self._name, "exit run") - if not self.__loop.is_closed(): + try: + self.__loop.run_forever() + finally: + logging.info(f'{self._name}, exit run') + self.__loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(self.__loop))) self.__loop.close() + logging.info(f'{self._name}, close loop') async def _worker(self): - print(self._name, '_worker') + logging.info(f'{self._name}, _worker') while True: - print(f'{self._name} get queue') - task = await self._queue.get() - print(f'{self._name} get queue11') - print(f"{self._name} Get task size: {self._queue.qsize()}") - if task is None: # None as a stop signal - break + try: + task = self._queue.get() + if task is None: # None as a stop signal + break - func, *args = task # Unpack task - print(f"{self._name}, Executing task with args: {args}") - await func(*args) # Execute async function - self._queue.task_done() + func, *args = task # Unpack task + if func is None: # None as a stop signal + break - print(self._name, '_worker finish') + await func(*args) # Execute async function + except Exception as e: + logging.error(f'{self._name} error:', e) + finally: + self._queue.task_done() + + logging.info(f'{self._name}, _worker finish') self._current_worker_num -= 1 if self._current_worker_num == 0: - print(self._name, 'loop stop') - self.__loop.stop() + self.__loop.call_soon_threadsafe(self.__loop.stop) def add_task(self, func, *args): - # return self.__loop.call_soon_threadsafe(self._queue.put_nowait, (func, *args)) - self._queue.put_nowait((func, *args)) + self._queue.put((func, *args)) def stop_workers(self): for _ in range(self._worker_num): @@ -60,4 +68,4 @@ class AsyncTaskQueue: def stop(self): self.stop_workers() - self._thread.join() + self._thread.join() \ No newline at end of file