modify async

This commit is contained in:
brige 2024-11-07 20:43:46 +08:00
parent c8fc8097e7
commit 4c9ec6831f
8 changed files with 53 additions and 63 deletions

View File

@ -65,7 +65,7 @@ class SherpaNcnnAsr(AsrBase):
self._notify_complete('介绍中国5000年历史文学') self._notify_complete('介绍中国5000年历史文学')
logger.info(f'_recognize_loop111') logger.info(f'_recognize_loop111')
segment_id += 1 segment_id += 1
time.sleep(15) time.sleep(60)
logger.info(f'_recognize_loop222') logger.info(f'_recognize_loop222')
logger.info(f'_recognize_loop exit') logger.info(f'_recognize_loop exit')
''' '''

View File

@ -70,17 +70,13 @@ class AudioInferenceHandler(AudioHandler):
logger.info(f'use device:{device}') logger.info(f'use device:{device}')
while self._is_running: while self._is_running:
print('AudioInferenceHandler mel_batch:000')
if self._exit_event.is_set(): if self._exit_event.is_set():
start_time = time.perf_counter() start_time = time.perf_counter()
batch_size = self._context.batch_size batch_size = self._context.batch_size
try: try:
print('AudioInferenceHandler mel_batch:') mel_batch = self._mal_queue.get(timeout=0.02)
mel_batch = self._mal_queue.get(timeout=0.03) print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', self._mal_queue.size())
print('AudioInferenceHandler mel_batch:111')
# print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', size)
except queue.Empty: except queue.Empty:
print('AudioInferenceHandler mel_batch:111')
continue continue
# print('origin mel_batch:', len(mel_batch)) # print('origin mel_batch:', len(mel_batch))
@ -100,13 +96,11 @@ class AudioInferenceHandler(AudioHandler):
if is_all_silence: if is_all_silence:
for i in range(batch_size): for i in range(batch_size):
if not self._is_running: if not self._is_running:
print('AudioInferenceHandler not running1111')
break break
print('AudioInferenceHandler is_all_silence 111')
self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
0) 0)
print('AudioInferenceHandler is_all_silence 222')
index = index + 1 index = index + 1
print('AudioInferenceHandler all silence')
else: else:
logger.info('infer=======') logger.info('infer=======')
t = time.perf_counter() t = time.perf_counter()
@ -152,7 +146,6 @@ class AudioInferenceHandler(AudioHandler):
index = index + 1 index = index + 1
logger.info(f'total batch time: {time.perf_counter() - start_time}') logger.info(f'total batch time: {time.perf_counter() - start_time}')
else: else:
print('AudioInferenceHandler mel_batch:333')
time.sleep(1) time.sleep(1)
break break
logger.info('AudioInferenceHandler inference processor stop') logger.info('AudioInferenceHandler inference processor stop')
@ -162,9 +155,9 @@ class AudioInferenceHandler(AudioHandler):
self._is_running = False self._is_running = False
self._exit_event.clear() self._exit_event.clear()
if self._run_thread.is_alive(): if self._run_thread.is_alive():
print('AudioInferenceHandler stop join') logger.info('AudioInferenceHandler stop join')
self._run_thread.join() self._run_thread.join()
print('AudioInferenceHandler stop exit') logger.info('AudioInferenceHandler stop exit')
def pause_talk(self): def pause_talk(self):
print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size()) print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size())

View File

@ -115,11 +115,6 @@ class HumanContext:
def stop(self): def stop(self):
EventBus().post('stop') EventBus().post('stop')
object_stop(self._tts)
object_stop(self._tts_handle)
object_stop(self._mal_handler)
object_stop(self._infer_handler)
object_stop(self._render_handler)
def pause_talk(self): def pause_talk(self):
self._nlp.pause_talk() self._nlp.pause_talk()

View File

@ -53,6 +53,7 @@ class HumanRender(AudioHandler):
if value is None: if value is None:
return return
res_frame, idx, audio_frames = value res_frame, idx, audio_frames = value
print('voice render queue size', self._queue.size())
if not self._empty_log: if not self._empty_log:
self._empty_log = True self._empty_log = True
logging.info('render render:') logging.info('render render:')
@ -84,25 +85,10 @@ class HumanRender(AudioHandler):
super().on_message(message) super().on_message(message)
def on_handle(self, stream, index): def on_handle(self, stream, index):
print('human render:', self._is_running)
if not self._is_running: if not self._is_running:
return return
self._queue.put(stream) self._queue.put(stream)
# res_frame, idx, audio_frames = stream
# self._voice_render.put(audio_frames, self._last_audio_ps)
# self._last_audio_ps = self._last_audio_ps + 0.4
# type_ = 1
# if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
# type_ = 0
# self._video_render.put((res_frame, idx, type_), self._last_video_ps)
# self._last_video_ps = self._last_video_ps + 0.4
#
# if self._voice_render.is_full():
# self._context.notify({'msg_id': MessageType.Video_Render_Queue_Full})
# def get_audio_queue_size(self):
# return self._voice_render.size()
def pause_talk(self): def pause_talk(self):
logging.info('hunan pause_talk') logging.info('hunan pause_talk')

View File

@ -14,7 +14,7 @@ class NLPBase(AsrObserver):
self._context = context self._context = context
self._split_handle = split self._split_handle = split
self._callback = callback self._callback = callback
self._is_running = False self._is_running = True
EventBus().register('stop', self.on_stop) EventBus().register('stop', self.on_stop)
@ -46,6 +46,8 @@ class NLPBase(AsrObserver):
pass pass
def completed(self, message: str): def completed(self, message: str):
if not self._is_running:
return
logger.info(f'complete:{message}') logger.info(f'complete:{message}')
# self._context.pause_talk() # self._context.pause_talk()
self.ask(message) self.ask(message)

View File

@ -40,17 +40,16 @@ class DouBao(NLPBase):
], ],
stream=True stream=True
) )
t1 = time.time()
await stream.close()
logger.info(f'-------dou_bao close time:{time.time() - t1:.4f}s')
return
sec = '' sec = ''
async for completion in stream: async for completion in stream:
sec = sec + completion.choices[0].delta.content sec = sec + completion.choices[0].delta.content
sec, message = self._split_handle.handle(sec) sec, message = self._split_handle.handle(sec)
if len(message) > 0: if len(message) > 0:
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
self._on_callback(message) self._on_callback(message)
self._on_callback(sec) self._on_callback(sec)
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
await stream.close() await stream.close()
# sec = "你是测试客服,是由字节跳动开发的 AI 人工智能助手" # sec = "你是测试客服,是由字节跳动开发的 AI 人工智能助手"

View File

@ -3,6 +3,7 @@
import logging import logging
import time import time
from eventbus import EventBus
from nlp import NLPCallback from nlp import NLPCallback
from utils import AsyncTaskQueue from utils import AsyncTaskQueue
@ -13,7 +14,14 @@ class TTSBase(NLPCallback):
def __init__(self, handle): def __init__(self, handle):
self._handle = handle self._handle = handle
self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5) self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5)
self._is_running = False self._is_running = True
EventBus().register('stop', self.on_stop)
def __del__(self):
EventBus().unregister('stop', self.on_stop)
def on_stop(self, *args, **kwargs):
self.stop()
@property @property
def handle(self): def handle(self):
@ -49,7 +57,6 @@ class TTSBase(NLPCallback):
self.message(txt) self.message(txt)
def message(self, txt): def message(self, txt):
self._is_running = True
txt = txt.strip() txt = txt.strip()
if len(txt) == 0: if len(txt) == 0:
logger.info(f'message is empty') logger.info(f'message is empty')
@ -62,10 +69,10 @@ class TTSBase(NLPCallback):
self._message_queue.add_task(self._request, txt, index) self._message_queue.add_task(self._request, txt, index)
def stop(self): def stop(self):
self._is_running = False
self._message_queue.add_task(self._on_close) self._message_queue.add_task(self._on_close)
self._message_queue.stop() self._message_queue.stop()
def pause_talk(self): def pause_talk(self):
logger.info(f'TTSBase pause_talk') logger.info(f'TTSBase pause_talk')
self._is_running = False
self._message_queue.clear() self._message_queue.clear()

View File

@ -1,12 +1,16 @@
#encoding = utf8 #encoding = utf8
import asyncio import asyncio
import logging
from queue import Queue
import threading import threading
logger = logging.getLogger(__name__)
class AsyncTaskQueue: class AsyncTaskQueue:
def __init__(self, name, work_num=1): def __init__(self, name, work_num=1):
self._queue = asyncio.Queue() self._queue = Queue()
self._worker_num = work_num self._worker_num = work_num
self._current_worker_num = work_num self._current_worker_num = work_num
self._name = name self._name = name
@ -15,39 +19,43 @@ class AsyncTaskQueue:
self.__loop = None self.__loop = None
def _run_loop(self): def _run_loop(self):
print(self._name, '_run_loop') logging.info(f'{self._name}, _run_loop')
self.__loop = asyncio.new_event_loop() self.__loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.__loop) asyncio.set_event_loop(self.__loop)
self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)] self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
self.__loop.run_forever() try:
print(self._name, "exit run") self.__loop.run_forever()
if not self.__loop.is_closed(): finally:
logging.info(f'{self._name}, exit run')
self.__loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(self.__loop)))
self.__loop.close() self.__loop.close()
logging.info(f'{self._name}, close loop')
async def _worker(self): async def _worker(self):
print(self._name, '_worker') logging.info(f'{self._name}, _worker')
while True: while True:
print(f'{self._name} get queue') try:
task = await self._queue.get() task = self._queue.get()
print(f'{self._name} get queue11') if task is None: # None as a stop signal
print(f"{self._name} Get task size: {self._queue.qsize()}") break
if task is None: # None as a stop signal
break
func, *args = task # Unpack task func, *args = task # Unpack task
print(f"{self._name}, Executing task with args: {args}") if func is None: # None as a stop signal
await func(*args) # Execute async function break
self._queue.task_done()
print(self._name, '_worker finish') await func(*args) # Execute async function
except Exception as e:
logging.error(f'{self._name} error:', e)
finally:
self._queue.task_done()
logging.info(f'{self._name}, _worker finish')
self._current_worker_num -= 1 self._current_worker_num -= 1
if self._current_worker_num == 0: if self._current_worker_num == 0:
print(self._name, 'loop stop') self.__loop.call_soon_threadsafe(self.__loop.stop)
self.__loop.stop()
def add_task(self, func, *args): def add_task(self, func, *args):
# return self.__loop.call_soon_threadsafe(self._queue.put_nowait, (func, *args)) self._queue.put((func, *args))
self._queue.put_nowait((func, *args))
def stop_workers(self): def stop_workers(self):
for _ in range(self._worker_num): for _ in range(self._worker_num):
@ -60,4 +68,4 @@ class AsyncTaskQueue:
def stop(self): def stop(self):
self.stop_workers() self.stop_workers()
self._thread.join() self._thread.join()