From c8fc8097e7dc177815ce4766ce372be97e339919 Mon Sep 17 00:00:00 2001
From: jiegeaiai <jiegeaiai@163.com>
Date: Thu, 7 Nov 2024 08:26:03 +0800
Subject: [PATCH] modify sync block

---
 human/audio_inference_handler.py | 23 +++++++++++++++++++----
 human/audio_mal_handler.py       |  5 ++---
 human/human_render.py            |  6 +++++-
 nlp/nlp_base.py                  |  8 ++++----
 tts/tts_base.py                  |  2 +-
 ui/pygame_ui.py                  |  2 +-
 utils/async_task_queue.py        | 24 ++++++++++++++----------
 utils/sync_queue.py              |  3 +++
 8 files changed, 49 insertions(+), 24 deletions(-)

diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py
index 0a5acbe..abfe34d 100644
--- a/human/audio_inference_handler.py
+++ b/human/audio_inference_handler.py
@@ -70,14 +70,17 @@ class AudioInferenceHandler(AudioHandler):
         logger.info(f'use device:{device}')
 
         while self._is_running:
+            print('AudioInferenceHandler mel_batch:000')
             if self._exit_event.is_set():
                 start_time = time.perf_counter()
                 batch_size = self._context.batch_size
                 try:
-                    mel_batch = self._mal_queue.get()
-                    size = self._audio_queue.size()
+                    print('AudioInferenceHandler mel_batch:')
+                    mel_batch = self._mal_queue.get(timeout=0.03)
+                    print('AudioInferenceHandler mel_batch:111')
                     # print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', size)
                 except queue.Empty:
+                    print('AudioInferenceHandler mel_batch:111')
                     continue
 
                 # print('origin mel_batch:', len(mel_batch))
@@ -91,12 +94,18 @@ class AudioInferenceHandler(AudioHandler):
                         is_all_silence = False
 
                 if not self._is_running:
-                    return
+                    print('AudioInferenceHandler not running')
+                    break
 
                 if is_all_silence:
                     for i in range(batch_size):
+                        if not self._is_running:
+                            print('AudioInferenceHandler not running1111')
+                            break
+                        print('AudioInferenceHandler is_all_silence 111')
                         self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
                                             0)
+                        print('AudioInferenceHandler is_all_silence 222')
                         index = index + 1
                 else:
                     logger.info('infer=======')
@@ -135,12 +144,15 @@ class AudioInferenceHandler(AudioHandler):
                         count_time = 0
 
                     for i, res_frame in enumerate(pred):
+                        if not self._is_running:
+                            break
                         self.on_next_handle(
                             (res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
                             0)
                         index = index + 1
                     logger.info(f'total batch time: {time.perf_counter() - start_time}')
             else:
+                print('AudioInferenceHandler mel_batch:333')
                 time.sleep(1)
                 break
         logger.info('AudioInferenceHandler inference processor stop')
@@ -149,7 +161,10 @@ class AudioInferenceHandler(AudioHandler):
         logger.info('AudioInferenceHandler stop')
         self._is_running = False
         self._exit_event.clear()
-        # self._run_thread.join()
+        if self._run_thread.is_alive():
+            print('AudioInferenceHandler stop join')
+            self._run_thread.join()
+        print('AudioInferenceHandler stop exit')
 
     def pause_talk(self):
         print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size())
diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py
index 4637255..ed752ae 100644
--- a/human/audio_mal_handler.py
+++ b/human/audio_mal_handler.py
@@ -92,11 +92,10 @@ class AudioMalHandler(AudioHandler):
         self.frames = self.frames[-(self._context.stride_left_size + self._context.stride_right_size):]
 
     def get_audio_frame(self):
-        try:
-            # print('AudioMalHandler get_audio_frame')
+        if not self._queue.is_empty():
             frame = self._queue.get()
             type_ = 0
-        except queue.Empty:
+        else:
             frame = np.zeros(self.chunk, dtype=np.float32)
             type_ = 1
         # print('AudioMalHandler get_audio_frame type:', type_)
diff --git a/human/human_render.py b/human/human_render.py
index eef3f1c..8c1851e 100644
--- a/human/human_render.py
+++ b/human/human_render.py
@@ -48,7 +48,6 @@ class HumanRender(AudioHandler):
         logging.info('human render exit')
 
     def _run_step(self):
-
         try:
             value = self._queue.get(timeout=.005)
             if value is None:
@@ -85,6 +84,10 @@ class HumanRender(AudioHandler):
         super().on_message(message)
 
     def on_handle(self, stream, index):
+        print('human render:', self._is_running)
+        if not self._is_running:
+            return
+
         self._queue.put(stream)
         # res_frame, idx, audio_frames = stream
         # self._voice_render.put(audio_frames, self._last_audio_ps)
@@ -112,6 +115,7 @@ class HumanRender(AudioHandler):
         if self._exit_event is None:
             return
 
+        self._queue.clear()
         self._exit_event.clear()
         if self._thread.is_alive():
             self._thread.join()
diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py
index f186254..7e8f458 100644
--- a/nlp/nlp_base.py
+++ b/nlp/nlp_base.py
@@ -10,18 +10,18 @@ logger = logging.getLogger(__name__)
 
 class NLPBase(AsrObserver):
     def __init__(self, context, split, callback=None):
-        self._ask_queue = AsyncTaskQueue()
+        self._ask_queue = AsyncTaskQueue('NLPBaseQueue')
         self._context = context
         self._split_handle = split
         self._callback = callback
         self._is_running = False
 
-        EventBus().register('stop', self.onStop)
+        EventBus().register('stop', self.on_stop)
 
     def __del__(self):
-        EventBus().unregister('stop', self.onStop)
+        EventBus().unregister('stop', self.on_stop)
 
-    def onStop(self, *args, **kwargs):
+    def on_stop(self, *args, **kwargs):
         self.stop()
 
     @property
diff --git a/tts/tts_base.py b/tts/tts_base.py
index db7a2c0..7fc1ba2 100644
--- a/tts/tts_base.py
+++ b/tts/tts_base.py
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
 class TTSBase(NLPCallback):
     def __init__(self, handle):
         self._handle = handle
-        self._message_queue = AsyncTaskQueue(5)
+        self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5)
         self._is_running = False
 
     @property
diff --git a/ui/pygame_ui.py b/ui/pygame_ui.py
index 50e6ab2..6dca3ec 100644
--- a/ui/pygame_ui.py
+++ b/ui/pygame_ui.py
@@ -60,7 +60,7 @@ class PyGameUI:
     def stop(self):
         logger.info('stop')
         if self._human_context is not None:
-            self._human_context.pause_talk()
+            # self._human_context.pause_talk()
             self._human_context.stop()
 
     def on_render(self, image):
diff --git a/utils/async_task_queue.py b/utils/async_task_queue.py
index e3d395a..b7a32bb 100644
--- a/utils/async_task_queue.py
+++ b/utils/async_task_queue.py
@@ -5,45 +5,49 @@ import threading
 
 
 class AsyncTaskQueue:
-    def __init__(self, work_num=1):
+    def __init__(self, name, work_num=1):
         self._queue = asyncio.Queue()
         self._worker_num = work_num
         self._current_worker_num = work_num
-        self._thread = threading.Thread(target=self._run_loop)
+        self._name = name
+        self._thread = threading.Thread(target=self._run_loop, name=name)
         self._thread.start()
         self.__loop = None
 
     def _run_loop(self):
-        print('_run_loop')
+        print(self._name, '_run_loop')
         self.__loop = asyncio.new_event_loop()
         asyncio.set_event_loop(self.__loop)
         self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
         self.__loop.run_forever()
-        print("exit run")
+        print(self._name, "exit run")
         if not self.__loop.is_closed():
             self.__loop.close()
 
     async def _worker(self):
-        print('_worker')
+        print(self._name, '_worker')
         while True:
+            print(f'{self._name} get queue')
             task = await self._queue.get()
-            print(f"Get task size: {self._queue.qsize()}")
+            print(f'{self._name} get queue11')
+            print(f"{self._name} Get task size: {self._queue.qsize()}")
             if task is None:  # None as a stop signal
                 break
 
             func, *args = task  # Unpack task
-            print(f"Executing task with args: {args}")
+            print(f"{self._name}, Executing task with args: {args}")
             await func(*args)  # Execute async function
             self._queue.task_done()
 
-        print('_worker finish')
+        print(self._name, '_worker finish')
         self._current_worker_num -= 1
         if self._current_worker_num == 0:
-            print('loop stop')
+            print(self._name, 'loop stop')
             self.__loop.stop()
 
     def add_task(self, func, *args):
-        return self.__loop.call_soon_threadsafe(self._queue.put_nowait, (func, *args))
+        # return self.__loop.call_soon_threadsafe(self._queue.put_nowait, (func, *args))
+        self._queue.put_nowait((func, *args))
 
     def stop_workers(self):
         for _ in range(self._worker_num):
diff --git a/utils/sync_queue.py b/utils/sync_queue.py
index 8809f24..cbf31e7 100644
--- a/utils/sync_queue.py
+++ b/utils/sync_queue.py
@@ -10,6 +10,9 @@ class SyncQueue:
         self._queue = Queue(maxsize)
         self._condition = threading.Condition()
 
+    def is_empty(self):
+        return self._queue.empty()
+
     def put(self, item):
         with self._condition:
             while self._queue.full():