modify add event bus

This commit is contained in:
brige 2024-11-06 11:11:53 +08:00
parent 0741df8fde
commit 2f180f0f65
7 changed files with 74 additions and 9 deletions

View File

@ -2,6 +2,7 @@
import threading import threading
from eventbus import EventBus
from .asr_observer import AsrObserver from .asr_observer import AsrObserver
@ -12,11 +13,19 @@ class AsrBase:
self._samples_per_read = 100 self._samples_per_read = 100
self._observers = [] self._observers = []
EventBus().register('stop', self._on_stop)
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._stop_event.set() self._stop_event.set()
self._thread = threading.Thread(target=self._recognize_loop) self._thread = threading.Thread(target=self._recognize_loop)
self._thread.start() self._thread.start()
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def _recognize_loop(self): def _recognize_loop(self):
pass pass

4
eventbus/__init__.py Normal file
View File

@ -0,0 +1,4 @@
#encoding = utf8
from .event_bus import EventBus

39
eventbus/event_bus.py Normal file
View File

@ -0,0 +1,39 @@
#encoding = utf8
import threading
class EventBus:
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super(EventBus, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
self._listeners = {}
self._lock = threading.Lock()
self._initialized = True
def register(self, event_type, listener):
with self._lock:
if event_type not in self._listeners:
self._listeners[event_type] = []
self._listeners[event_type].append(listener)
def unregister(self, event_type, listener):
with self._lock:
if event_type in self._listeners:
self._listeners[event_type].remove(listener)
if not self._listeners[event_type]:
del self._listeners[event_type]
def post(self, event_type, *args, **kwargs):
with self._lock:
listeners = self._listeners.get(event_type, []).copy()
for listener in listeners:
listener(*args, **kwargs)

View File

@ -3,7 +3,8 @@ import logging
import queue import queue
import time import time
from threading import Thread, Event, Condition from threading import Thread, Event
from eventbus import EventBus
import numpy as np import numpy as np
@ -17,6 +18,8 @@ class AudioMalHandler(AudioHandler):
def __init__(self, context, handler): def __init__(self, context, handler):
super().__init__(context, handler) super().__init__(context, handler)
EventBus().register('stop', self._on_stop)
self._queue = SyncQueue(context.batch_size, "AudioMalHandler_queue") self._queue = SyncQueue(context.batch_size, "AudioMalHandler_queue")
self._exit_event = Event() self._exit_event = Event()
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread") self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
@ -27,6 +30,12 @@ class AudioMalHandler(AudioHandler):
self.chunk = context.sample_rate // context.fps self.chunk = context.sample_rate // context.fps
logger.info("AudioMalHandler init") logger.info("AudioMalHandler init")
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def on_message(self, message): def on_message(self, message):
super().on_message(message) super().on_message(message)

View File

@ -3,6 +3,7 @@ import logging
import os import os
from asr import SherpaNcnnAsr from asr import SherpaNcnnAsr
from eventbus import EventBus
from .audio_inference_handler import AudioInferenceHandler from .audio_inference_handler import AudioInferenceHandler
from .audio_mal_handler import AudioMalHandler from .audio_mal_handler import AudioMalHandler
from .human_render import HumanRender from .human_render import HumanRender
@ -113,8 +114,7 @@ class HumanContext:
self._asr.attach(self._nlp) self._asr.attach(self._nlp)
def stop(self): def stop(self):
object_stop(self._asr) EventBus().post('stop')
object_stop(self._nlp)
object_stop(self._tts) object_stop(self._tts)
object_stop(self._tts_handle) object_stop(self._tts_handle)
object_stop(self._mal_handler) object_stop(self._mal_handler)

View File

@ -2,6 +2,7 @@
import logging import logging
from asr import AsrObserver from asr import AsrObserver
from eventbus import EventBus
from utils import AsyncTaskQueue from utils import AsyncTaskQueue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -15,6 +16,14 @@ class NLPBase(AsrObserver):
self._callback = callback self._callback = callback
self._is_running = False self._is_running = False
EventBus().register('stop', self.onStop)
def __del__(self):
EventBus().unregister('stop', self.onStop)
def onStop(self, *args, **kwargs):
self.stop()
@property @property
def callback(self): def callback(self):
return self._callback return self._callback
@ -50,6 +59,7 @@ class NLPBase(AsrObserver):
def stop(self): def stop(self):
logger.info('NLPBase stop') logger.info('NLPBase stop')
self._is_running = False self._is_running = False
self._ask_queue.clear()
self._ask_queue.add_task(self._on_close) self._ask_queue.add_task(self._on_close)
logger.info('NLPBase add close') logger.info('NLPBase add close')
self._ask_queue.stop() self._ask_queue.stop()

View File

@ -66,9 +66,3 @@ class DouBao(NLPBase):
logger.info('AsyncArk close') logger.info('AsyncArk close')
if self.__client is not None and not self.__client.is_closed(): if self.__client is not None and not self.__client.is_closed():
await self.__client.close() await self.__client.close()
def stop(self):
print('doubao stop00')
self.__client.close()
print('doubao stop11')
super().stop()