diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py index a363cbc..5552729 100644 --- a/human/audio_inference_handler.py +++ b/human/audio_inference_handler.py @@ -10,7 +10,7 @@ import numpy as np import torch from human_handler import AudioHandler -from utils import load_model, mirror_index, get_device +from utils import load_model, mirror_index, get_device, SyncQueue logger = logging.getLogger(__name__) current_file_path = os.path.dirname(os.path.abspath(__file__)) @@ -21,7 +21,7 @@ class AudioInferenceHandler(AudioHandler): super().__init__(context, handler) self._mal_queue = Queue() - self._audio_queue = Queue() + self._audio_queue = SyncQueue(context.render_batch) self._exit_event = Event() self._run_thread = Thread(target=self.__on_run) @@ -126,5 +126,5 @@ class AudioInferenceHandler(AudioHandler): self._run_thread.join() def pause_talk(self): - self._audio_queue.queue.clear() + self._audio_queue.clear() self._mal_queue.queue.clear() diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py index 5cad0e0..a7dc094 100644 --- a/human/audio_mal_handler.py +++ b/human/audio_mal_handler.py @@ -9,7 +9,7 @@ import numpy as np from human.message_type import MessageType from human_handler import AudioHandler -from utils import melspectrogram +from utils import melspectrogram, SyncQueue logger = logging.getLogger(__name__) @@ -18,9 +18,7 @@ class AudioMalHandler(AudioHandler): def __init__(self, context, handler): super().__init__(context, handler) - self._queue = Queue() - self._wait = False - self._condition = Condition() + self._queue = SyncQueue(context.render_batch) self._exit_event = Event() self._thread = Thread(target=self._on_run) self._exit_event.set() @@ -31,18 +29,7 @@ class AudioMalHandler(AudioHandler): logger.info("AudioMalHandler init") def on_message(self, message): - if message['msg_id'] == MessageType.Video_Render_Queue_Empty: - with self._condition: - if self._wait: - self._wait = False - self._condition.notify() - print('AudioMalHandler notify') - elif message['msg_id'] == MessageType.Video_Render_Queue_Full: - if not self._wait: - self._wait = True - print('AudioMalHandler wait') - else: - super().on_message(message) + super().on_message(message) def on_handle(self, stream, index): self._queue.put(stream) @@ -50,9 +37,6 @@ class AudioMalHandler(AudioHandler): def _on_run(self): logging.info('chunk2mal run') while self._exit_event.is_set(): - with self._condition: - self._condition.wait_for(lambda: not self._wait) - print('AudioMalHandler run') self._run_step() time.sleep(0.02) @@ -111,4 +95,4 @@ class AudioMalHandler(AudioHandler): logging.info('chunk2mal stop') def pause_talk(self): - self._queue.queue.clear() + self._queue.clear() diff --git a/human/human_render.py b/human/human_render.py index bb497db..ae9b211 100644 --- a/human/human_render.py +++ b/human/human_render.py @@ -1,13 +1,6 @@ #encoding = utf8 -import copy -import logging -import queue -import time -from queue import Queue -from threading import Thread, Event -import cv2 -import numpy as np +import logging from human.message_type import MessageType from human_handler import AudioHandler diff --git a/render/base_render.py b/render/base_render.py index cdb4c33..9e62ddc 100644 --- a/render/base_render.py +++ b/render/base_render.py @@ -5,6 +5,8 @@ from abc import ABC, abstractmethod from queue import Queue from threading import Event, Thread +from utils import SyncQueue + logger = logging.getLogger(__name__) @@ -14,7 +16,7 @@ class BaseRender(ABC): self._context = context self._type = type_ self._delay = delay - self._queue = Queue() + self._queue = SyncQueue(context.render_batch) self._exit_event = Event() self._thread = Thread(target=self._on_run) self._exit_event.set() @@ -29,16 +31,16 @@ class BaseRender(ABC): logging.info(f'{self._type} render exit') def put(self, frame, ps): - self._queue.put_nowait((frame, ps)) + self._queue.put((frame, ps)) def size(self): - return self._queue.qsize() + return self._queue.size() def pause_talk(self): - self._queue.queue.clear() + self._queue.clear() def stop(self): - self._queue.queue.clear() + self._queue.clear() self._exit_event.clear() self._thread.join() diff --git a/render/video_render.py b/render/video_render.py index 0a27ff8..27158f8 100644 --- a/render/video_render.py +++ b/render/video_render.py @@ -32,7 +32,7 @@ class VideoRender(BaseRender): self._diff_avg_count += 1 else: if time_difference < -self._play_clock.audio_diff_threshold: - sleep_time = abs(time_difference ) + sleep_time = abs(time_difference) # print("Video frame waiting to catch up with audio", sleep_time) if sleep_time <= 1.0: time.sleep(sleep_time) diff --git a/utils/sync_queue.py b/utils/sync_queue.py index 8439b24..934ad09 100644 --- a/utils/sync_queue.py +++ b/utils/sync_queue.py @@ -30,3 +30,6 @@ class SyncQueue: self._queue.get() self._queue.task_done() self._condition.notify_all() + + def size(self): + return self._queue.qsize()