diff --git a/asr/asr_base.py b/asr/asr_base.py index 6a1b904..73dcc9e 100644 --- a/asr/asr_base.py +++ b/asr/asr_base.py @@ -34,6 +34,7 @@ class AsrBase: observer.process(message) def _notify_complete(self, message: str): + EventBus().post('clear_cache') for observer in self._observers: observer.completed(message) diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py index 899fd2d..407e146 100644 --- a/human/audio_inference_handler.py +++ b/human/audio_inference_handler.py @@ -22,6 +22,7 @@ class AudioInferenceHandler(AudioHandler): super().__init__(context, handler) EventBus().register('stop', self._on_stop) + EventBus().register('clear_cache', self.on_clear_cache) self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel") self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio") @@ -35,10 +36,15 @@ class AudioInferenceHandler(AudioHandler): def __del__(self): EventBus().unregister('stop', self._on_stop) + EventBus().unregister('clear_cache', self.on_clear_cache) def _on_stop(self, *args, **kwargs): self.stop() + def on_clear_cache(self, *args, **kwargs): + self._mal_queue.clear() + self._audio_queue.clear() + def on_handle(self, stream, type_): if not self._is_running: return diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py index 59bc98f..6c46eb9 100644 --- a/human/audio_mal_handler.py +++ b/human/audio_mal_handler.py @@ -19,6 +19,7 @@ class AudioMalHandler(AudioHandler): super().__init__(context, handler) EventBus().register('stop', self._on_stop) + EventBus().register('clear_cache', self.on_clear_cache) self._is_running = True self._queue = SyncQueue(context.batch_size * 2, "AudioMalHandler_queue") @@ -34,10 +35,15 @@ class AudioMalHandler(AudioHandler): def __del__(self): EventBus().unregister('stop', self._on_stop) + EventBus().unregister('clear_cache', self.on_clear_cache) def _on_stop(self, *args, **kwargs): self.stop() + def on_clear_cache(self, *args, **kwargs): + self.frames.clear() + self._queue.clear() + def on_message(self, message): super().on_message(message) diff --git a/human/human_render.py b/human/human_render.py index af01588..78ef4cd 100644 --- a/human/human_render.py +++ b/human/human_render.py @@ -19,6 +19,7 @@ class HumanRender(AudioHandler): super().__init__(context, handler) EventBus().register('stop', self._on_stop) + EventBus().register('clear_cache', self.on_clear_cache) play_clock = PlayClock() self._voice_render = VoiceRender(play_clock, context) self._video_render = VideoRender(play_clock, context, self) @@ -35,10 +36,14 @@ class HumanRender(AudioHandler): def __del__(self): EventBus().unregister('stop', self._on_stop) + EventBus().unregister('clear_cache', self.on_clear_cache) def _on_stop(self, *args, **kwargs): self.stop() + def on_clear_cache(self, *args, **kwargs): + self._queue.clear() + def _on_run(self): logging.info('human render run') while self._exit_event.is_set() and self._is_running: diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py index 47936a9..d4cf098 100644 --- a/nlp/nlp_base.py +++ b/nlp/nlp_base.py @@ -17,13 +17,19 @@ class NLPBase(AsrObserver): self._is_running = True EventBus().register('stop', self.on_stop) + EventBus().register('clear_cache', self.on_clear_cache) def __del__(self): EventBus().unregister('stop', self.on_stop) + EventBus().unregister('clear_cache', self.on_clear_cache) def on_stop(self, *args, **kwargs): self.stop() + def on_clear_cache(self, *args, **kwargs): + logger.info('NLPBase clear_cache') + self._ask_queue.clear() + @property def callback(self): return self._callback diff --git a/tts/tts_base.py b/tts/tts_base.py index 5d14ae6..5ef0775 100644 --- a/tts/tts_base.py +++ b/tts/tts_base.py @@ -16,13 +16,19 @@ class TTSBase(NLPCallback): self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5) self._is_running = True EventBus().register('stop', self.on_stop) + EventBus().register('clear_cache', self.on_clear_cache) def __del__(self): EventBus().unregister('stop', self.on_stop) + EventBus().unregister('clear_cache', self.on_clear_cache) def on_stop(self, *args, **kwargs): self.stop() + def on_clear_cache(self, *args, **kwargs): + logger.info('TTSBase clear_cache') + self._message_queue.clear() + @property def handle(self): return self._handle diff --git a/tts/tts_edge_http.py b/tts/tts_edge_http.py index 11ec055..4856039 100644 --- a/tts/tts_edge_http.py +++ b/tts/tts_edge_http.py @@ -21,6 +21,7 @@ class TTSEdgeHttp(TTSBase): # self._url = 'http://localhost:8082/v1/audio/speech' self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech' logger.info(f"TTSEdge init, {voice}") + self._response_list = [] async def _on_async_request(self, data): async with aiohttp.ClientSession() as session: @@ -35,11 +36,12 @@ class TTSEdgeHttp(TTSBase): def _on_sync_request(self, data): response = requests.post(self._url, json=data) + self._response_list.append(response) + stream = None if response.status_code == 200: stream = BytesIO(response.content) - return stream - else: - return None + self._response_list.remove(response) + return stream async def _on_request(self, txt: str): logger.info(f'TTSEdgeHttp, _on_request, txt:{txt}') @@ -91,3 +93,9 @@ class TTSEdgeHttp(TTSBase): print('TTSEdge close') # if self._byte_stream is not None and not self._byte_stream.closed: # self._byte_stream.close() + + def on_clear_cache(self, *args, **kwargs): + logger.info('TTSEdgeHttp clear_cache') + super().on_clear_cache(*args, **kwargs) + for response in self._response_list: + response.close()