Compare commits

...

13 Commits

Author SHA1 Message Date
c4eb191e18 modify tts handle 2024-11-16 14:05:13 +08:00
4e98ee0e76 modify remove asyncio 2024-11-16 10:13:11 +08:00
aa554b1209 modify mutil-thread 2024-11-15 21:23:34 +08:00
7eccc99c2a modify handler 2024-11-15 01:09:47 +08:00
88a307ed6a modify tts logs 2024-11-14 19:14:22 +08:00
742340971b add txt render 2024-11-13 19:29:40 +08:00
b3bbf40d95 add human txt 2024-11-13 12:58:56 +08:00
d08a74b4e4 modify tts and txt 2024-11-12 23:07:55 +08:00
7cbe6d073b modify tts audio handle 2024-11-11 19:07:01 +08:00
dde41769bf modify abort tts hangle 2024-11-10 14:06:47 +08:00
a6153b1eb9 modify abort 2024-11-09 21:00:22 +08:00
3334918ed1 modif clean cacah 2024-11-09 07:39:03 +08:00
92b912e162 modfiy abort 2024-11-08 19:49:53 +08:00
12 changed files with 217 additions and 165 deletions

View File

@ -17,7 +17,7 @@ class AsrBase:
self._stop_event = threading.Event()
self._stop_event.set()
self._thread = threading.Thread(target=self._recognize_loop)
self._thread = threading.Thread(target=self._recognize_loop, name="AsrBaseThread")
self._thread.start()
def __del__(self):
@ -34,6 +34,7 @@ class AsrBase:
observer.process(message)
def _notify_complete(self, message: str):
EventBus().post('clear_cache')
for observer in self._observers:
observer.completed(message)

View File

@ -63,6 +63,15 @@ class SherpaNcnnAsr(AsrBase):
logger.info(f'_recognize_loop')
print(f'_recognize_loop')
while self._stop_event.is_set():
logger.info(f'_recognize_loop000')
self._notify_complete('介绍中国5000年历史文学')
logger.info(f'_recognize_loop111')
segment_id += 1
time.sleep(150)
logger.info(f'_recognize_loop222')
logger.info(f'_recognize_loop exit')
'''
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
while self._stop_event.is_set():
samples, _ = s.read(self._samples_per_read) # a blocking read
@ -84,12 +93,4 @@ class SherpaNcnnAsr(AsrBase):
segment_id += 1
self._recognizer.reset()
'''
while self._stop_event.is_set():
logger.info(f'_recognize_loop000')
self._notify_complete('介绍中国5000年历史文学')
logger.info(f'_recognize_loop111')
segment_id += 1
time.sleep(60)
logger.info(f'_recognize_loop222')
logger.info(f'_recognize_loop exit')
'''

View File

@ -22,6 +22,7 @@ class AudioInferenceHandler(AudioHandler):
super().__init__(context, handler)
EventBus().register('stop', self._on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel")
self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio")
@ -35,10 +36,15 @@ class AudioInferenceHandler(AudioHandler):
def __del__(self):
EventBus().unregister('stop', self._on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def _on_stop(self, *args, **kwargs):
self.stop()
def on_clear_cache(self, *args, **kwargs):
self._mal_queue.clear()
self._audio_queue.clear()
def on_handle(self, stream, type_):
if not self._is_running:
return
@ -82,9 +88,11 @@ class AudioInferenceHandler(AudioHandler):
# print('origin mel_batch:', len(mel_batch))
is_all_silence = True
audio_frames = []
current_text = ''
for _ in range(batch_size * 2):
frame, type_ = self._audio_queue.get()
# print('AudioInferenceHandler type_', type_)
current_text = frame[1]
audio_frames.append((frame, type_))
if type_ == 0:
is_all_silence = False
@ -101,7 +109,7 @@ class AudioInferenceHandler(AudioHandler):
0)
index = index + 1
else:
logger.info('infer=======')
logger.info(f'infer======= {current_text}')
t = time.perf_counter()
img_batch = []
# for i in range(batch_size):

View File

@ -19,6 +19,7 @@ class AudioMalHandler(AudioHandler):
super().__init__(context, handler)
EventBus().register('stop', self._on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
self._is_running = True
self._queue = SyncQueue(context.batch_size * 2, "AudioMalHandler_queue")
@ -34,15 +35,20 @@ class AudioMalHandler(AudioHandler):
def __del__(self):
EventBus().unregister('stop', self._on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def _on_stop(self, *args, **kwargs):
self.stop()
def on_clear_cache(self, *args, **kwargs):
self.frames.clear()
self._queue.clear()
def on_message(self, message):
super().on_message(message)
def on_handle(self, stream, index):
# print('AudioMalHandler on_handle', index)
# logging.info(f'AudioMalHandler on_handle {index}')
self._queue.put(stream)
def _on_run(self):
@ -57,7 +63,8 @@ class AudioMalHandler(AudioHandler):
count = 0
for _ in range(self._context.batch_size * 2):
frame, _type = self.get_audio_frame()
self.frames.append(frame)
chunk, txt = frame
self.frames.append(chunk)
self.on_next_handle((frame, _type), 0)
count = count + 1
@ -97,9 +104,10 @@ class AudioMalHandler(AudioHandler):
frame = self._queue.get()
type_ = 0
else:
frame = np.zeros(self.chunk, dtype=np.float32)
chunk = np.zeros(self.chunk, dtype=np.float32)
frame = (chunk, '')
type_ = 1
# print('AudioMalHandler get_audio_frame type:', type_)
# logging.info(f'AudioMalHandler get_audio_frame type:{type_}')
return frame, type_
def stop(self):

View File

@ -19,6 +19,7 @@ class HumanRender(AudioHandler):
super().__init__(context, handler)
EventBus().register('stop', self._on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
play_clock = PlayClock()
self._voice_render = VoiceRender(play_clock, context)
self._video_render = VideoRender(play_clock, context, self)
@ -35,17 +36,21 @@ class HumanRender(AudioHandler):
def __del__(self):
EventBus().unregister('stop', self._on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def _on_stop(self, *args, **kwargs):
self.stop()
def on_clear_cache(self, *args, **kwargs):
self._queue.clear()
def _on_run(self):
logging.info('human render run')
while self._exit_event.is_set() and self._is_running:
# t = time.time()
self._run_step()
# delay = time.time() - t
delay = 0.03805 # - delay
delay = 0.038 # - delay
# print(delay)
# if delay <= 0.0:
# continue
@ -116,52 +121,4 @@ class HumanRender(AudioHandler):
# self._video_render.stop()
# self._exit_event.clear()
# self._thread.join()
'''
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._exit_event.set()
self._thread.start()
def _on_run(self):
logging.info('human render run')
while self._exit_event.is_set():
self._run_step()
time.sleep(0.02)
logging.info('human render exit')
def _run_step(self):
try:
res_frame, idx, audio_frames = self._queue.get(block=True, timeout=.002)
except queue.Empty:
# print('render queue.Empty:')
return None
if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
combine_frame = self._context.frame_list_cycle[idx]
else:
bbox = self._context.coord_list_cycle[idx]
combine_frame = copy.deepcopy(self._context.frame_list_cycle[idx])
y1, y2, x1, x2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except:
return
# combine_frame = get_image(ori_frame,res_frame,bbox)
# t=time.perf_counter()
combine_frame[y1:y2, x1:x2] = res_frame
image = combine_frame
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self._image_render is not None:
self._image_render.on_render(image)
for audio_frame in audio_frames:
frame, type_ = audio_frame
frame = (frame * 32767).astype(np.int16)
if self._audio_render is not None:
self._audio_render.write(frame.tobytes(), int(frame.shape[0]*2))
# new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
# new_frame.planes[0].update(frame.tobytes())
# new_frame.sample_rate = 16000
'''

View File

@ -17,13 +17,19 @@ class NLPBase(AsrObserver):
self._is_running = True
EventBus().register('stop', self.on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
def __del__(self):
EventBus().unregister('stop', self.on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def on_stop(self, *args, **kwargs):
self.stop()
def on_clear_cache(self, *args, **kwargs):
logger.info('NLPBase clear_cache')
self._ask_queue.clear()
@property
def callback(self):
return self._callback
@ -36,10 +42,10 @@ class NLPBase(AsrObserver):
if self._callback is not None and self._is_running:
self._callback.on_message(txt)
async def _request(self, question):
def _request(self, question):
pass
async def _on_close(self):
def _on_close(self):
pass
def process(self, message: str):

View File

@ -17,13 +17,13 @@ class DouBaoSDK:
self.__client = AsyncArk(api_key=token)
self._stream = None
async def request(self, question, handle, callback):
def request(self, question, handle, callback):
if self.__client is None:
self.__client = AsyncArk(api_key=self._token)
t = time.time()
logger.info(f'-------dou_bao ask:{question}')
try:
self._stream = await self.__client.chat.completions.create(
self._stream = self.__client.chat.completions.create(
model="ep-20241008152048-fsgzf",
messages=[
{"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
@ -33,27 +33,36 @@ class DouBaoSDK:
)
sec = ''
async for completion in self._stream:
for completion in self._stream:
sec = sec + completion.choices[0].delta.content
sec, message = handle.handle(sec)
if len(message) > 0:
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
# logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
callback(message)
callback(sec)
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
await self._stream.close()
# logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
self._stream.close()
self._stream = None
except Exception as e:
print(e)
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
logger.error(f'-------dou_bao error:{e}')
# logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
async def close(self):
def close(self):
if self._stream is not None:
await self._stream.close()
self._stream.close()
self._stream = None
logger.info('AsyncArk close')
if self.__client is not None and not self.__client.is_closed():
await self.__client.close()
self.__client.close()
self.__client = None
def aclose(self):
if self._stream is not None:
self._stream.close()
self._stream = None
logger.info('AsyncArk close')
if self.__client is not None and not self.__client.is_closed():
self.__client.close()
self.__client = None
@ -79,7 +88,7 @@ class DouBaoHttp:
response = requests.post(url, headers=headers, json=data, stream=True)
return response
async def request(self, question, handle, callback):
def request(self, question, handle, callback):
t = time.time()
self._requesting = True
logger.info(f'-------dou_bao ask:{question}')
@ -89,7 +98,7 @@ class DouBaoHttp:
]
self._response = self.__request(msg_list)
if not self._response.ok:
logger.info(f"请求失败,状态码:{self._response.status_code}")
logger.error(f"请求失败,状态码:{self._response.status_code}")
return
sec = ''
for chunk in self._response.iter_lines():
@ -97,20 +106,35 @@ class DouBaoHttp:
if len(content) < 1:
continue
content = content[5:]
content = json.loads(content)
content = content.strip()
if content == '[DONE]':
break
try:
content = json.loads(content)
except Exception as e:
logger.error(f"json解析失败错误信息{e, content}")
continue
sec = sec + content["choices"][0]["delta"]["content"]
sec, message = handle.handle(sec)
if len(message) > 0:
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
callback(message)
callback(sec)
if len(sec) > 0:
callback(sec)
self._requesting = False
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
async def close(self):
def close(self):
if self._response is not None and self._requesting:
self._response.close()
def aclose(self):
if self._response is not None and self._requesting:
self._response.close()
logger.info('DouBaoHttp close')
class DouBao(NLPBase):
def __init__(self, context, split, callback=None):
@ -131,10 +155,16 @@ class DouBao(NLPBase):
self.__token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74'
self._dou_bao = DouBaoHttp(self.__token)
async def _request(self, question):
await self._dou_bao.request(question, self._split_handle, self._on_callback)
def _request(self, question):
self._dou_bao.request(question, self._split_handle, self._on_callback)
async def _on_close(self):
def _on_close(self):
if self._dou_bao is not None:
await self._dou_bao.close()
self._dou_bao.close()
logger.info('AsyncArk close')
def on_clear_cache(self, *args, **kwargs):
super().on_clear_cache(*args, **kwargs)
if self._dou_bao is not None:
self._dou_bao.aclose()
logger.info('DouBao clear_cache')

View File

@ -16,6 +16,7 @@ class VoiceRender(BaseRender):
def __init__(self, play_clock, context):
self._audio_render = AudioRender()
super().__init__(play_clock, context, 'Voice')
self._current_text = ''
def render(self, frame, ps):
self._play_clock.update_display_time()
@ -23,12 +24,16 @@ class VoiceRender(BaseRender):
for audio_frame in frame:
frame, type_ = audio_frame
frame = (frame * 32767).astype(np.int16)
chunk, txt = frame
if txt != self._current_text:
self._current_text = txt
logging.info(f'VoiceRender: {txt}')
chunk = (chunk * 32767).astype(np.int16)
if self._audio_render is not None:
try:
chunk_len = int(frame.shape[0] * 2)
chunk_len = int(chunk.shape[0] * 2)
# print('audio frame:', frame.shape, chunk_len)
self._audio_render.write(frame.tobytes(), chunk_len)
self._audio_render.write(chunk.tobytes(), chunk_len)
except Exception as e:
logging.error(f'Error writing audio frame: {e}')

View File

@ -3,6 +3,8 @@ import heapq
import logging
import os
import shutil
import time
from threading import Lock, Thread
from eventbus import EventBus
from utils import save_wav
@ -18,13 +20,18 @@ class TTSAudioHandle(AudioHandler):
self._index = -1
EventBus().register('stop', self._on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
def __del__(self):
EventBus().unregister('stop', self._on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def _on_stop(self, *args, **kwargs):
self.stop()
def on_clear_cache(self, *args, **kwargs):
self._index = -1
@property
def sample_rate(self):
return self._sample_rate
@ -53,40 +60,54 @@ class TTSAudioSplitHandle(TTSAudioHandle):
self.sample_rate = self._context.sample_rate
self._chunk = self.sample_rate // self._context.fps
self._priority_queue = []
self._lock = Lock()
self._current = 0
self._is_running = True
self._thread = Thread(target=self._process_loop)
self._thread.start()
logger.info("TTSAudioSplitHandle init")
def _process_loop(self):
while self._is_running:
with self._lock:
if self._priority_queue and self._priority_queue[0][0] == self._current:
self._current += 1
chunks, txt = heapq.heappop(self._priority_queue)[1]
if chunks is not None:
for chunk in chunks:
self.on_next_handle((chunk, txt), 0)
time.sleep(0.01) # Sleep briefly to prevent busy-waiting
def on_handle(self, stream, index):
if not self._is_running:
logger.info('TTSAudioSplitHandle::on_handle is not running')
return
# heapq.heappush(self._priority_queue, (index, stream))
if stream is None:
logger.info(f'TTSAudioSplitHandle::on_handle {index}')
s, txt = stream
if s is None:
heapq.heappush(self._priority_queue, (index, None))
else:
stream_len = stream.shape[0]
stream_len = s.shape[0]
idx = 0
chunks = []
while stream_len >= self._chunk and self._is_running:
# self.on_next_handle(stream[idx:idx + self._chunk], 0)
chunks.append(stream[idx:idx + self._chunk])
chunks.append(s[idx:idx + self._chunk])
stream_len -= self._chunk
idx += self._chunk
heapq.heappush(self._priority_queue, (index, chunks))
current = self._priority_queue[0][0]
print('TTSAudioSplitHandle::on_handle', index, current, self._current)
if current == self._current and self._is_running:
self._current = self._current + 1
chunks = heapq.heappop(self._priority_queue)[1]
if chunks is None:
pass
else:
for chunk in chunks:
self.on_next_handle(chunk, 0)
if not self._is_running:
return
heapq.heappush(self._priority_queue, (index, (chunks, txt)))
def stop(self):
self._is_running = False
self._thread.join()
def on_clear_cache(self, *args, **kwargs):
super().on_clear_cache()
with self._lock:
self._current = 0
self._priority_queue.clear()
class TTSAudioSaveHandle(TTSAudioHandle):

View File

@ -16,13 +16,19 @@ class TTSBase(NLPCallback):
self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5)
self._is_running = True
EventBus().register('stop', self.on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
def __del__(self):
EventBus().unregister('stop', self.on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def on_stop(self, *args, **kwargs):
self.stop()
def on_clear_cache(self, *args, **kwargs):
logger.info('TTSBase clear_cache')
self._message_queue.clear()
@property
def handle(self):
return self._handle
@ -31,32 +37,39 @@ class TTSBase(NLPCallback):
def handle(self, value):
self._handle = value
async def _request(self, txt: str, index):
# print('_request:', txt)
t = time.time()
stream = await self._on_request(txt)
if stream is None:
logger.warn(f'-------stream is None')
def _request(self, txt: str, index):
if not self._is_running:
logger.info('TTSBase::_request is not running')
return
t = time.time()
stream = self._on_request(txt)
logger.info(f'-------tts request time:{time.time() - t:.4f}s, txt:{txt}')
if stream is None or self._is_running is False:
logger.warning(f'-------stream is None or is_running {self._is_running}')
return
logger.info(f'-------tts time:{time.time() - t:.4f}s, txt:{txt}')
if self._handle is not None and self._is_running:
await self._on_handle(stream, index)
self._on_handle((stream, txt), index)
else:
logger.info(f'handle is None, running:{self._is_running}')
logger.info(f'-------tts finish time:{time.time() - t:.4f}s, txt:{txt}')
async def _on_request(self, text: str):
def _on_request(self, text: str):
pass
async def _on_handle(self, stream, index):
def _on_handle(self, stream, index):
pass
async def _on_close(self):
def _on_close(self):
pass
def on_message(self, txt: str):
self.message(txt)
def message(self, txt):
if not self._is_running:
logger.info('TTSBase::message is not running')
return
txt = txt.strip()
if len(txt) == 0:
# logger.info(f'message is empty')
@ -66,6 +79,7 @@ class TTSBase(NLPCallback):
if self._handle is not None:
index = self._handle.get_index()
# print(f'message txt-index:{txt}, index {index}')
logger.info(f'TTSBase::message request:{txt}, index:{index}')
self._message_queue.add_task(self._request, txt, index)
def stop(self):

View File

@ -1,5 +1,6 @@
#encoding = utf8
import logging
import time
from io import BytesIO
import aiohttp
@ -21,13 +22,14 @@ class TTSEdgeHttp(TTSBase):
# self._url = 'http://localhost:8082/v1/audio/speech'
self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech'
logger.info(f"TTSEdge init, {voice}")
self._response_list = []
async def _on_async_request(self, data):
async with aiohttp.ClientSession() as session:
async with session.post(self._url, json=data) as response:
def _on_async_request(self, data):
with aiohttp.ClientSession() as session:
with session.post(self._url, json=data) as response:
print('TTSEdgeHttp, _on_request, response:', response)
if response.status == 200:
stream = BytesIO(await response.read())
stream = BytesIO(response.read())
return stream
else:
byte_stream = None
@ -35,13 +37,14 @@ class TTSEdgeHttp(TTSBase):
def _on_sync_request(self, data):
response = requests.post(self._url, json=data)
self._response_list.append(response)
stream = None
if response.status_code == 200:
stream = BytesIO(response.content)
return stream
else:
return None
self._response_list.remove(response)
return stream
async def _on_request(self, txt: str):
def _on_request(self, txt: str):
logger.info(f'TTSEdgeHttp, _on_request, txt:{txt}')
data = {
"model": "tts-1",
@ -54,23 +57,25 @@ class TTSEdgeHttp(TTSBase):
# return self._on_async_request(data)
return self._on_sync_request(data)
async def _on_handle(self, stream, index):
print('-------tts _on_handle')
def _on_handle(self, stream, index):
st, txt = stream
try:
stream.seek(0)
byte_stream = self.__create_bytes_stream(stream)
print('-------tts start push chunk', index)
self._handle.on_handle(byte_stream, index)
stream.seek(0)
stream.truncate()
print('-------tts finish push chunk')
st.seek(0)
t = time.time()
byte_stream = self.__create_bytes_stream(st)
logger.info(f'-------tts resample time:{time.time() - t:.4f}s, txt:{txt}')
t = time.time()
self._handle.on_handle((byte_stream, txt), index)
logger.info(f'-------tts handle time:{time.time() - t:.4f}s')
st.seek(0)
st.truncate()
except Exception as e:
self._handle.on_handle(None, index)
stream.seek(0)
stream.truncate()
print('-------tts finish error:', e)
stream.close()
st.seek(0)
st.truncate()
logger.error(f'-------tts finish error:{e}')
st.close()
def __create_bytes_stream(self, byte_stream):
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
@ -87,7 +92,13 @@ class TTSEdgeHttp(TTSBase):
return stream
async def _on_close(self):
def _on_close(self):
print('TTSEdge close')
# if self._byte_stream is not None and not self._byte_stream.closed:
# self._byte_stream.close()
def on_clear_cache(self, *args, **kwargs):
logger.info('TTSEdgeHttp clear_cache')
super().on_clear_cache(*args, **kwargs)
for response in self._response_list:
response.close()

View File

@ -1,6 +1,5 @@
#encoding = utf8
import asyncio
import logging
from queue import Queue
import threading
@ -14,24 +13,14 @@ class AsyncTaskQueue:
self._worker_num = work_num
self._current_worker_num = work_num
self._name = name
self._thread = threading.Thread(target=self._run_loop, name=name)
self._thread.start()
self.__loop = None
self._threads = []
self._lock = threading.Lock()
for _ in range(work_num):
thread = threading.Thread(target=self._worker, name=f'{name}_worker_{_}')
thread.start()
self._threads.append(thread)
def _run_loop(self):
logging.info(f'{self._name}, _run_loop')
self.__loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.__loop)
self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
try:
self.__loop.run_forever()
finally:
logging.info(f'{self._name}, exit run')
self.__loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(self.__loop)))
self.__loop.close()
logging.info(f'{self._name}, close loop')
async def _worker(self):
def _worker(self):
logging.info(f'{self._name}, _worker')
while True:
try:
@ -43,16 +32,17 @@ class AsyncTaskQueue:
if func is None: # None as a stop signal
break
await func(*args) # Execute async function
func(*args) # Execute function
except Exception as e:
logging.error(f'{self._name} error: {e}')
logging.error(f'{self._name} error: {repr(e)}')
finally:
self._queue.task_done()
logging.info(f'{self._name}, _worker finish')
self._current_worker_num -= 1
if self._current_worker_num == 0:
self.__loop.call_soon_threadsafe(self.__loop.stop)
with self._lock:
self._current_worker_num -= 1
if self._current_worker_num == 0:
self._queue.put(None) # Send stop signal to remaining workers
def add_task(self, func, *args):
self._queue.put((func, *args))
@ -62,10 +52,10 @@ class AsyncTaskQueue:
self.add_task(None) # Send stop signal
def clear(self):
while not self._queue.empty():
self._queue.get_nowait()
self._queue.task_done()
with self._queue.mutex:
self._queue.queue.clear()
def stop(self):
self.stop_workers()
self._thread.join()
for thread in self._threads:
thread.join()