modify add event bus
This commit is contained in:
parent
0741df8fde
commit
2f180f0f65
@ -2,6 +2,7 @@
|
||||
|
||||
import threading
|
||||
|
||||
from eventbus import EventBus
|
||||
from .asr_observer import AsrObserver
|
||||
|
||||
|
||||
@ -12,11 +13,19 @@ class AsrBase:
|
||||
self._samples_per_read = 100
|
||||
self._observers = []
|
||||
|
||||
EventBus().register('stop', self._on_stop)
|
||||
|
||||
self._stop_event = threading.Event()
|
||||
self._stop_event.set()
|
||||
self._thread = threading.Thread(target=self._recognize_loop)
|
||||
self._thread.start()
|
||||
|
||||
def __del__(self):
|
||||
EventBus().unregister('stop', self._on_stop)
|
||||
|
||||
def _on_stop(self, *args, **kwargs):
|
||||
self.stop()
|
||||
|
||||
def _recognize_loop(self):
|
||||
pass
|
||||
|
||||
|
4
eventbus/__init__.py
Normal file
4
eventbus/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
#encoding = utf8
|
||||
|
||||
from .event_bus import EventBus
|
||||
|
39
eventbus/event_bus.py
Normal file
39
eventbus/event_bus.py
Normal file
@ -0,0 +1,39 @@
|
||||
#encoding = utf8
|
||||
import threading
|
||||
|
||||
|
||||
class EventBus:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super(EventBus, cls).__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, '_initialized'):
|
||||
self._listeners = {}
|
||||
self._lock = threading.Lock()
|
||||
self._initialized = True
|
||||
|
||||
def register(self, event_type, listener):
|
||||
with self._lock:
|
||||
if event_type not in self._listeners:
|
||||
self._listeners[event_type] = []
|
||||
self._listeners[event_type].append(listener)
|
||||
|
||||
def unregister(self, event_type, listener):
|
||||
with self._lock:
|
||||
if event_type in self._listeners:
|
||||
self._listeners[event_type].remove(listener)
|
||||
if not self._listeners[event_type]:
|
||||
del self._listeners[event_type]
|
||||
|
||||
def post(self, event_type, *args, **kwargs):
|
||||
with self._lock:
|
||||
listeners = self._listeners.get(event_type, []).copy()
|
||||
for listener in listeners:
|
||||
listener(*args, **kwargs)
|
@ -3,7 +3,8 @@ import logging
|
||||
import queue
|
||||
import time
|
||||
|
||||
from threading import Thread, Event, Condition
|
||||
from threading import Thread, Event
|
||||
from eventbus import EventBus
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -17,6 +18,8 @@ class AudioMalHandler(AudioHandler):
|
||||
def __init__(self, context, handler):
|
||||
super().__init__(context, handler)
|
||||
|
||||
EventBus().register('stop', self._on_stop)
|
||||
|
||||
self._queue = SyncQueue(context.batch_size, "AudioMalHandler_queue")
|
||||
self._exit_event = Event()
|
||||
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
|
||||
@ -27,6 +30,12 @@ class AudioMalHandler(AudioHandler):
|
||||
self.chunk = context.sample_rate // context.fps
|
||||
logger.info("AudioMalHandler init")
|
||||
|
||||
def __del__(self):
|
||||
EventBus().unregister('stop', self._on_stop)
|
||||
|
||||
def _on_stop(self, *args, **kwargs):
|
||||
self.stop()
|
||||
|
||||
def on_message(self, message):
|
||||
super().on_message(message)
|
||||
|
||||
|
@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
|
||||
from asr import SherpaNcnnAsr
|
||||
from eventbus import EventBus
|
||||
from .audio_inference_handler import AudioInferenceHandler
|
||||
from .audio_mal_handler import AudioMalHandler
|
||||
from .human_render import HumanRender
|
||||
@ -113,8 +114,7 @@ class HumanContext:
|
||||
self._asr.attach(self._nlp)
|
||||
|
||||
def stop(self):
|
||||
object_stop(self._asr)
|
||||
object_stop(self._nlp)
|
||||
EventBus().post('stop')
|
||||
object_stop(self._tts)
|
||||
object_stop(self._tts_handle)
|
||||
object_stop(self._mal_handler)
|
||||
|
@ -2,6 +2,7 @@
|
||||
import logging
|
||||
|
||||
from asr import AsrObserver
|
||||
from eventbus import EventBus
|
||||
from utils import AsyncTaskQueue
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -15,6 +16,14 @@ class NLPBase(AsrObserver):
|
||||
self._callback = callback
|
||||
self._is_running = False
|
||||
|
||||
EventBus().register('stop', self.onStop)
|
||||
|
||||
def __del__(self):
|
||||
EventBus().unregister('stop', self.onStop)
|
||||
|
||||
def onStop(self, *args, **kwargs):
|
||||
self.stop()
|
||||
|
||||
@property
|
||||
def callback(self):
|
||||
return self._callback
|
||||
@ -50,6 +59,7 @@ class NLPBase(AsrObserver):
|
||||
def stop(self):
|
||||
logger.info('NLPBase stop')
|
||||
self._is_running = False
|
||||
self._ask_queue.clear()
|
||||
self._ask_queue.add_task(self._on_close)
|
||||
logger.info('NLPBase add close')
|
||||
self._ask_queue.stop()
|
||||
|
@ -66,9 +66,3 @@ class DouBao(NLPBase):
|
||||
logger.info('AsyncArk close')
|
||||
if self.__client is not None and not self.__client.is_closed():
|
||||
await self.__client.close()
|
||||
|
||||
def stop(self):
|
||||
print('doubao stop00')
|
||||
self.__client.close()
|
||||
print('doubao stop11')
|
||||
super().stop()
|
||||
|
Loading…
Reference in New Issue
Block a user