modiy crash

This commit is contained in:
brige 2024-10-30 16:34:12 +08:00
parent 273dbfb0ec
commit cb6dd54497
9 changed files with 59 additions and 23 deletions

View File

@ -21,7 +21,7 @@ class AudioInferenceHandler(AudioHandler):
super().__init__(context, handler)
self._mal_queue = Queue()
self._audio_queue = SyncQueue(context.render_batch)
self._audio_queue = SyncQueue(context.batch_size * 2)
self._exit_event = Event()
self._run_thread = Thread(target=self.__on_run)
@ -34,6 +34,7 @@ class AudioInferenceHandler(AudioHandler):
self._mal_queue.put(stream)
elif type_ == 0:
self._audio_queue.put(stream)
print('AudioInferenceHandler on_handle', type_)
def on_message(self, message):
super().on_message(message)
@ -60,7 +61,7 @@ class AudioInferenceHandler(AudioHandler):
start_time = time.perf_counter()
batch_size = self._context.batch_size
try:
mel_batch = self._mal_queue.get(block=True, timeout=0.1)
mel_batch = self._mal_queue.get(block=True, timeout=1)
except queue.Empty:
continue
@ -85,6 +86,7 @@ class AudioInferenceHandler(AudioHandler):
face = face_list_cycle[idx]
img_batch.append(face)
print('orign img_batch:', len(img_batch), 'origin mel_batch:', len(mel_batch))
img_batch = np.asarray(img_batch)
mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy()
@ -96,7 +98,7 @@ class AudioInferenceHandler(AudioHandler):
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
print('img_batch:', img_batch.shape, 'mel_batch:', mel_batch.shape)
with torch.no_grad():
pred = model(mel_batch, img_batch)

View File

@ -18,7 +18,7 @@ class AudioMalHandler(AudioHandler):
def __init__(self, context, handler):
super().__init__(context, handler)
self._queue = SyncQueue(context.render_batch)
self._queue = SyncQueue(context.batch_size)
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._exit_event.set()
@ -32,6 +32,7 @@ class AudioMalHandler(AudioHandler):
super().on_message(message)
def on_handle(self, stream, index):
print('AudioMalHandler on_handle', index)
self._queue.put(stream)
def _on_run(self):
@ -51,12 +52,13 @@ class AudioMalHandler(AudioHandler):
if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size:
return
print('AudioMalHandler _run_step')
inputs = np.concatenate(self.frames) # [N * chunk]
mel = melspectrogram(inputs)
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
# cut off stride
left = max(0, self._context.stride_left_size * 80 / 50)
right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50)
# right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50)
mel_idx_multiplier = 80. * 2 / self._context.fps
mel_step_size = 16
i = 0
@ -76,12 +78,15 @@ class AudioMalHandler(AudioHandler):
def get_audio_frame(self):
try:
frame = self._queue.get(block=True, timeout=0.01)
frame = self._queue.get()
type_ = 0
if frame is None:
frame = np.zeros(self.chunk, dtype=np.float32)
type_ = 1
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type_ = 1
print('AudioMalHandler get_audio_frame type:', type_)
return frame, type_
def stop(self):

View File

@ -16,7 +16,7 @@ class BaseRender(ABC):
self._context = context
self._type = type_
self._delay = delay
self._queue = SyncQueue(context.render_batch)
self._queue = SyncQueue(context.batch_size)
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._exit_event.set()

View File

@ -20,7 +20,10 @@ class VideoRender(BaseRender):
def _run_step(self):
while self._exit_event.is_set():
try:
frame, ps = self._queue.get(block=True, timeout=0.02)
value = self._queue.get()
if value is None:
return
frame, ps = value
res_frame, idx, type_ = frame
except Empty:
return
@ -45,7 +48,7 @@ class VideoRender(BaseRender):
self._diff_avg_count = 0
print('video render:', ps, ' ', clock_time, ' ', time_difference,
'get face', self._queue.qsize(), self._diff_avg_count)
'get face', self._queue.size(), self._diff_avg_count)
if type_ == 0:
combine_frame = self._context.frame_list_cycle[idx]

View File

@ -19,11 +19,14 @@ class VoiceRender(BaseRender):
super().__init__(play_clock, context, 'Voice')
def is_full(self):
return self._queue.qsize() >= self._context.render_batch * 2
return self._queue.size() >= self._context.render_batch * 2
def _run_step(self):
try:
audio_frames, ps = self._queue.get(block=True, timeout=0.01)
value = self._queue.get()
if value is None:
return
audio_frames, ps = value
# print('voice render queue size', self._queue.qsize())
except Empty:
self._context.notify({'msg_id': MessageType.Video_Render_Queue_Empty})
@ -37,9 +40,9 @@ class VoiceRender(BaseRender):
self._is_empty = False
status = MessageType.Video_Render_Queue_Not_Empty
if self._queue.qsize() < self._context.render_batch:
if self._queue.size() < self._context.render_batch:
status = MessageType.Video_Render_Queue_Empty
elif self._queue.qsize() >= self._context.render_batch * 2:
elif self._queue.size() >= self._context.render_batch * 2:
status = MessageType.Video_Render_Queue_Full
self._context.notify({'msg_id': status})

View File

@ -2,12 +2,12 @@
import time
from tts import TTSEdge, TTSAudioSaveHandle
from tts import TTSEdge, TTSAudioSaveHandle, TTSEdgeHttp
def main():
handle = TTSAudioSaveHandle()
tts = TTSEdge(handle)
handle = TTSAudioSaveHandle(None, None)
tts = TTSEdgeHttp(handle)
tts.message('你好,')
tts.message('请问有什么可以帮到您,')
tts.message('很高兴为您服务。')

View File

@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
class TTSBase(NLPCallback):
def __init__(self, handle):
self._handle = handle
self._message_queue = AsyncTaskQueue(5)
self._message_queue = AsyncTaskQueue(1)
@property
def handle(self):

View File

@ -21,7 +21,7 @@ class TTSEdgeHttp(TTSBase):
logger.info(f"TTSEdge init, {voice}")
async def _on_request(self, txt: str):
print('_on_request, txt')
print('TTSEdgeHttp, _on_request, txt:', txt)
data = {
"model": "tts-1",
"input": txt,
@ -31,6 +31,7 @@ class TTSEdgeHttp(TTSBase):
}
async with aiohttp.ClientSession() as session:
async with session.post(self._url, json=data) as response:
print('TTSEdgeHttp, _on_request, response:', response)
if response.status == 200:
stream = BytesIO(await response.read())
return stream
@ -43,7 +44,7 @@ class TTSEdgeHttp(TTSBase):
try:
stream.seek(0)
byte_stream = self.__create_bytes_stream(stream)
print('-------tts start push chunk')
print('-------tts start push chunk', index)
self._handle.on_handle(byte_stream, index)
stream.seek(0)
stream.truncate()

View File

@ -3,20 +3,24 @@
import threading
from queue import Queue
'''
class SyncQueue:
def __init__(self, maxsize):
self._queue = Queue(maxsize)
# self._queue = Queue()
self._condition = threading.Condition()
def put(self, item):
# self._queue.put(item)
with self._condition:
while self._queue.full():
print('put wait')
self._condition.wait()
self._queue.put(item)
self._condition.notify()
def get(self):
# return self._queue.get(block=True, timeout=0.01)
with self._condition:
while self._queue.empty():
self._condition.wait()
@ -25,11 +29,29 @@ class SyncQueue:
return item
def clear(self):
# self._queue.queue.clear()
with self._condition:
while not self._queue.empty():
self._queue.get()
self._queue.task_done()
self._queue.queue.clear()
self._condition.notify_all()
def size(self):
return self._queue.qsize()
'''
class SyncQueue:
def __init__(self, maxsize):
self._queue = Queue()
def put(self, item):
self._queue.put(item)
def get(self):
return self._queue.get(block=True, timeout=0.2)
def clear(self):
self._queue.queue.clear()
def size(self):
return self._queue.qsize()