modify add event bus

This commit is contained in:
brige 2024-11-06 11:11:53 +08:00
parent 0741df8fde
commit 2f180f0f65
7 changed files with 74 additions and 9 deletions

View File

@ -2,6 +2,7 @@
import threading
from eventbus import EventBus
from .asr_observer import AsrObserver
@ -12,11 +13,19 @@ class AsrBase:
self._samples_per_read = 100
self._observers = []
EventBus().register('stop', self._on_stop)
self._stop_event = threading.Event()
self._stop_event.set()
self._thread = threading.Thread(target=self._recognize_loop)
self._thread.start()
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def _recognize_loop(self):
pass

4
eventbus/__init__.py Normal file
View File

@ -0,0 +1,4 @@
#encoding = utf8
from .event_bus import EventBus

39
eventbus/event_bus.py Normal file
View File

@ -0,0 +1,39 @@
#encoding = utf8
import threading
class EventBus:
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super(EventBus, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
self._listeners = {}
self._lock = threading.Lock()
self._initialized = True
def register(self, event_type, listener):
with self._lock:
if event_type not in self._listeners:
self._listeners[event_type] = []
self._listeners[event_type].append(listener)
def unregister(self, event_type, listener):
with self._lock:
if event_type in self._listeners:
self._listeners[event_type].remove(listener)
if not self._listeners[event_type]:
del self._listeners[event_type]
def post(self, event_type, *args, **kwargs):
with self._lock:
listeners = self._listeners.get(event_type, []).copy()
for listener in listeners:
listener(*args, **kwargs)

View File

@ -3,7 +3,8 @@ import logging
import queue
import time
from threading import Thread, Event, Condition
from threading import Thread, Event
from eventbus import EventBus
import numpy as np
@ -17,6 +18,8 @@ class AudioMalHandler(AudioHandler):
def __init__(self, context, handler):
super().__init__(context, handler)
EventBus().register('stop', self._on_stop)
self._queue = SyncQueue(context.batch_size, "AudioMalHandler_queue")
self._exit_event = Event()
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
@ -27,6 +30,12 @@ class AudioMalHandler(AudioHandler):
self.chunk = context.sample_rate // context.fps
logger.info("AudioMalHandler init")
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def on_message(self, message):
super().on_message(message)

View File

@ -3,6 +3,7 @@ import logging
import os
from asr import SherpaNcnnAsr
from eventbus import EventBus
from .audio_inference_handler import AudioInferenceHandler
from .audio_mal_handler import AudioMalHandler
from .human_render import HumanRender
@ -113,8 +114,7 @@ class HumanContext:
self._asr.attach(self._nlp)
def stop(self):
object_stop(self._asr)
object_stop(self._nlp)
EventBus().post('stop')
object_stop(self._tts)
object_stop(self._tts_handle)
object_stop(self._mal_handler)

View File

@ -2,6 +2,7 @@
import logging
from asr import AsrObserver
from eventbus import EventBus
from utils import AsyncTaskQueue
logger = logging.getLogger(__name__)
@ -15,6 +16,14 @@ class NLPBase(AsrObserver):
self._callback = callback
self._is_running = False
EventBus().register('stop', self.onStop)
def __del__(self):
EventBus().unregister('stop', self.onStop)
def onStop(self, *args, **kwargs):
self.stop()
@property
def callback(self):
return self._callback
@ -50,6 +59,7 @@ class NLPBase(AsrObserver):
def stop(self):
logger.info('NLPBase stop')
self._is_running = False
self._ask_queue.clear()
self._ask_queue.add_task(self._on_close)
logger.info('NLPBase add close')
self._ask_queue.stop()

View File

@ -66,9 +66,3 @@ class DouBao(NLPBase):
logger.info('AsyncArk close')
if self.__client is not None and not self.__client.is_closed():
await self.__client.close()
def stop(self):
print('doubao stop00')
self.__client.close()
print('doubao stop11')
super().stop()