modify use sync queue

This commit is contained in:
brige 2024-10-29 18:09:26 +08:00
parent 4c6b27ad43
commit 273dbfb0ec
6 changed files with 19 additions and 37 deletions

View File

@ -10,7 +10,7 @@ import numpy as np
import torch import torch
from human_handler import AudioHandler from human_handler import AudioHandler
from utils import load_model, mirror_index, get_device from utils import load_model, mirror_index, get_device, SyncQueue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
current_file_path = os.path.dirname(os.path.abspath(__file__)) current_file_path = os.path.dirname(os.path.abspath(__file__))
@ -21,7 +21,7 @@ class AudioInferenceHandler(AudioHandler):
super().__init__(context, handler) super().__init__(context, handler)
self._mal_queue = Queue() self._mal_queue = Queue()
self._audio_queue = Queue() self._audio_queue = SyncQueue(context.render_batch)
self._exit_event = Event() self._exit_event = Event()
self._run_thread = Thread(target=self.__on_run) self._run_thread = Thread(target=self.__on_run)
@ -126,5 +126,5 @@ class AudioInferenceHandler(AudioHandler):
self._run_thread.join() self._run_thread.join()
def pause_talk(self): def pause_talk(self):
self._audio_queue.queue.clear() self._audio_queue.clear()
self._mal_queue.queue.clear() self._mal_queue.queue.clear()

View File

@ -9,7 +9,7 @@ import numpy as np
from human.message_type import MessageType from human.message_type import MessageType
from human_handler import AudioHandler from human_handler import AudioHandler
from utils import melspectrogram from utils import melspectrogram, SyncQueue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,9 +18,7 @@ class AudioMalHandler(AudioHandler):
def __init__(self, context, handler): def __init__(self, context, handler):
super().__init__(context, handler) super().__init__(context, handler)
self._queue = Queue() self._queue = SyncQueue(context.render_batch)
self._wait = False
self._condition = Condition()
self._exit_event = Event() self._exit_event = Event()
self._thread = Thread(target=self._on_run) self._thread = Thread(target=self._on_run)
self._exit_event.set() self._exit_event.set()
@ -31,18 +29,7 @@ class AudioMalHandler(AudioHandler):
logger.info("AudioMalHandler init") logger.info("AudioMalHandler init")
def on_message(self, message): def on_message(self, message):
if message['msg_id'] == MessageType.Video_Render_Queue_Empty: super().on_message(message)
with self._condition:
if self._wait:
self._wait = False
self._condition.notify()
print('AudioMalHandler notify')
elif message['msg_id'] == MessageType.Video_Render_Queue_Full:
if not self._wait:
self._wait = True
print('AudioMalHandler wait')
else:
super().on_message(message)
def on_handle(self, stream, index): def on_handle(self, stream, index):
self._queue.put(stream) self._queue.put(stream)
@ -50,9 +37,6 @@ class AudioMalHandler(AudioHandler):
def _on_run(self): def _on_run(self):
logging.info('chunk2mal run') logging.info('chunk2mal run')
while self._exit_event.is_set(): while self._exit_event.is_set():
with self._condition:
self._condition.wait_for(lambda: not self._wait)
print('AudioMalHandler run')
self._run_step() self._run_step()
time.sleep(0.02) time.sleep(0.02)
@ -111,4 +95,4 @@ class AudioMalHandler(AudioHandler):
logging.info('chunk2mal stop') logging.info('chunk2mal stop')
def pause_talk(self): def pause_talk(self):
self._queue.queue.clear() self._queue.clear()

View File

@ -1,13 +1,6 @@
#encoding = utf8 #encoding = utf8
import copy
import logging
import queue
import time
from queue import Queue
from threading import Thread, Event
import cv2 import logging
import numpy as np
from human.message_type import MessageType from human.message_type import MessageType
from human_handler import AudioHandler from human_handler import AudioHandler

View File

@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
from queue import Queue from queue import Queue
from threading import Event, Thread from threading import Event, Thread
from utils import SyncQueue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -14,7 +16,7 @@ class BaseRender(ABC):
self._context = context self._context = context
self._type = type_ self._type = type_
self._delay = delay self._delay = delay
self._queue = Queue() self._queue = SyncQueue(context.render_batch)
self._exit_event = Event() self._exit_event = Event()
self._thread = Thread(target=self._on_run) self._thread = Thread(target=self._on_run)
self._exit_event.set() self._exit_event.set()
@ -29,16 +31,16 @@ class BaseRender(ABC):
logging.info(f'{self._type} render exit') logging.info(f'{self._type} render exit')
def put(self, frame, ps): def put(self, frame, ps):
self._queue.put_nowait((frame, ps)) self._queue.put((frame, ps))
def size(self): def size(self):
return self._queue.qsize() return self._queue.size()
def pause_talk(self): def pause_talk(self):
self._queue.queue.clear() self._queue.clear()
def stop(self): def stop(self):
self._queue.queue.clear() self._queue.clear()
self._exit_event.clear() self._exit_event.clear()
self._thread.join() self._thread.join()

View File

@ -32,7 +32,7 @@ class VideoRender(BaseRender):
self._diff_avg_count += 1 self._diff_avg_count += 1
else: else:
if time_difference < -self._play_clock.audio_diff_threshold: if time_difference < -self._play_clock.audio_diff_threshold:
sleep_time = abs(time_difference ) sleep_time = abs(time_difference)
# print("Video frame waiting to catch up with audio", sleep_time) # print("Video frame waiting to catch up with audio", sleep_time)
if sleep_time <= 1.0: if sleep_time <= 1.0:
time.sleep(sleep_time) time.sleep(sleep_time)

View File

@ -30,3 +30,6 @@ class SyncQueue:
self._queue.get() self._queue.get()
self._queue.task_done() self._queue.task_done()
self._condition.notify_all() self._condition.notify_all()
def size(self):
return self._queue.qsize()