Merge branch 'render_sync' into develop

This commit is contained in:
brige 2024-11-08 09:33:41 +08:00
commit d9f55d1ba1
23 changed files with 307 additions and 124 deletions

View File

@ -2,6 +2,7 @@
import threading
from eventbus import EventBus
from .asr_observer import AsrObserver
@ -12,11 +13,19 @@ class AsrBase:
self._samples_per_read = 100
self._observers = []
EventBus().register('stop', self._on_stop)
self._stop_event = threading.Event()
self._stop_event.set()
self._thread = threading.Thread(target=self._recognize_loop)
self._thread.start()
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def _recognize_loop(self):
pass

View File

@ -27,6 +27,7 @@ class SherpaNcnnAsr(AsrBase):
super().__init__()
self._recognizer = self._create_recognizer()
logger.info('SherpaNcnnAsr init')
print('SherpaNcnnAsr init')
def __del__(self):
self.__del__()
@ -60,17 +61,10 @@ class SherpaNcnnAsr(AsrBase):
time.sleep(3)
last_result = ""
logger.info(f'_recognize_loop')
while self._stop_event.is_set():
logger.info(f'_recognize_loop000')
self._notify_complete('介绍中国5000年历史文学')
logger.info(f'_recognize_loop111')
segment_id += 1
time.sleep(50)
logger.info(f'_recognize_loop222')
logger.info(f'_recognize_loop exit')
'''
print(f'_recognize_loop')
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
while not self._stop_event.is_set():
while self._stop_event.is_set():
samples, _ = s.read(self._samples_per_read) # a blocking read
samples = samples.reshape(-1)
self._recognizer.accept_waveform(self._sample_rate, samples)
@ -89,4 +83,13 @@ class SherpaNcnnAsr(AsrBase):
self._notify_complete(result)
segment_id += 1
self._recognizer.reset()
'''
while self._stop_event.is_set():
logger.info(f'_recognize_loop000')
self._notify_complete('介绍中国5000年历史文学')
logger.info(f'_recognize_loop111')
segment_id += 1
time.sleep(60)
logger.info(f'_recognize_loop222')
logger.info(f'_recognize_loop exit')
'''

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

4
eventbus/__init__.py Normal file
View File

@ -0,0 +1,4 @@
#encoding = utf8
from .event_bus import EventBus

39
eventbus/event_bus.py Normal file
View File

@ -0,0 +1,39 @@
#encoding = utf8
import threading
class EventBus:
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super(EventBus, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
self._listeners = {}
self._lock = threading.Lock()
self._initialized = True
def register(self, event_type, listener):
with self._lock:
if event_type not in self._listeners:
self._listeners[event_type] = []
self._listeners[event_type].append(listener)
def unregister(self, event_type, listener):
with self._lock:
if event_type in self._listeners:
self._listeners[event_type].remove(listener)
if not self._listeners[event_type]:
del self._listeners[event_type]
def post(self, event_type, *args, **kwargs):
with self._lock:
listeners = self._listeners.get(event_type, []).copy()
for listener in listeners:
listener(*args, **kwargs)

BIN
face/img00016.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 452 KiB

BIN
face/img00016.jpg.bak1 Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 452 KiB

BIN
face/img00020.png.bak Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 258 KiB

View File

@ -9,6 +9,7 @@ from threading import Event, Thread
import numpy as np
import torch
from eventbus import EventBus
from human_handler import AudioHandler
from utils import load_model, mirror_index, get_device, SyncQueue
@ -20,16 +21,28 @@ class AudioInferenceHandler(AudioHandler):
def __init__(self, context, handler):
super().__init__(context, handler)
EventBus().register('stop', self._on_stop)
self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel")
self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio")
self._is_running = True
self._exit_event = Event()
self._run_thread = Thread(target=self.__on_run, name="AudioInferenceHandlerThread")
self._exit_event.set()
self._run_thread.start()
logger.info("AudioInferenceHandler init")
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def on_handle(self, stream, type_):
if not self._is_running:
return
if type_ == 1:
self._mal_queue.put(stream)
elif type_ == 0:
@ -56,14 +69,13 @@ class AudioInferenceHandler(AudioHandler):
device = get_device()
logger.info(f'use device:{device}')
while True:
while self._is_running:
if self._exit_event.is_set():
start_time = time.perf_counter()
batch_size = self._context.batch_size
try:
mel_batch = self._mal_queue.get()
size = self._audio_queue.size()
# print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', size)
mel_batch = self._mal_queue.get(timeout=0.02)
# print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', self._mal_queue.size())
except queue.Empty:
continue
@ -76,8 +88,15 @@ class AudioInferenceHandler(AudioHandler):
audio_frames.append((frame, type_))
if type_ == 0:
is_all_silence = False
if not self._is_running:
print('AudioInferenceHandler not running')
break
if is_all_silence:
for i in range(batch_size):
if not self._is_running:
break
self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
0)
index = index + 1
@ -118,6 +137,8 @@ class AudioInferenceHandler(AudioHandler):
count_time = 0
for i, res_frame in enumerate(pred):
if not self._is_running:
break
self.on_next_handle(
(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
0)
@ -129,10 +150,17 @@ class AudioInferenceHandler(AudioHandler):
logger.info('AudioInferenceHandler inference processor stop')
def stop(self):
logger.info('AudioInferenceHandler stop')
self._is_running = False
self._exit_event.clear()
self._run_thread.join()
if self._run_thread.is_alive():
logger.info('AudioInferenceHandler stop join')
self._run_thread.join()
logger.info('AudioInferenceHandler stop exit')
def pause_talk(self):
print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size())
self._audio_queue.clear()
print('AudioInferenceHandler111')
self._mal_queue.clear()
print('AudioInferenceHandler222')

View File

@ -3,7 +3,8 @@ import logging
import queue
import time
from threading import Thread, Event, Condition
from threading import Thread, Event
from eventbus import EventBus
import numpy as np
@ -17,16 +18,26 @@ class AudioMalHandler(AudioHandler):
def __init__(self, context, handler):
super().__init__(context, handler)
self._queue = SyncQueue(context.batch_size, "AudioMalHandler_queue")
self._exit_event = Event()
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
self._exit_event.set()
self._thread.start()
EventBus().register('stop', self._on_stop)
self._is_running = True
self._queue = SyncQueue(context.batch_size * 2, "AudioMalHandler_queue")
self.frames = []
self.chunk = context.sample_rate // context.fps
self._exit_event = Event()
self._exit_event.set()
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
self._thread.start()
logger.info("AudioMalHandler init")
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def on_message(self, message):
super().on_message(message)
@ -36,7 +47,7 @@ class AudioMalHandler(AudioHandler):
def _on_run(self):
logging.info('chunk2mal run')
while self._exit_event.is_set():
while self._exit_event.is_set() and self._is_running:
self._run_step()
time.sleep(0.02)
@ -49,6 +60,9 @@ class AudioMalHandler(AudioHandler):
self.frames.append(frame)
self.on_next_handle((frame, _type), 0)
count = count + 1
if self._is_running is False:
return
# context not enough, do not run network.
if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size:
return
@ -64,7 +78,8 @@ class AudioMalHandler(AudioHandler):
mel_step_size = 16
i = 0
mel_chunks = []
while i < (len(self.frames) - self._context.stride_left_size - self._context.stride_right_size) / 2:
while i < (len(self.frames) - self._context.stride_left_size - self._context.stride_right_size) / 2\
and self._is_running:
start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx)
if start_idx + mel_step_size > len(mel[0]):
@ -78,11 +93,10 @@ class AudioMalHandler(AudioHandler):
self.frames = self.frames[-(self._context.stride_left_size + self._context.stride_right_size):]
def get_audio_frame(self):
try:
# print('AudioMalHandler get_audio_frame')
if not self._queue.is_empty():
frame = self._queue.get()
type_ = 0
except queue.Empty:
else:
frame = np.zeros(self.chunk, dtype=np.float32)
type_ = 1
# print('AudioMalHandler get_audio_frame type:', type_)
@ -90,6 +104,7 @@ class AudioMalHandler(AudioHandler):
def stop(self):
logging.info('stop')
self._is_running = False
if self._exit_event is None:
return

View File

@ -3,6 +3,7 @@ import logging
import os
from asr import SherpaNcnnAsr
from eventbus import EventBus
from .audio_inference_handler import AudioInferenceHandler
from .audio_mal_handler import AudioMalHandler
from .human_render import HumanRender
@ -113,13 +114,7 @@ class HumanContext:
self._asr.attach(self._nlp)
def stop(self):
object_stop(self._asr)
object_stop(self._nlp)
object_stop(self._tts)
object_stop(self._tts_handle)
object_stop(self._mal_handler)
object_stop(self._infer_handler)
object_stop(self._render_handler)
EventBus().post('stop')
def pause_talk(self):
self._nlp.pause_talk()

View File

@ -5,6 +5,7 @@ import time
from queue import Empty
from threading import Event, Thread
from eventbus import EventBus
from human.message_type import MessageType
from human_handler import AudioHandler
from render import VoiceRender, VideoRender, PlayClock
@ -17,9 +18,11 @@ class HumanRender(AudioHandler):
def __init__(self, context, handler):
super().__init__(context, handler)
EventBus().register('stop', self._on_stop)
play_clock = PlayClock()
self._voice_render = VoiceRender(play_clock, context)
self._video_render = VideoRender(play_clock, context, self)
self._is_running = True
self._queue = SyncQueue(context.batch_size, "HumanRender_queue")
self._exit_event = Event()
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
@ -28,24 +31,43 @@ class HumanRender(AudioHandler):
self._image_render = None
self._last_audio_ps = 0
self._last_video_ps = 0
self._empty_log = True
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def _on_run(self):
logging.info('human render run')
while self._exit_event.is_set():
while self._exit_event.is_set() and self._is_running:
# t = time.time()
self._run_step()
time.sleep(0.038)
# delay = time.time() - t
delay = 0.03805 # - delay
# print(delay)
# if delay <= 0.0:
# continue
time.sleep(delay)
logging.info('human render exit')
def _run_step(self):
try:
value = self._queue.get()
value = self._queue.get(timeout=.005)
if value is None:
return
res_frame, idx, audio_frames = value
# print('render queue size', self._queue.size())
if not self._empty_log:
self._empty_log = True
logging.info('render render:')
# print('voice render queue size', self._queue.size())
except Empty:
print('render queue.Empty:')
if self._empty_log:
self._empty_log = False
logging.info('render queue.Empty:')
return
type_ = 1
@ -69,32 +91,23 @@ class HumanRender(AudioHandler):
super().on_message(message)
def on_handle(self, stream, index):
self._queue.put(stream)
# res_frame, idx, audio_frames = stream
# self._voice_render.put(audio_frames, self._last_audio_ps)
# self._last_audio_ps = self._last_audio_ps + 0.4
# type_ = 1
# if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
# type_ = 0
# self._video_render.put((res_frame, idx, type_), self._last_video_ps)
# self._last_video_ps = self._last_video_ps + 0.4
#
# if self._voice_render.is_full():
# self._context.notify({'msg_id': MessageType.Video_Render_Queue_Full})
if not self._is_running:
return
# def get_audio_queue_size(self):
# return self._voice_render.size()
self._queue.put(stream)
def pause_talk(self):
pass
logging.info('hunan pause_talk')
# self._voice_render.pause_talk()
# self._video_render.pause_talk()
def stop(self):
logging.info('hunan render stop')
self._is_running = False
if self._exit_event is None:
return
self._queue.clear()
self._exit_event.clear()
if self._thread.is_alive():
self._thread.join()

View File

@ -2,6 +2,7 @@
import logging
from asr import AsrObserver
from eventbus import EventBus
from utils import AsyncTaskQueue
logger = logging.getLogger(__name__)
@ -9,11 +10,19 @@ logger = logging.getLogger(__name__)
class NLPBase(AsrObserver):
def __init__(self, context, split, callback=None):
self._ask_queue = AsyncTaskQueue()
self._ask_queue = AsyncTaskQueue('NLPBaseQueue')
self._context = context
self._split_handle = split
self._callback = callback
self._is_running = False
self._is_running = True
EventBus().register('stop', self.on_stop)
def __del__(self):
EventBus().unregister('stop', self.on_stop)
def on_stop(self, *args, **kwargs):
self.stop()
@property
def callback(self):
@ -37,6 +46,8 @@ class NLPBase(AsrObserver):
pass
def completed(self, message: str):
if not self._is_running:
return
logger.info(f'complete:{message}')
# self._context.pause_talk()
self.ask(message)
@ -50,8 +61,11 @@ class NLPBase(AsrObserver):
def stop(self):
logger.info('NLPBase stop')
self._is_running = False
self._ask_queue.clear()
self._ask_queue.add_task(self._on_close)
logger.info('NLPBase add close')
self._ask_queue.stop()
logger.info('NLPBase _ask_queue stop')
def pause_talk(self):
logger.info('NLPBase pause_talk')

View File

@ -30,7 +30,6 @@ class DouBao(NLPBase):
async def _request(self, question):
t = time.time()
logger.info(f'_request:{question}')
logger.info(f'-------dou_bao ask:{question}')
try:
stream = await self.__client.chat.completions.create(
@ -41,13 +40,16 @@ class DouBao(NLPBase):
],
stream=True
)
sec = ''
async for completion in stream:
sec = sec + completion.choices[0].delta.content
sec, message = self._split_handle.handle(sec)
if len(message) > 0:
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
self._on_callback(message)
self._on_callback(sec)
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
await stream.close()
# sec = "你是测试客服,是由字节跳动开发的 AI 人工智能助手"
@ -59,7 +61,6 @@ class DouBao(NLPBase):
# self._on_callback(sec)
except Exception as e:
print(e)
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
async def _on_close(self):

View File

@ -15,28 +15,9 @@ class VideoRender(BaseRender):
def __init__(self, play_clock, context, human_render):
super().__init__(play_clock, context, 'Video')
self._human_render = human_render
self._diff_avg_count = 0
def render(self, frame, ps):
res_frame, idx, type_ = frame
clock_time = self._play_clock.clock_time()
time_difference = clock_time - ps
if abs(time_difference) > self._play_clock.audio_diff_threshold:
if self._diff_avg_count < 5:
self._diff_avg_count += 1
else:
if time_difference < -self._play_clock.audio_diff_threshold:
sleep_time = abs(time_difference)
print("Video frame waiting to catch up with audio", sleep_time)
if sleep_time <= 1.0:
time.sleep(sleep_time)
# elif time_difference > self._play_clock.audio_diff_threshold: # 视频比音频快超过10ms
# print("Video frame dropped to catch up with audio")
# continue
else:
self._diff_avg_count = 0
if type_ == 0:
combine_frame = self._context.frame_list_cycle[idx]

View File

@ -6,6 +6,7 @@ import numpy as np
import requests
import resampy
import soundfile as sf
from io import BytesIO
def download_tts(url):
@ -38,16 +39,15 @@ def __create_bytes_stream(byte_stream):
return stream
def main():
def test_async_tts(url, content):
import aiohttp
import asyncio
from io import BytesIO
async def fetch_audio():
url = "http://localhost:8082/v1/audio/speech"
data = {
"model": "tts-1",
"input": "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用",
"input": content,
"voice": "alloy",
"speed": 1.0
}
@ -68,6 +68,33 @@ def main():
asyncio.run(fetch_audio())
def test_sync_tts(url, content):
data = {
"model": "tts-1",
"input": content,
"voice": "alloy",
"speed": 1.0
}
response = requests.post(url, json=data)
if response.status_code == 200:
audio_data = BytesIO(response.content)
audio_stream = __create_bytes_stream(audio_data)
# 保存为新的音频文件
sf.write("output_audio.wav", audio_stream, 16000)
print("Audio data received and saved to output_audio.wav")
else:
print("Error:", response.status_code, response.text)
def main():
# url = "http://localhost:8082/v1/audio/speech"
url = "https://tts.mzzsfy.eu.org/v1/audio/speech"
content = "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用"
# test_async_tts(url, content)
test_sync_tts(url, content)
if __name__ == "__main__":
try:
t = time.time()

View File

@ -4,6 +4,7 @@ import logging
import os
import shutil
from eventbus import EventBus
from utils import save_wav
from human_handler import AudioHandler
@ -16,6 +17,14 @@ class TTSAudioHandle(AudioHandler):
self._sample_rate = 16000
self._index = -1
EventBus().register('stop', self._on_stop)
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
@property
def sample_rate(self):
return self._sample_rate
@ -45,9 +54,13 @@ class TTSAudioSplitHandle(TTSAudioHandle):
self._chunk = self.sample_rate // self._context.fps
self._priority_queue = []
self._current = 0
self._is_running = True
logger.info("TTSAudioSplitHandle init")
def on_handle(self, stream, index):
if not self._is_running:
logger.info('TTSAudioSplitHandle::on_handle is not running')
return
# heapq.heappush(self._priority_queue, (index, stream))
if stream is None:
heapq.heappush(self._priority_queue, (index, None))
@ -55,7 +68,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
stream_len = stream.shape[0]
idx = 0
chunks = []
while stream_len >= self._chunk:
while stream_len >= self._chunk and self._is_running:
# self.on_next_handle(stream[idx:idx + self._chunk], 0)
chunks.append(stream[idx:idx + self._chunk])
stream_len -= self._chunk
@ -63,7 +76,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
heapq.heappush(self._priority_queue, (index, chunks))
current = self._priority_queue[0][0]
print('TTSAudioSplitHandle::on_handle', index, current, self._current)
if current == self._current:
if current == self._current and self._is_running:
self._current = self._current + 1
chunks = heapq.heappop(self._priority_queue)[1]
if chunks is None:
@ -73,7 +86,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
self.on_next_handle(chunk, 0)
def stop(self):
pass
self._is_running = False
class TTSAudioSaveHandle(TTSAudioHandle):

View File

@ -3,6 +3,7 @@
import logging
import time
from eventbus import EventBus
from nlp import NLPCallback
from utils import AsyncTaskQueue
@ -12,8 +13,15 @@ logger = logging.getLogger(__name__)
class TTSBase(NLPCallback):
def __init__(self, handle):
self._handle = handle
self._message_queue = AsyncTaskQueue(5)
self._is_running = False
self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5)
self._is_running = True
EventBus().register('stop', self.on_stop)
def __del__(self):
EventBus().unregister('stop', self.on_stop)
def on_stop(self, *args, **kwargs):
self.stop()
@property
def handle(self):
@ -24,13 +32,13 @@ class TTSBase(NLPCallback):
self._handle = value
async def _request(self, txt: str, index):
print('_request:', txt)
# print('_request:', txt)
t = time.time()
stream = await self._on_request(txt)
if stream is None:
print(f'-------stream is None')
logger.warn(f'-------stream is None')
return
print(f'-------tts time:{time.time() - t:.4f}s')
logger.info(f'-------tts time:{time.time() - t:.4f}s, txt:{txt}')
if self._handle is not None and self._is_running:
await self._on_handle(stream, index)
else:
@ -49,23 +57,22 @@ class TTSBase(NLPCallback):
self.message(txt)
def message(self, txt):
self._is_running = True
txt = txt.strip()
if len(txt) == 0:
logger.info(f'message is empty')
# logger.info(f'message is empty')
return
logger.info(f'message:{txt}')
index = 0
if self._handle is not None:
index = self._handle.get_index()
print(f'message txt-index:{txt}, index {index}')
# print(f'message txt-index:{txt}, index {index}')
self._message_queue.add_task(self._request, txt, index)
def stop(self):
self._is_running = False
self._message_queue.add_task(self._on_close)
self._message_queue.stop()
def pause_talk(self):
logger.info(f'TTSBase pause_talk')
self._is_running = False
self._message_queue.clear()

View File

@ -4,6 +4,7 @@ from io import BytesIO
import aiohttp
import numpy as np
import requests
import soundfile as sf
import edge_tts
import resampy
@ -21,15 +22,7 @@ class TTSEdgeHttp(TTSBase):
self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech'
logger.info(f"TTSEdge init, {voice}")
async def _on_request(self, txt: str):
print('TTSEdgeHttp, _on_request, txt:', txt)
data = {
"model": "tts-1",
"input": txt,
"voice": "alloy",
"speed": 1.0,
"thread": 10
}
async def _on_async_request(self, data):
async with aiohttp.ClientSession() as session:
async with session.post(self._url, json=data) as response:
print('TTSEdgeHttp, _on_request, response:', response)
@ -38,7 +31,28 @@ class TTSEdgeHttp(TTSBase):
return stream
else:
byte_stream = None
return byte_stream
return byte_stream, None
def _on_sync_request(self, data):
response = requests.post(self._url, json=data)
if response.status_code == 200:
stream = BytesIO(response.content)
return stream
else:
return None
async def _on_request(self, txt: str):
logger.info(f'TTSEdgeHttp, _on_request, txt:{txt}')
data = {
"model": "tts-1",
"input": txt,
"voice": "alloy",
"speed": 1.0,
"thread": 10
}
# return self._on_async_request(data)
return self._on_sync_request(data)
async def _on_handle(self, stream, index):
print('-------tts _on_handle')

View File

@ -55,12 +55,16 @@ class PyGameUI:
if self._queue.empty():
return
image = self._queue.get()
self._human_image = pygame.image.frombuffer(image.tobytes(), image.shape[1::-1], "RGB")
color_format = "RGB"
if 4 == image.shape[2]:
color_format = "RGBA"
self._human_image = pygame.image.frombuffer(image.tobytes(), image.shape[1::-1], color_format)
def stop(self):
logger.info('stop')
if self._human_context is not None:
self._human_context.pause_talk()
# self._human_context.pause_talk()
self._human_context.stop()
def on_render(self, image):

View File

@ -1,48 +1,61 @@
#encoding = utf8
import asyncio
import logging
from queue import Queue
import threading
logger = logging.getLogger(__name__)
class AsyncTaskQueue:
def __init__(self, work_num=1):
self._queue = asyncio.Queue()
def __init__(self, name, work_num=1):
self._queue = Queue()
self._worker_num = work_num
self._current_worker_num = work_num
self._thread = threading.Thread(target=self._run_loop)
self._name = name
self._thread = threading.Thread(target=self._run_loop, name=name)
self._thread.start()
self.__loop = None
def _run_loop(self):
print('_run_loop')
logging.info(f'{self._name}, _run_loop')
self.__loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.__loop)
self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
self.__loop.run_forever()
print("exit run")
if not self.__loop.is_closed():
try:
self.__loop.run_forever()
finally:
logging.info(f'{self._name}, exit run')
self.__loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(self.__loop)))
self.__loop.close()
logging.info(f'{self._name}, close loop')
async def _worker(self):
print('_worker')
logging.info(f'{self._name}, _worker')
while True:
task = await self._queue.get()
if task is None: # None as a stop signal
break
try:
task = self._queue.get()
if task is None: # None as a stop signal
break
func, *args = task # Unpack task
print(f"Executing task with args: {args}")
await func(*args) # Execute async function
self._queue.task_done()
func, *args = task # Unpack task
if func is None: # None as a stop signal
break
print('_worker finish')
await func(*args) # Execute async function
except Exception as e:
logging.error(f'{self._name} error: {e}')
finally:
self._queue.task_done()
logging.info(f'{self._name}, _worker finish')
self._current_worker_num -= 1
if self._current_worker_num == 0:
print('loop stop')
self.__loop.stop()
self.__loop.call_soon_threadsafe(self.__loop.stop)
def add_task(self, func, *args):
self.__loop.call_soon_threadsafe(self._queue.put_nowait, (func, *args))
self._queue.put((func, *args))
def stop_workers(self):
for _ in range(self._worker_num):
@ -55,4 +68,4 @@ class AsyncTaskQueue:
def stop(self):
self.stop_workers()
self._thread.join()
self._thread.join()

View File

@ -10,6 +10,9 @@ class SyncQueue:
self._queue = Queue(maxsize)
self._condition = threading.Condition()
def is_empty(self):
return self._queue.empty()
def put(self, item):
with self._condition:
while self._queue.full():

View File

@ -194,7 +194,7 @@ def config_logging(file_name: str, console_level: int = logging.INFO, file_level
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'[%(asctime)s %(levelname)s] %(message)s',
'[%(asctime)s.%(msecs)03d %(levelname)s] %(message)s',
datefmt="%Y/%m/%d %H:%M:%S"
))
console_handler.setLevel(console_level)