From 2f180f0f6533a0fa95b47d7f92f2b5b04724514b Mon Sep 17 00:00:00 2001 From: brige Date: Wed, 6 Nov 2024 11:11:53 +0800 Subject: [PATCH] modify add event bus --- asr/asr_base.py | 9 +++++++++ eventbus/__init__.py | 4 ++++ eventbus/event_bus.py | 39 ++++++++++++++++++++++++++++++++++++++ human/audio_mal_handler.py | 11 ++++++++++- human/human_context.py | 4 ++-- nlp/nlp_base.py | 10 ++++++++++ nlp/nlp_doubao.py | 6 ------ 7 files changed, 74 insertions(+), 9 deletions(-) create mode 100644 eventbus/__init__.py create mode 100644 eventbus/event_bus.py diff --git a/asr/asr_base.py b/asr/asr_base.py index 0528093..6a1b904 100644 --- a/asr/asr_base.py +++ b/asr/asr_base.py @@ -2,6 +2,7 @@ import threading +from eventbus import EventBus from .asr_observer import AsrObserver @@ -12,11 +13,19 @@ class AsrBase: self._samples_per_read = 100 self._observers = [] + EventBus().register('stop', self._on_stop) + self._stop_event = threading.Event() self._stop_event.set() self._thread = threading.Thread(target=self._recognize_loop) self._thread.start() + def __del__(self): + EventBus().unregister('stop', self._on_stop) + + def _on_stop(self, *args, **kwargs): + self.stop() + def _recognize_loop(self): pass diff --git a/eventbus/__init__.py b/eventbus/__init__.py new file mode 100644 index 0000000..0d26045 --- /dev/null +++ b/eventbus/__init__.py @@ -0,0 +1,4 @@ +#encoding = utf8 + +from .event_bus import EventBus + diff --git a/eventbus/event_bus.py b/eventbus/event_bus.py new file mode 100644 index 0000000..6a02683 --- /dev/null +++ b/eventbus/event_bus.py @@ -0,0 +1,39 @@ +#encoding = utf8 +import threading + + +class EventBus: + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if not cls._instance: + with cls._lock: + if not cls._instance: + cls._instance = super(EventBus, cls).__new__(cls, *args, **kwargs) + return cls._instance + + def __init__(self): + if not hasattr(self, '_initialized'): + self._listeners = {} + self._lock = threading.Lock() + self._initialized = True + + def register(self, event_type, listener): + with self._lock: + if event_type not in self._listeners: + self._listeners[event_type] = [] + self._listeners[event_type].append(listener) + + def unregister(self, event_type, listener): + with self._lock: + if event_type in self._listeners: + self._listeners[event_type].remove(listener) + if not self._listeners[event_type]: + del self._listeners[event_type] + + def post(self, event_type, *args, **kwargs): + with self._lock: + listeners = self._listeners.get(event_type, []).copy() + for listener in listeners: + listener(*args, **kwargs) diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py index 796ad25..a880b55 100644 --- a/human/audio_mal_handler.py +++ b/human/audio_mal_handler.py @@ -3,7 +3,8 @@ import logging import queue import time -from threading import Thread, Event, Condition +from threading import Thread, Event +from eventbus import EventBus import numpy as np @@ -17,6 +18,8 @@ class AudioMalHandler(AudioHandler): def __init__(self, context, handler): super().__init__(context, handler) + EventBus().register('stop', self._on_stop) + self._queue = SyncQueue(context.batch_size, "AudioMalHandler_queue") self._exit_event = Event() self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread") @@ -27,6 +30,12 @@ class AudioMalHandler(AudioHandler): self.chunk = context.sample_rate // context.fps logger.info("AudioMalHandler init") + def __del__(self): + EventBus().unregister('stop', self._on_stop) + + def _on_stop(self, *args, **kwargs): + self.stop() + def on_message(self, message): super().on_message(message) diff --git a/human/human_context.py b/human/human_context.py index cbefb1f..2c96cde 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -3,6 +3,7 @@ import logging import os from asr import SherpaNcnnAsr +from eventbus import EventBus from .audio_inference_handler import AudioInferenceHandler from .audio_mal_handler import AudioMalHandler from .human_render import HumanRender @@ -113,8 +114,7 @@ class HumanContext: self._asr.attach(self._nlp) def stop(self): - object_stop(self._asr) - object_stop(self._nlp) + EventBus().post('stop') object_stop(self._tts) object_stop(self._tts_handle) object_stop(self._mal_handler) diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py index 13bb4e7..3dbd4a6 100644 --- a/nlp/nlp_base.py +++ b/nlp/nlp_base.py @@ -2,6 +2,7 @@ import logging from asr import AsrObserver +from eventbus import EventBus from utils import AsyncTaskQueue logger = logging.getLogger(__name__) @@ -15,6 +16,14 @@ class NLPBase(AsrObserver): self._callback = callback self._is_running = False + EventBus().register('stop', self.onStop) + + def __del__(self): + EventBus().unregister('stop', self.onStop) + + def onStop(self, *args, **kwargs): + self.stop() + @property def callback(self): return self._callback @@ -50,6 +59,7 @@ class NLPBase(AsrObserver): def stop(self): logger.info('NLPBase stop') self._is_running = False + self._ask_queue.clear() self._ask_queue.add_task(self._on_close) logger.info('NLPBase add close') self._ask_queue.stop() diff --git a/nlp/nlp_doubao.py b/nlp/nlp_doubao.py index 794bdb5..d1b41da 100644 --- a/nlp/nlp_doubao.py +++ b/nlp/nlp_doubao.py @@ -66,9 +66,3 @@ class DouBao(NLPBase): logger.info('AsyncArk close') if self.__client is not None and not self.__client.is_closed(): await self.__client.close() - - def stop(self): - print('doubao stop00') - self.__client.close() - print('doubao stop11') - super().stop()