modify render sync

This commit is contained in:
brige 2024-11-01 20:38:57 +08:00
parent 82eac73454
commit b69dd5a800
9 changed files with 47 additions and 47 deletions

View File

@ -20,11 +20,11 @@ class AudioInferenceHandler(AudioHandler):
def __init__(self, context, handler): def __init__(self, context, handler):
super().__init__(context, handler) super().__init__(context, handler)
self._mal_queue = Queue() self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel")
self._audio_queue = SyncQueue(context.batch_size * 2) self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio")
self._exit_event = Event() self._exit_event = Event()
self._run_thread = Thread(target=self.__on_run) self._run_thread = Thread(target=self.__on_run, name="AudioInferenceHandlerThread")
self._exit_event.set() self._exit_event.set()
self._run_thread.start() self._run_thread.start()
logger.info("AudioInferenceHandler init") logger.info("AudioInferenceHandler init")
@ -34,7 +34,7 @@ class AudioInferenceHandler(AudioHandler):
self._mal_queue.put(stream) self._mal_queue.put(stream)
elif type_ == 0: elif type_ == 0:
self._audio_queue.put(stream) self._audio_queue.put(stream)
print('AudioInferenceHandler on_handle', type_, self._audio_queue.size()) # print('AudioInferenceHandler on_handle', type_, self._audio_queue.size())
def on_message(self, message): def on_message(self, message):
super().on_message(message) super().on_message(message)
@ -61,16 +61,18 @@ class AudioInferenceHandler(AudioHandler):
start_time = time.perf_counter() start_time = time.perf_counter()
batch_size = self._context.batch_size batch_size = self._context.batch_size
try: try:
mel_batch = self._mal_queue.get(block=True, timeout=1) mel_batch = self._mal_queue.get()
size = self._audio_queue.size()
# print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', size)
except queue.Empty: except queue.Empty:
continue continue
print('origin mel_batch:', len(mel_batch)) # print('origin mel_batch:', len(mel_batch))
is_all_silence = True is_all_silence = True
audio_frames = [] audio_frames = []
for _ in range(batch_size * 2): for _ in range(batch_size * 2):
frame, type_ = self._audio_queue.get() frame, type_ = self._audio_queue.get()
print('AudioInferenceHandler type_', type_) # print('AudioInferenceHandler type_', type_)
audio_frames.append((frame, type_)) audio_frames.append((frame, type_))
if type_ == 0: if type_ == 0:
is_all_silence = False is_all_silence = False
@ -89,7 +91,7 @@ class AudioInferenceHandler(AudioHandler):
face = face_list_cycle[idx] face = face_list_cycle[idx]
img_batch.append(face) img_batch.append(face)
print('orign img_batch:', len(img_batch), 'origin mel_batch:', len(mel_batch)) # print('orign img_batch:', len(img_batch), 'origin mel_batch:', len(mel_batch))
img_batch = np.asarray(img_batch) img_batch = np.asarray(img_batch)
mel_batch = np.asarray(mel_batch) mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy() img_masked = img_batch.copy()

View File

@ -2,12 +2,11 @@
import logging import logging
import queue import queue
import time import time
from queue import Queue
from threading import Thread, Event, Condition from threading import Thread, Event, Condition
import numpy as np import numpy as np
from human.message_type import MessageType
from human_handler import AudioHandler from human_handler import AudioHandler
from utils import melspectrogram, SyncQueue from utils import melspectrogram, SyncQueue
@ -18,9 +17,9 @@ class AudioMalHandler(AudioHandler):
def __init__(self, context, handler): def __init__(self, context, handler):
super().__init__(context, handler) super().__init__(context, handler)
self._queue = SyncQueue(context.batch_size) self._queue = SyncQueue(context.batch_size, "AudioMalHandler_queue")
self._exit_event = Event() self._exit_event = Event()
self._thread = Thread(target=self._on_run) self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
self._exit_event.set() self._exit_event.set()
self._thread.start() self._thread.start()
@ -32,7 +31,7 @@ class AudioMalHandler(AudioHandler):
super().on_message(message) super().on_message(message)
def on_handle(self, stream, index): def on_handle(self, stream, index):
print('AudioMalHandler on_handle', index) # print('AudioMalHandler on_handle', index)
self._queue.put(stream) self._queue.put(stream)
def _on_run(self): def _on_run(self):
@ -49,13 +48,12 @@ class AudioMalHandler(AudioHandler):
frame, _type = self.get_audio_frame() frame, _type = self.get_audio_frame()
self.frames.append(frame) self.frames.append(frame)
self.on_next_handle((frame, _type), 0) self.on_next_handle((frame, _type), 0)
print("AudioMalHandler _type", _type)
count = count + 1 count = count + 1
# context not enough, do not run network. # context not enough, do not run network.
if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size: if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size:
return return
print('AudioMalHandler _run_step', count) # print('AudioMalHandler _run_step', count)
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]
mel = melspectrogram(inputs) mel = melspectrogram(inputs)
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames)) # print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
@ -70,13 +68,10 @@ class AudioMalHandler(AudioHandler):
start_idx = int(left + i * mel_idx_multiplier) start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx) # print(start_idx)
if start_idx + mel_step_size > len(mel[0]): if start_idx + mel_step_size > len(mel[0]):
print("AudioMalHandler start_idx", start_idx)
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
else: else:
print("AudioMalHandler start_idx222", start_idx + mel_step_size - start_idx)
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
i += 1 i += 1
print("AudioMalHandler start_idx333", i)
self.on_next_handle(mel_chunks, 1) self.on_next_handle(mel_chunks, 1)
# discard the old part to save memory # discard the old part to save memory
@ -84,15 +79,13 @@ class AudioMalHandler(AudioHandler):
def get_audio_frame(self): def get_audio_frame(self):
try: try:
frame = self._queue.get() # print('AudioMalHandler get_audio_frame')
frame = self._queue.get(timeout=0.02)
type_ = 0 type_ = 0
if frame is None:
frame = np.zeros(self.chunk, dtype=np.float32)
type_ = 1
except queue.Empty: except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32) frame = np.zeros(self.chunk, dtype=np.float32)
type_ = 1 type_ = 1
print('AudioMalHandler get_audio_frame type:', type_) # print('AudioMalHandler get_audio_frame type:', type_)
return frame, type_ return frame, type_
def stop(self): def stop(self):

View File

@ -33,7 +33,7 @@ class HumanRender(AudioHandler):
def on_handle(self, stream, index): def on_handle(self, stream, index):
res_frame, idx, audio_frames = stream res_frame, idx, audio_frames = stream
self._voice_render.put(audio_frames, self._last_audio_ps) self._voice_render.put(audio_frames, self._last_audio_ps)
self._last_audio_ps = self._last_audio_ps + 0.2 self._last_audio_ps = self._last_audio_ps + 0.4
type_ = 1 type_ = 1
if audio_frames[0][1] != 0 and audio_frames[1][1] != 0: if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
type_ = 0 type_ = 0
@ -43,6 +43,9 @@ class HumanRender(AudioHandler):
if self._voice_render.is_full(): if self._voice_render.is_full():
self._context.notify({'msg_id': MessageType.Video_Render_Queue_Full}) self._context.notify({'msg_id': MessageType.Video_Render_Queue_Full})
def get_audio_queue_size(self):
return self._voice_render.size()
def pause_talk(self): def pause_talk(self):
self._voice_render.pause_talk() self._voice_render.pause_talk()
self._video_render.pause_talk() self._video_render.pause_talk()

View File

@ -44,7 +44,6 @@ class DouBao(NLPBase):
sec = '' sec = ''
async for completion in stream: async for completion in stream:
sec = sec + completion.choices[0].delta.content sec = sec + completion.choices[0].delta.content
print('DouBao content:', sec)
sec, message = self._split_handle.handle(sec) sec, message = self._split_handle.handle(sec)
if len(message) > 0: if len(message) > 0:
self._on_callback(message) self._on_callback(message)

View File

@ -11,14 +11,14 @@ logger = logging.getLogger(__name__)
class BaseRender(ABC): class BaseRender(ABC):
def __init__(self, play_clock, context, type_, delay=0.02): def __init__(self, play_clock, context, type_, delay=0.02, thread_name="BaseRenderThread"):
self._play_clock = play_clock self._play_clock = play_clock
self._context = context self._context = context
self._type = type_ self._type = type_
self._delay = delay self._delay = delay
self._queue = SyncQueue(context.batch_size) self._queue = SyncQueue(context.batch_size, f'{type_}RenderQueue')
self._exit_event = Event() self._exit_event = Event()
self._thread = Thread(target=self._on_run) self._thread = Thread(target=self._on_run, name=thread_name)
self._exit_event.set() self._exit_event.set()
self._thread.start() self._thread.start()
@ -31,7 +31,6 @@ class BaseRender(ABC):
logging.info(f'{self._type} render exit') logging.info(f'{self._type} render exit')
def put(self, frame, ps): def put(self, frame, ps):
print('put:', ps)
self._queue.put((frame, ps)) self._queue.put((frame, ps))
def size(self): def size(self):

View File

@ -13,7 +13,7 @@ from human.message_type import MessageType
class VideoRender(BaseRender): class VideoRender(BaseRender):
def __init__(self, play_clock, context, human_render): def __init__(self, play_clock, context, human_render):
super().__init__(play_clock, context, 'Video', 0.038) super().__init__(play_clock, context, 'Video', 0.03, "VideoRenderThread")
self._human_render = human_render self._human_render = human_render
self._diff_avg_count = 0 self._diff_avg_count = 0
@ -31,7 +31,7 @@ class VideoRender(BaseRender):
clock_time = self._play_clock.clock_time() clock_time = self._play_clock.clock_time()
time_difference = clock_time - ps time_difference = clock_time - ps
if abs(time_difference) > self._play_clock.audio_diff_threshold: if abs(time_difference) > self._play_clock.audio_diff_threshold:
if self._diff_avg_count < 10: if self._diff_avg_count < 5:
self._diff_avg_count += 1 self._diff_avg_count += 1
else: else:
if time_difference < -self._play_clock.audio_diff_threshold: if time_difference < -self._play_clock.audio_diff_threshold:
@ -47,8 +47,9 @@ class VideoRender(BaseRender):
else: else:
self._diff_avg_count = 0 self._diff_avg_count = 0
print('video render:', ps, clock_time, time_difference, print('video render:',
'get face', self._queue.size(), self._diff_avg_count) 'get face', self._queue.size(),
'audio queue', self._human_render.get_audio_queue_size())
if type_ == 0: if type_ == 0:
combine_frame = self._context.frame_list_cycle[idx] combine_frame = self._context.frame_list_cycle[idx]

View File

@ -16,7 +16,7 @@ class VoiceRender(BaseRender):
def __init__(self, play_clock, context): def __init__(self, play_clock, context):
self._audio_render = AudioRender() self._audio_render = AudioRender()
self._is_empty = True self._is_empty = True
super().__init__(play_clock, context, 'Voice') super().__init__(play_clock, context, 'Voice', 0.03, "VoiceRenderThread")
def is_full(self): def is_full(self):
return self._queue.size() >= self._context.render_batch * 2 return self._queue.size() >= self._context.render_batch * 2
@ -27,7 +27,7 @@ class VoiceRender(BaseRender):
if value is None: if value is None:
return return
audio_frames, ps = value audio_frames, ps = value
# print('voice render queue size', self._queue.qsize()) # print('voice render queue size', self._queue.size())
except Empty: except Empty:
self._context.notify({'msg_id': MessageType.Video_Render_Queue_Empty}) self._context.notify({'msg_id': MessageType.Video_Render_Queue_Empty})
if not self._is_empty: if not self._is_empty:
@ -55,6 +55,8 @@ class VoiceRender(BaseRender):
if self._audio_render is not None: if self._audio_render is not None:
try: try:
self._audio_render.write(frame.tobytes(), int(frame.shape[0] * 2)) chunk_len = int(frame.shape[0] * 2)
# print('audio frame:', frame.shape, chunk_len)
self._audio_render.write(frame.tobytes(), chunk_len)
except Exception as e: except Exception as e:
logging.error(f'Error writing audio frame: {e}') logging.error(f'Error writing audio frame: {e}')

2
ui.py
View File

@ -57,7 +57,7 @@ class App(customtkinter.CTk):
# self.main_button_1.grid(row=2, column=2, padx=(20, 20), pady=(20, 20), sticky="nsew") # self.main_button_1.grid(row=2, column=2, padx=(20, 20), pady=(20, 20), sticky="nsew")
background = os.path.join(current_file_path, 'data', 'background', 'background.webp') background = os.path.join(current_file_path, 'data', 'background', 'background.webp')
logger.info(f'background: {background}') logger.info(f'background: {background}')
self._background = ImageTk.PhotoImage(read_image(background)) # self._background = ImageTk.PhotoImage(read_image(background))
self._init_image_canvas() self._init_image_canvas()

View File

@ -5,35 +5,36 @@ from queue import Queue
class SyncQueue: class SyncQueue:
def __init__(self, maxsize): def __init__(self, maxsize, name):
self._name = name
self._queue = Queue(maxsize) self._queue = Queue(maxsize)
# self._queue = Queue()
self._condition = threading.Condition() self._condition = threading.Condition()
def put(self, item): def put(self, item):
# self._queue.put(item)
with self._condition: with self._condition:
while self._queue.full(): while self._queue.full():
print('put wait') # print(self._name, 'put wait')
self._condition.wait() self._condition.wait()
self._queue.put(item) self._queue.put(item)
self._condition.notify() self._condition.notify()
def get(self): def get(self, timeout=None):
# return self._queue.get(block=True, timeout=0.01) # 添加超时时间,防止死锁
with self._condition: with self._condition:
while self._queue.empty(): while self._queue.empty():
self._condition.wait() self._condition.wait(timeout=timeout)
item = self._queue.get() # print(self._name, 'get wait')
if timeout is not None:
break
item = self._queue.get(block=False)
self._condition.notify() self._condition.notify()
return item return item
def clear(self): def clear(self):
# self._queue.queue.clear()
with self._condition: with self._condition:
while not self._queue.empty(): while not self._queue.empty():
self._queue.queue.clear() self._queue.queue.clear()
self._condition.notify_all() self._condition.notify()
def size(self): def size(self):
return self._queue.qsize() return self._queue.qsize()