Compare commits
No commits in common. "d9f55d1ba17d506c9e54c5dced38b533db28207c" and "e3c2a79ce7b36025c99c7ad2f3ed42c3ccc47891" have entirely different histories.
d9f55d1ba1
...
e3c2a79ce7
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from eventbus import EventBus
|
|
||||||
from .asr_observer import AsrObserver
|
from .asr_observer import AsrObserver
|
||||||
|
|
||||||
|
|
||||||
@ -13,19 +12,11 @@ class AsrBase:
|
|||||||
self._samples_per_read = 100
|
self._samples_per_read = 100
|
||||||
self._observers = []
|
self._observers = []
|
||||||
|
|
||||||
EventBus().register('stop', self._on_stop)
|
|
||||||
|
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
self._thread = threading.Thread(target=self._recognize_loop)
|
self._thread = threading.Thread(target=self._recognize_loop)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
EventBus().unregister('stop', self._on_stop)
|
|
||||||
|
|
||||||
def _on_stop(self, *args, **kwargs):
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
def _recognize_loop(self):
|
def _recognize_loop(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -27,7 +27,6 @@ class SherpaNcnnAsr(AsrBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._recognizer = self._create_recognizer()
|
self._recognizer = self._create_recognizer()
|
||||||
logger.info('SherpaNcnnAsr init')
|
logger.info('SherpaNcnnAsr init')
|
||||||
print('SherpaNcnnAsr init')
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.__del__()
|
self.__del__()
|
||||||
@ -61,10 +60,17 @@ class SherpaNcnnAsr(AsrBase):
|
|||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
last_result = ""
|
last_result = ""
|
||||||
logger.info(f'_recognize_loop')
|
logger.info(f'_recognize_loop')
|
||||||
print(f'_recognize_loop')
|
while self._stop_event.is_set():
|
||||||
|
logger.info(f'_recognize_loop000')
|
||||||
|
self._notify_complete('介绍中国5000年历史文学')
|
||||||
|
logger.info(f'_recognize_loop111')
|
||||||
|
segment_id += 1
|
||||||
|
time.sleep(50)
|
||||||
|
logger.info(f'_recognize_loop222')
|
||||||
|
logger.info(f'_recognize_loop exit')
|
||||||
|
'''
|
||||||
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
|
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
|
||||||
while self._stop_event.is_set():
|
while not self._stop_event.is_set():
|
||||||
samples, _ = s.read(self._samples_per_read) # a blocking read
|
samples, _ = s.read(self._samples_per_read) # a blocking read
|
||||||
samples = samples.reshape(-1)
|
samples = samples.reshape(-1)
|
||||||
self._recognizer.accept_waveform(self._sample_rate, samples)
|
self._recognizer.accept_waveform(self._sample_rate, samples)
|
||||||
@ -83,13 +89,4 @@ class SherpaNcnnAsr(AsrBase):
|
|||||||
self._notify_complete(result)
|
self._notify_complete(result)
|
||||||
segment_id += 1
|
segment_id += 1
|
||||||
self._recognizer.reset()
|
self._recognizer.reset()
|
||||||
'''
|
|
||||||
while self._stop_event.is_set():
|
|
||||||
logger.info(f'_recognize_loop000')
|
|
||||||
self._notify_complete('介绍中国5000年历史文学')
|
|
||||||
logger.info(f'_recognize_loop111')
|
|
||||||
segment_id += 1
|
|
||||||
time.sleep(60)
|
|
||||||
logger.info(f'_recognize_loop222')
|
|
||||||
logger.info(f'_recognize_loop exit')
|
|
||||||
'''
|
'''
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 61 KiB |
@ -1,4 +0,0 @@
|
|||||||
#encoding = utf8
|
|
||||||
|
|
||||||
from .event_bus import EventBus
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
|||||||
#encoding = utf8
|
|
||||||
import threading
|
|
||||||
|
|
||||||
|
|
||||||
class EventBus:
|
|
||||||
_instance = None
|
|
||||||
_lock = threading.Lock()
|
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
if not cls._instance:
|
|
||||||
with cls._lock:
|
|
||||||
if not cls._instance:
|
|
||||||
cls._instance = super(EventBus, cls).__new__(cls, *args, **kwargs)
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
if not hasattr(self, '_initialized'):
|
|
||||||
self._listeners = {}
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
def register(self, event_type, listener):
|
|
||||||
with self._lock:
|
|
||||||
if event_type not in self._listeners:
|
|
||||||
self._listeners[event_type] = []
|
|
||||||
self._listeners[event_type].append(listener)
|
|
||||||
|
|
||||||
def unregister(self, event_type, listener):
|
|
||||||
with self._lock:
|
|
||||||
if event_type in self._listeners:
|
|
||||||
self._listeners[event_type].remove(listener)
|
|
||||||
if not self._listeners[event_type]:
|
|
||||||
del self._listeners[event_type]
|
|
||||||
|
|
||||||
def post(self, event_type, *args, **kwargs):
|
|
||||||
with self._lock:
|
|
||||||
listeners = self._listeners.get(event_type, []).copy()
|
|
||||||
for listener in listeners:
|
|
||||||
listener(*args, **kwargs)
|
|
Binary file not shown.
Before Width: | Height: | Size: 452 KiB |
Binary file not shown.
Before Width: | Height: | Size: 452 KiB |
Binary file not shown.
Before Width: | Height: | Size: 258 KiB |
@ -9,7 +9,6 @@ from threading import Event, Thread
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from eventbus import EventBus
|
|
||||||
from human_handler import AudioHandler
|
from human_handler import AudioHandler
|
||||||
from utils import load_model, mirror_index, get_device, SyncQueue
|
from utils import load_model, mirror_index, get_device, SyncQueue
|
||||||
|
|
||||||
@ -21,28 +20,16 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
def __init__(self, context, handler):
|
def __init__(self, context, handler):
|
||||||
super().__init__(context, handler)
|
super().__init__(context, handler)
|
||||||
|
|
||||||
EventBus().register('stop', self._on_stop)
|
|
||||||
self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel")
|
self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel")
|
||||||
self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio")
|
self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio")
|
||||||
|
|
||||||
self._is_running = True
|
|
||||||
self._exit_event = Event()
|
self._exit_event = Event()
|
||||||
self._run_thread = Thread(target=self.__on_run, name="AudioInferenceHandlerThread")
|
self._run_thread = Thread(target=self.__on_run, name="AudioInferenceHandlerThread")
|
||||||
self._exit_event.set()
|
self._exit_event.set()
|
||||||
self._run_thread.start()
|
self._run_thread.start()
|
||||||
|
|
||||||
logger.info("AudioInferenceHandler init")
|
logger.info("AudioInferenceHandler init")
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
EventBus().unregister('stop', self._on_stop)
|
|
||||||
|
|
||||||
def _on_stop(self, *args, **kwargs):
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
def on_handle(self, stream, type_):
|
def on_handle(self, stream, type_):
|
||||||
if not self._is_running:
|
|
||||||
return
|
|
||||||
|
|
||||||
if type_ == 1:
|
if type_ == 1:
|
||||||
self._mal_queue.put(stream)
|
self._mal_queue.put(stream)
|
||||||
elif type_ == 0:
|
elif type_ == 0:
|
||||||
@ -69,13 +56,14 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
device = get_device()
|
device = get_device()
|
||||||
logger.info(f'use device:{device}')
|
logger.info(f'use device:{device}')
|
||||||
|
|
||||||
while self._is_running:
|
while True:
|
||||||
if self._exit_event.is_set():
|
if self._exit_event.is_set():
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch_size = self._context.batch_size
|
batch_size = self._context.batch_size
|
||||||
try:
|
try:
|
||||||
mel_batch = self._mal_queue.get(timeout=0.02)
|
mel_batch = self._mal_queue.get()
|
||||||
# print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', self._mal_queue.size())
|
size = self._audio_queue.size()
|
||||||
|
# print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', size)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -88,15 +76,8 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
audio_frames.append((frame, type_))
|
audio_frames.append((frame, type_))
|
||||||
if type_ == 0:
|
if type_ == 0:
|
||||||
is_all_silence = False
|
is_all_silence = False
|
||||||
|
|
||||||
if not self._is_running:
|
|
||||||
print('AudioInferenceHandler not running')
|
|
||||||
break
|
|
||||||
|
|
||||||
if is_all_silence:
|
if is_all_silence:
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
if not self._is_running:
|
|
||||||
break
|
|
||||||
self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
||||||
0)
|
0)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
@ -137,8 +118,6 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
count_time = 0
|
count_time = 0
|
||||||
|
|
||||||
for i, res_frame in enumerate(pred):
|
for i, res_frame in enumerate(pred):
|
||||||
if not self._is_running:
|
|
||||||
break
|
|
||||||
self.on_next_handle(
|
self.on_next_handle(
|
||||||
(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
||||||
0)
|
0)
|
||||||
@ -150,17 +129,10 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
logger.info('AudioInferenceHandler inference processor stop')
|
logger.info('AudioInferenceHandler inference processor stop')
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
logger.info('AudioInferenceHandler stop')
|
|
||||||
self._is_running = False
|
|
||||||
self._exit_event.clear()
|
self._exit_event.clear()
|
||||||
if self._run_thread.is_alive():
|
self._run_thread.join()
|
||||||
logger.info('AudioInferenceHandler stop join')
|
|
||||||
self._run_thread.join()
|
|
||||||
logger.info('AudioInferenceHandler stop exit')
|
|
||||||
|
|
||||||
def pause_talk(self):
|
def pause_talk(self):
|
||||||
print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size())
|
print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size())
|
||||||
self._audio_queue.clear()
|
self._audio_queue.clear()
|
||||||
print('AudioInferenceHandler111')
|
|
||||||
self._mal_queue.clear()
|
self._mal_queue.clear()
|
||||||
print('AudioInferenceHandler222')
|
|
||||||
|
@ -3,8 +3,7 @@ import logging
|
|||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event, Condition
|
||||||
from eventbus import EventBus
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -18,26 +17,16 @@ class AudioMalHandler(AudioHandler):
|
|||||||
def __init__(self, context, handler):
|
def __init__(self, context, handler):
|
||||||
super().__init__(context, handler)
|
super().__init__(context, handler)
|
||||||
|
|
||||||
EventBus().register('stop', self._on_stop)
|
self._queue = SyncQueue(context.batch_size, "AudioMalHandler_queue")
|
||||||
|
self._exit_event = Event()
|
||||||
self._is_running = True
|
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
|
||||||
self._queue = SyncQueue(context.batch_size * 2, "AudioMalHandler_queue")
|
self._exit_event.set()
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
self.frames = []
|
self.frames = []
|
||||||
self.chunk = context.sample_rate // context.fps
|
self.chunk = context.sample_rate // context.fps
|
||||||
|
|
||||||
self._exit_event = Event()
|
|
||||||
self._exit_event.set()
|
|
||||||
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
|
|
||||||
self._thread.start()
|
|
||||||
logger.info("AudioMalHandler init")
|
logger.info("AudioMalHandler init")
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
EventBus().unregister('stop', self._on_stop)
|
|
||||||
|
|
||||||
def _on_stop(self, *args, **kwargs):
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
def on_message(self, message):
|
def on_message(self, message):
|
||||||
super().on_message(message)
|
super().on_message(message)
|
||||||
|
|
||||||
@ -47,7 +36,7 @@ class AudioMalHandler(AudioHandler):
|
|||||||
|
|
||||||
def _on_run(self):
|
def _on_run(self):
|
||||||
logging.info('chunk2mal run')
|
logging.info('chunk2mal run')
|
||||||
while self._exit_event.is_set() and self._is_running:
|
while self._exit_event.is_set():
|
||||||
self._run_step()
|
self._run_step()
|
||||||
time.sleep(0.02)
|
time.sleep(0.02)
|
||||||
|
|
||||||
@ -60,9 +49,6 @@ class AudioMalHandler(AudioHandler):
|
|||||||
self.frames.append(frame)
|
self.frames.append(frame)
|
||||||
self.on_next_handle((frame, _type), 0)
|
self.on_next_handle((frame, _type), 0)
|
||||||
count = count + 1
|
count = count + 1
|
||||||
|
|
||||||
if self._is_running is False:
|
|
||||||
return
|
|
||||||
# context not enough, do not run network.
|
# context not enough, do not run network.
|
||||||
if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size:
|
if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size:
|
||||||
return
|
return
|
||||||
@ -78,8 +64,7 @@ class AudioMalHandler(AudioHandler):
|
|||||||
mel_step_size = 16
|
mel_step_size = 16
|
||||||
i = 0
|
i = 0
|
||||||
mel_chunks = []
|
mel_chunks = []
|
||||||
while i < (len(self.frames) - self._context.stride_left_size - self._context.stride_right_size) / 2\
|
while i < (len(self.frames) - self._context.stride_left_size - self._context.stride_right_size) / 2:
|
||||||
and self._is_running:
|
|
||||||
start_idx = int(left + i * mel_idx_multiplier)
|
start_idx = int(left + i * mel_idx_multiplier)
|
||||||
# print(start_idx)
|
# print(start_idx)
|
||||||
if start_idx + mel_step_size > len(mel[0]):
|
if start_idx + mel_step_size > len(mel[0]):
|
||||||
@ -93,10 +78,11 @@ class AudioMalHandler(AudioHandler):
|
|||||||
self.frames = self.frames[-(self._context.stride_left_size + self._context.stride_right_size):]
|
self.frames = self.frames[-(self._context.stride_left_size + self._context.stride_right_size):]
|
||||||
|
|
||||||
def get_audio_frame(self):
|
def get_audio_frame(self):
|
||||||
if not self._queue.is_empty():
|
try:
|
||||||
|
# print('AudioMalHandler get_audio_frame')
|
||||||
frame = self._queue.get()
|
frame = self._queue.get()
|
||||||
type_ = 0
|
type_ = 0
|
||||||
else:
|
except queue.Empty:
|
||||||
frame = np.zeros(self.chunk, dtype=np.float32)
|
frame = np.zeros(self.chunk, dtype=np.float32)
|
||||||
type_ = 1
|
type_ = 1
|
||||||
# print('AudioMalHandler get_audio_frame type:', type_)
|
# print('AudioMalHandler get_audio_frame type:', type_)
|
||||||
@ -104,7 +90,6 @@ class AudioMalHandler(AudioHandler):
|
|||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
logging.info('stop')
|
logging.info('stop')
|
||||||
self._is_running = False
|
|
||||||
if self._exit_event is None:
|
if self._exit_event is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from asr import SherpaNcnnAsr
|
from asr import SherpaNcnnAsr
|
||||||
from eventbus import EventBus
|
|
||||||
from .audio_inference_handler import AudioInferenceHandler
|
from .audio_inference_handler import AudioInferenceHandler
|
||||||
from .audio_mal_handler import AudioMalHandler
|
from .audio_mal_handler import AudioMalHandler
|
||||||
from .human_render import HumanRender
|
from .human_render import HumanRender
|
||||||
@ -114,7 +113,13 @@ class HumanContext:
|
|||||||
self._asr.attach(self._nlp)
|
self._asr.attach(self._nlp)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
EventBus().post('stop')
|
object_stop(self._asr)
|
||||||
|
object_stop(self._nlp)
|
||||||
|
object_stop(self._tts)
|
||||||
|
object_stop(self._tts_handle)
|
||||||
|
object_stop(self._mal_handler)
|
||||||
|
object_stop(self._infer_handler)
|
||||||
|
object_stop(self._render_handler)
|
||||||
|
|
||||||
def pause_talk(self):
|
def pause_talk(self):
|
||||||
self._nlp.pause_talk()
|
self._nlp.pause_talk()
|
||||||
|
@ -5,7 +5,6 @@ import time
|
|||||||
from queue import Empty
|
from queue import Empty
|
||||||
from threading import Event, Thread
|
from threading import Event, Thread
|
||||||
|
|
||||||
from eventbus import EventBus
|
|
||||||
from human.message_type import MessageType
|
from human.message_type import MessageType
|
||||||
from human_handler import AudioHandler
|
from human_handler import AudioHandler
|
||||||
from render import VoiceRender, VideoRender, PlayClock
|
from render import VoiceRender, VideoRender, PlayClock
|
||||||
@ -18,11 +17,9 @@ class HumanRender(AudioHandler):
|
|||||||
def __init__(self, context, handler):
|
def __init__(self, context, handler):
|
||||||
super().__init__(context, handler)
|
super().__init__(context, handler)
|
||||||
|
|
||||||
EventBus().register('stop', self._on_stop)
|
|
||||||
play_clock = PlayClock()
|
play_clock = PlayClock()
|
||||||
self._voice_render = VoiceRender(play_clock, context)
|
self._voice_render = VoiceRender(play_clock, context)
|
||||||
self._video_render = VideoRender(play_clock, context, self)
|
self._video_render = VideoRender(play_clock, context, self)
|
||||||
self._is_running = True
|
|
||||||
self._queue = SyncQueue(context.batch_size, "HumanRender_queue")
|
self._queue = SyncQueue(context.batch_size, "HumanRender_queue")
|
||||||
self._exit_event = Event()
|
self._exit_event = Event()
|
||||||
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
|
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
|
||||||
@ -31,43 +28,24 @@ class HumanRender(AudioHandler):
|
|||||||
self._image_render = None
|
self._image_render = None
|
||||||
self._last_audio_ps = 0
|
self._last_audio_ps = 0
|
||||||
self._last_video_ps = 0
|
self._last_video_ps = 0
|
||||||
self._empty_log = True
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
EventBus().unregister('stop', self._on_stop)
|
|
||||||
|
|
||||||
def _on_stop(self, *args, **kwargs):
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
def _on_run(self):
|
def _on_run(self):
|
||||||
logging.info('human render run')
|
logging.info('human render run')
|
||||||
while self._exit_event.is_set() and self._is_running:
|
while self._exit_event.is_set():
|
||||||
# t = time.time()
|
|
||||||
self._run_step()
|
self._run_step()
|
||||||
# delay = time.time() - t
|
time.sleep(0.038)
|
||||||
delay = 0.03805 # - delay
|
|
||||||
# print(delay)
|
|
||||||
# if delay <= 0.0:
|
|
||||||
# continue
|
|
||||||
time.sleep(delay)
|
|
||||||
|
|
||||||
logging.info('human render exit')
|
logging.info('human render exit')
|
||||||
|
|
||||||
def _run_step(self):
|
def _run_step(self):
|
||||||
try:
|
try:
|
||||||
value = self._queue.get(timeout=.005)
|
value = self._queue.get()
|
||||||
if value is None:
|
if value is None:
|
||||||
return
|
return
|
||||||
res_frame, idx, audio_frames = value
|
res_frame, idx, audio_frames = value
|
||||||
# print('render queue size', self._queue.size())
|
|
||||||
if not self._empty_log:
|
|
||||||
self._empty_log = True
|
|
||||||
logging.info('render render:')
|
|
||||||
# print('voice render queue size', self._queue.size())
|
# print('voice render queue size', self._queue.size())
|
||||||
except Empty:
|
except Empty:
|
||||||
if self._empty_log:
|
print('render queue.Empty:')
|
||||||
self._empty_log = False
|
|
||||||
logging.info('render queue.Empty:')
|
|
||||||
return
|
return
|
||||||
|
|
||||||
type_ = 1
|
type_ = 1
|
||||||
@ -91,23 +69,32 @@ class HumanRender(AudioHandler):
|
|||||||
super().on_message(message)
|
super().on_message(message)
|
||||||
|
|
||||||
def on_handle(self, stream, index):
|
def on_handle(self, stream, index):
|
||||||
if not self._is_running:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._queue.put(stream)
|
self._queue.put(stream)
|
||||||
|
# res_frame, idx, audio_frames = stream
|
||||||
|
# self._voice_render.put(audio_frames, self._last_audio_ps)
|
||||||
|
# self._last_audio_ps = self._last_audio_ps + 0.4
|
||||||
|
# type_ = 1
|
||||||
|
# if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
|
||||||
|
# type_ = 0
|
||||||
|
# self._video_render.put((res_frame, idx, type_), self._last_video_ps)
|
||||||
|
# self._last_video_ps = self._last_video_ps + 0.4
|
||||||
|
#
|
||||||
|
# if self._voice_render.is_full():
|
||||||
|
# self._context.notify({'msg_id': MessageType.Video_Render_Queue_Full})
|
||||||
|
|
||||||
|
# def get_audio_queue_size(self):
|
||||||
|
# return self._voice_render.size()
|
||||||
|
|
||||||
def pause_talk(self):
|
def pause_talk(self):
|
||||||
logging.info('hunan pause_talk')
|
pass
|
||||||
# self._voice_render.pause_talk()
|
# self._voice_render.pause_talk()
|
||||||
# self._video_render.pause_talk()
|
# self._video_render.pause_talk()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
logging.info('hunan render stop')
|
logging.info('hunan render stop')
|
||||||
self._is_running = False
|
|
||||||
if self._exit_event is None:
|
if self._exit_event is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._queue.clear()
|
|
||||||
self._exit_event.clear()
|
self._exit_event.clear()
|
||||||
if self._thread.is_alive():
|
if self._thread.is_alive():
|
||||||
self._thread.join()
|
self._thread.join()
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from asr import AsrObserver
|
from asr import AsrObserver
|
||||||
from eventbus import EventBus
|
|
||||||
from utils import AsyncTaskQueue
|
from utils import AsyncTaskQueue
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -10,19 +9,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class NLPBase(AsrObserver):
|
class NLPBase(AsrObserver):
|
||||||
def __init__(self, context, split, callback=None):
|
def __init__(self, context, split, callback=None):
|
||||||
self._ask_queue = AsyncTaskQueue('NLPBaseQueue')
|
self._ask_queue = AsyncTaskQueue()
|
||||||
self._context = context
|
self._context = context
|
||||||
self._split_handle = split
|
self._split_handle = split
|
||||||
self._callback = callback
|
self._callback = callback
|
||||||
self._is_running = True
|
self._is_running = False
|
||||||
|
|
||||||
EventBus().register('stop', self.on_stop)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
EventBus().unregister('stop', self.on_stop)
|
|
||||||
|
|
||||||
def on_stop(self, *args, **kwargs):
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def callback(self):
|
def callback(self):
|
||||||
@ -46,8 +37,6 @@ class NLPBase(AsrObserver):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def completed(self, message: str):
|
def completed(self, message: str):
|
||||||
if not self._is_running:
|
|
||||||
return
|
|
||||||
logger.info(f'complete:{message}')
|
logger.info(f'complete:{message}')
|
||||||
# self._context.pause_talk()
|
# self._context.pause_talk()
|
||||||
self.ask(message)
|
self.ask(message)
|
||||||
@ -61,11 +50,8 @@ class NLPBase(AsrObserver):
|
|||||||
def stop(self):
|
def stop(self):
|
||||||
logger.info('NLPBase stop')
|
logger.info('NLPBase stop')
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
self._ask_queue.clear()
|
|
||||||
self._ask_queue.add_task(self._on_close)
|
self._ask_queue.add_task(self._on_close)
|
||||||
logger.info('NLPBase add close')
|
|
||||||
self._ask_queue.stop()
|
self._ask_queue.stop()
|
||||||
logger.info('NLPBase _ask_queue stop')
|
|
||||||
|
|
||||||
def pause_talk(self):
|
def pause_talk(self):
|
||||||
logger.info('NLPBase pause_talk')
|
logger.info('NLPBase pause_talk')
|
||||||
|
@ -30,6 +30,7 @@ class DouBao(NLPBase):
|
|||||||
|
|
||||||
async def _request(self, question):
|
async def _request(self, question):
|
||||||
t = time.time()
|
t = time.time()
|
||||||
|
logger.info(f'_request:{question}')
|
||||||
logger.info(f'-------dou_bao ask:{question}')
|
logger.info(f'-------dou_bao ask:{question}')
|
||||||
try:
|
try:
|
||||||
stream = await self.__client.chat.completions.create(
|
stream = await self.__client.chat.completions.create(
|
||||||
@ -40,16 +41,13 @@ class DouBao(NLPBase):
|
|||||||
],
|
],
|
||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
|
|
||||||
sec = ''
|
sec = ''
|
||||||
async for completion in stream:
|
async for completion in stream:
|
||||||
sec = sec + completion.choices[0].delta.content
|
sec = sec + completion.choices[0].delta.content
|
||||||
sec, message = self._split_handle.handle(sec)
|
sec, message = self._split_handle.handle(sec)
|
||||||
if len(message) > 0:
|
if len(message) > 0:
|
||||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
|
||||||
self._on_callback(message)
|
self._on_callback(message)
|
||||||
self._on_callback(sec)
|
self._on_callback(sec)
|
||||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
|
||||||
await stream.close()
|
await stream.close()
|
||||||
|
|
||||||
# sec = "你是测试客服,是由字节跳动开发的 AI 人工智能助手"
|
# sec = "你是测试客服,是由字节跳动开发的 AI 人工智能助手"
|
||||||
@ -61,6 +59,7 @@ class DouBao(NLPBase):
|
|||||||
# self._on_callback(sec)
|
# self._on_callback(sec)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
|
||||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||||
|
|
||||||
async def _on_close(self):
|
async def _on_close(self):
|
||||||
|
@ -15,9 +15,28 @@ class VideoRender(BaseRender):
|
|||||||
def __init__(self, play_clock, context, human_render):
|
def __init__(self, play_clock, context, human_render):
|
||||||
super().__init__(play_clock, context, 'Video')
|
super().__init__(play_clock, context, 'Video')
|
||||||
self._human_render = human_render
|
self._human_render = human_render
|
||||||
|
self._diff_avg_count = 0
|
||||||
|
|
||||||
def render(self, frame, ps):
|
def render(self, frame, ps):
|
||||||
res_frame, idx, type_ = frame
|
res_frame, idx, type_ = frame
|
||||||
|
clock_time = self._play_clock.clock_time()
|
||||||
|
time_difference = clock_time - ps
|
||||||
|
if abs(time_difference) > self._play_clock.audio_diff_threshold:
|
||||||
|
if self._diff_avg_count < 5:
|
||||||
|
self._diff_avg_count += 1
|
||||||
|
else:
|
||||||
|
if time_difference < -self._play_clock.audio_diff_threshold:
|
||||||
|
sleep_time = abs(time_difference)
|
||||||
|
print("Video frame waiting to catch up with audio", sleep_time)
|
||||||
|
if sleep_time <= 1.0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
# elif time_difference > self._play_clock.audio_diff_threshold: # 视频比音频快超过10ms
|
||||||
|
# print("Video frame dropped to catch up with audio")
|
||||||
|
# continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
self._diff_avg_count = 0
|
||||||
|
|
||||||
if type_ == 0:
|
if type_ == 0:
|
||||||
combine_frame = self._context.frame_list_cycle[idx]
|
combine_frame = self._context.frame_list_cycle[idx]
|
||||||
|
@ -6,7 +6,6 @@ import numpy as np
|
|||||||
import requests
|
import requests
|
||||||
import resampy
|
import resampy
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
|
|
||||||
def download_tts(url):
|
def download_tts(url):
|
||||||
@ -39,15 +38,16 @@ def __create_bytes_stream(byte_stream):
|
|||||||
return stream
|
return stream
|
||||||
|
|
||||||
|
|
||||||
def test_async_tts(url, content):
|
def main():
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
async def fetch_audio():
|
async def fetch_audio():
|
||||||
|
url = "http://localhost:8082/v1/audio/speech"
|
||||||
data = {
|
data = {
|
||||||
"model": "tts-1",
|
"model": "tts-1",
|
||||||
"input": content,
|
"input": "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用",
|
||||||
"voice": "alloy",
|
"voice": "alloy",
|
||||||
"speed": 1.0
|
"speed": 1.0
|
||||||
}
|
}
|
||||||
@ -68,33 +68,6 @@ def test_async_tts(url, content):
|
|||||||
asyncio.run(fetch_audio())
|
asyncio.run(fetch_audio())
|
||||||
|
|
||||||
|
|
||||||
def test_sync_tts(url, content):
|
|
||||||
data = {
|
|
||||||
"model": "tts-1",
|
|
||||||
"input": content,
|
|
||||||
"voice": "alloy",
|
|
||||||
"speed": 1.0
|
|
||||||
}
|
|
||||||
response = requests.post(url, json=data)
|
|
||||||
if response.status_code == 200:
|
|
||||||
audio_data = BytesIO(response.content)
|
|
||||||
audio_stream = __create_bytes_stream(audio_data)
|
|
||||||
|
|
||||||
# 保存为新的音频文件
|
|
||||||
sf.write("output_audio.wav", audio_stream, 16000)
|
|
||||||
print("Audio data received and saved to output_audio.wav")
|
|
||||||
else:
|
|
||||||
print("Error:", response.status_code, response.text)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
# url = "http://localhost:8082/v1/audio/speech"
|
|
||||||
url = "https://tts.mzzsfy.eu.org/v1/audio/speech"
|
|
||||||
content = "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用"
|
|
||||||
# test_async_tts(url, content)
|
|
||||||
test_sync_tts(url, content)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
t = time.time()
|
t = time.time()
|
||||||
|
@ -4,7 +4,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from eventbus import EventBus
|
|
||||||
from utils import save_wav
|
from utils import save_wav
|
||||||
from human_handler import AudioHandler
|
from human_handler import AudioHandler
|
||||||
|
|
||||||
@ -17,14 +16,6 @@ class TTSAudioHandle(AudioHandler):
|
|||||||
self._sample_rate = 16000
|
self._sample_rate = 16000
|
||||||
self._index = -1
|
self._index = -1
|
||||||
|
|
||||||
EventBus().register('stop', self._on_stop)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
EventBus().unregister('stop', self._on_stop)
|
|
||||||
|
|
||||||
def _on_stop(self, *args, **kwargs):
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sample_rate(self):
|
def sample_rate(self):
|
||||||
return self._sample_rate
|
return self._sample_rate
|
||||||
@ -54,13 +45,9 @@ class TTSAudioSplitHandle(TTSAudioHandle):
|
|||||||
self._chunk = self.sample_rate // self._context.fps
|
self._chunk = self.sample_rate // self._context.fps
|
||||||
self._priority_queue = []
|
self._priority_queue = []
|
||||||
self._current = 0
|
self._current = 0
|
||||||
self._is_running = True
|
|
||||||
logger.info("TTSAudioSplitHandle init")
|
logger.info("TTSAudioSplitHandle init")
|
||||||
|
|
||||||
def on_handle(self, stream, index):
|
def on_handle(self, stream, index):
|
||||||
if not self._is_running:
|
|
||||||
logger.info('TTSAudioSplitHandle::on_handle is not running')
|
|
||||||
return
|
|
||||||
# heapq.heappush(self._priority_queue, (index, stream))
|
# heapq.heappush(self._priority_queue, (index, stream))
|
||||||
if stream is None:
|
if stream is None:
|
||||||
heapq.heappush(self._priority_queue, (index, None))
|
heapq.heappush(self._priority_queue, (index, None))
|
||||||
@ -68,7 +55,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
|
|||||||
stream_len = stream.shape[0]
|
stream_len = stream.shape[0]
|
||||||
idx = 0
|
idx = 0
|
||||||
chunks = []
|
chunks = []
|
||||||
while stream_len >= self._chunk and self._is_running:
|
while stream_len >= self._chunk:
|
||||||
# self.on_next_handle(stream[idx:idx + self._chunk], 0)
|
# self.on_next_handle(stream[idx:idx + self._chunk], 0)
|
||||||
chunks.append(stream[idx:idx + self._chunk])
|
chunks.append(stream[idx:idx + self._chunk])
|
||||||
stream_len -= self._chunk
|
stream_len -= self._chunk
|
||||||
@ -76,7 +63,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
|
|||||||
heapq.heappush(self._priority_queue, (index, chunks))
|
heapq.heappush(self._priority_queue, (index, chunks))
|
||||||
current = self._priority_queue[0][0]
|
current = self._priority_queue[0][0]
|
||||||
print('TTSAudioSplitHandle::on_handle', index, current, self._current)
|
print('TTSAudioSplitHandle::on_handle', index, current, self._current)
|
||||||
if current == self._current and self._is_running:
|
if current == self._current:
|
||||||
self._current = self._current + 1
|
self._current = self._current + 1
|
||||||
chunks = heapq.heappop(self._priority_queue)[1]
|
chunks = heapq.heappop(self._priority_queue)[1]
|
||||||
if chunks is None:
|
if chunks is None:
|
||||||
@ -86,7 +73,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
|
|||||||
self.on_next_handle(chunk, 0)
|
self.on_next_handle(chunk, 0)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self._is_running = False
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TTSAudioSaveHandle(TTSAudioHandle):
|
class TTSAudioSaveHandle(TTSAudioHandle):
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from eventbus import EventBus
|
|
||||||
from nlp import NLPCallback
|
from nlp import NLPCallback
|
||||||
from utils import AsyncTaskQueue
|
from utils import AsyncTaskQueue
|
||||||
|
|
||||||
@ -13,15 +12,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class TTSBase(NLPCallback):
|
class TTSBase(NLPCallback):
|
||||||
def __init__(self, handle):
|
def __init__(self, handle):
|
||||||
self._handle = handle
|
self._handle = handle
|
||||||
self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5)
|
self._message_queue = AsyncTaskQueue(5)
|
||||||
self._is_running = True
|
self._is_running = False
|
||||||
EventBus().register('stop', self.on_stop)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
EventBus().unregister('stop', self.on_stop)
|
|
||||||
|
|
||||||
def on_stop(self, *args, **kwargs):
|
|
||||||
self.stop()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def handle(self):
|
def handle(self):
|
||||||
@ -32,13 +24,13 @@ class TTSBase(NLPCallback):
|
|||||||
self._handle = value
|
self._handle = value
|
||||||
|
|
||||||
async def _request(self, txt: str, index):
|
async def _request(self, txt: str, index):
|
||||||
# print('_request:', txt)
|
print('_request:', txt)
|
||||||
t = time.time()
|
t = time.time()
|
||||||
stream = await self._on_request(txt)
|
stream = await self._on_request(txt)
|
||||||
if stream is None:
|
if stream is None:
|
||||||
logger.warn(f'-------stream is None')
|
print(f'-------stream is None')
|
||||||
return
|
return
|
||||||
logger.info(f'-------tts time:{time.time() - t:.4f}s, txt:{txt}')
|
print(f'-------tts time:{time.time() - t:.4f}s')
|
||||||
if self._handle is not None and self._is_running:
|
if self._handle is not None and self._is_running:
|
||||||
await self._on_handle(stream, index)
|
await self._on_handle(stream, index)
|
||||||
else:
|
else:
|
||||||
@ -57,22 +49,23 @@ class TTSBase(NLPCallback):
|
|||||||
self.message(txt)
|
self.message(txt)
|
||||||
|
|
||||||
def message(self, txt):
|
def message(self, txt):
|
||||||
|
self._is_running = True
|
||||||
txt = txt.strip()
|
txt = txt.strip()
|
||||||
if len(txt) == 0:
|
if len(txt) == 0:
|
||||||
# logger.info(f'message is empty')
|
logger.info(f'message is empty')
|
||||||
return
|
return
|
||||||
logger.info(f'message:{txt}')
|
logger.info(f'message:{txt}')
|
||||||
index = 0
|
index = 0
|
||||||
if self._handle is not None:
|
if self._handle is not None:
|
||||||
index = self._handle.get_index()
|
index = self._handle.get_index()
|
||||||
# print(f'message txt-index:{txt}, index {index}')
|
print(f'message txt-index:{txt}, index {index}')
|
||||||
self._message_queue.add_task(self._request, txt, index)
|
self._message_queue.add_task(self._request, txt, index)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self._is_running = False
|
|
||||||
self._message_queue.add_task(self._on_close)
|
self._message_queue.add_task(self._on_close)
|
||||||
self._message_queue.stop()
|
self._message_queue.stop()
|
||||||
|
|
||||||
def pause_talk(self):
|
def pause_talk(self):
|
||||||
logger.info(f'TTSBase pause_talk')
|
logger.info(f'TTSBase pause_talk')
|
||||||
|
self._is_running = False
|
||||||
self._message_queue.clear()
|
self._message_queue.clear()
|
||||||
|
@ -4,7 +4,6 @@ from io import BytesIO
|
|||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import edge_tts
|
import edge_tts
|
||||||
import resampy
|
import resampy
|
||||||
@ -22,7 +21,15 @@ class TTSEdgeHttp(TTSBase):
|
|||||||
self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech'
|
self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech'
|
||||||
logger.info(f"TTSEdge init, {voice}")
|
logger.info(f"TTSEdge init, {voice}")
|
||||||
|
|
||||||
async def _on_async_request(self, data):
|
async def _on_request(self, txt: str):
|
||||||
|
print('TTSEdgeHttp, _on_request, txt:', txt)
|
||||||
|
data = {
|
||||||
|
"model": "tts-1",
|
||||||
|
"input": txt,
|
||||||
|
"voice": "alloy",
|
||||||
|
"speed": 1.0,
|
||||||
|
"thread": 10
|
||||||
|
}
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(self._url, json=data) as response:
|
async with session.post(self._url, json=data) as response:
|
||||||
print('TTSEdgeHttp, _on_request, response:', response)
|
print('TTSEdgeHttp, _on_request, response:', response)
|
||||||
@ -31,28 +38,7 @@ class TTSEdgeHttp(TTSBase):
|
|||||||
return stream
|
return stream
|
||||||
else:
|
else:
|
||||||
byte_stream = None
|
byte_stream = None
|
||||||
return byte_stream, None
|
return byte_stream
|
||||||
|
|
||||||
def _on_sync_request(self, data):
|
|
||||||
response = requests.post(self._url, json=data)
|
|
||||||
if response.status_code == 200:
|
|
||||||
stream = BytesIO(response.content)
|
|
||||||
return stream
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _on_request(self, txt: str):
|
|
||||||
logger.info(f'TTSEdgeHttp, _on_request, txt:{txt}')
|
|
||||||
data = {
|
|
||||||
"model": "tts-1",
|
|
||||||
"input": txt,
|
|
||||||
"voice": "alloy",
|
|
||||||
"speed": 1.0,
|
|
||||||
"thread": 10
|
|
||||||
}
|
|
||||||
|
|
||||||
# return self._on_async_request(data)
|
|
||||||
return self._on_sync_request(data)
|
|
||||||
|
|
||||||
async def _on_handle(self, stream, index):
|
async def _on_handle(self, stream, index):
|
||||||
print('-------tts _on_handle')
|
print('-------tts _on_handle')
|
||||||
|
@ -55,16 +55,12 @@ class PyGameUI:
|
|||||||
if self._queue.empty():
|
if self._queue.empty():
|
||||||
return
|
return
|
||||||
image = self._queue.get()
|
image = self._queue.get()
|
||||||
color_format = "RGB"
|
self._human_image = pygame.image.frombuffer(image.tobytes(), image.shape[1::-1], "RGB")
|
||||||
if 4 == image.shape[2]:
|
|
||||||
color_format = "RGBA"
|
|
||||||
|
|
||||||
self._human_image = pygame.image.frombuffer(image.tobytes(), image.shape[1::-1], color_format)
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
logger.info('stop')
|
logger.info('stop')
|
||||||
if self._human_context is not None:
|
if self._human_context is not None:
|
||||||
# self._human_context.pause_talk()
|
self._human_context.pause_talk()
|
||||||
self._human_context.stop()
|
self._human_context.stop()
|
||||||
|
|
||||||
def on_render(self, image):
|
def on_render(self, image):
|
||||||
|
@ -1,61 +1,48 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
from queue import Queue
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncTaskQueue:
|
class AsyncTaskQueue:
|
||||||
def __init__(self, name, work_num=1):
|
def __init__(self, work_num=1):
|
||||||
self._queue = Queue()
|
self._queue = asyncio.Queue()
|
||||||
self._worker_num = work_num
|
self._worker_num = work_num
|
||||||
self._current_worker_num = work_num
|
self._current_worker_num = work_num
|
||||||
self._name = name
|
self._thread = threading.Thread(target=self._run_loop)
|
||||||
self._thread = threading.Thread(target=self._run_loop, name=name)
|
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
self.__loop = None
|
self.__loop = None
|
||||||
|
|
||||||
def _run_loop(self):
|
def _run_loop(self):
|
||||||
logging.info(f'{self._name}, _run_loop')
|
print('_run_loop')
|
||||||
self.__loop = asyncio.new_event_loop()
|
self.__loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(self.__loop)
|
asyncio.set_event_loop(self.__loop)
|
||||||
self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
|
self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
|
||||||
try:
|
self.__loop.run_forever()
|
||||||
self.__loop.run_forever()
|
print("exit run")
|
||||||
finally:
|
if not self.__loop.is_closed():
|
||||||
logging.info(f'{self._name}, exit run')
|
|
||||||
self.__loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(self.__loop)))
|
|
||||||
self.__loop.close()
|
self.__loop.close()
|
||||||
logging.info(f'{self._name}, close loop')
|
|
||||||
|
|
||||||
async def _worker(self):
|
async def _worker(self):
|
||||||
logging.info(f'{self._name}, _worker')
|
print('_worker')
|
||||||
while True:
|
while True:
|
||||||
try:
|
task = await self._queue.get()
|
||||||
task = self._queue.get()
|
if task is None: # None as a stop signal
|
||||||
if task is None: # None as a stop signal
|
break
|
||||||
break
|
|
||||||
|
|
||||||
func, *args = task # Unpack task
|
func, *args = task # Unpack task
|
||||||
if func is None: # None as a stop signal
|
print(f"Executing task with args: {args}")
|
||||||
break
|
await func(*args) # Execute async function
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
await func(*args) # Execute async function
|
print('_worker finish')
|
||||||
except Exception as e:
|
|
||||||
logging.error(f'{self._name} error: {e}')
|
|
||||||
finally:
|
|
||||||
self._queue.task_done()
|
|
||||||
|
|
||||||
logging.info(f'{self._name}, _worker finish')
|
|
||||||
self._current_worker_num -= 1
|
self._current_worker_num -= 1
|
||||||
if self._current_worker_num == 0:
|
if self._current_worker_num == 0:
|
||||||
self.__loop.call_soon_threadsafe(self.__loop.stop)
|
print('loop stop')
|
||||||
|
self.__loop.stop()
|
||||||
|
|
||||||
def add_task(self, func, *args):
|
def add_task(self, func, *args):
|
||||||
self._queue.put((func, *args))
|
self.__loop.call_soon_threadsafe(self._queue.put_nowait, (func, *args))
|
||||||
|
|
||||||
def stop_workers(self):
|
def stop_workers(self):
|
||||||
for _ in range(self._worker_num):
|
for _ in range(self._worker_num):
|
||||||
@ -68,4 +55,4 @@ class AsyncTaskQueue:
|
|||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.stop_workers()
|
self.stop_workers()
|
||||||
self._thread.join()
|
self._thread.join()
|
||||||
|
@ -10,9 +10,6 @@ class SyncQueue:
|
|||||||
self._queue = Queue(maxsize)
|
self._queue = Queue(maxsize)
|
||||||
self._condition = threading.Condition()
|
self._condition = threading.Condition()
|
||||||
|
|
||||||
def is_empty(self):
|
|
||||||
return self._queue.empty()
|
|
||||||
|
|
||||||
def put(self, item):
|
def put(self, item):
|
||||||
with self._condition:
|
with self._condition:
|
||||||
while self._queue.full():
|
while self._queue.full():
|
||||||
|
@ -194,7 +194,7 @@ def config_logging(file_name: str, console_level: int = logging.INFO, file_level
|
|||||||
|
|
||||||
console_handler = logging.StreamHandler()
|
console_handler = logging.StreamHandler()
|
||||||
console_handler.setFormatter(logging.Formatter(
|
console_handler.setFormatter(logging.Formatter(
|
||||||
'[%(asctime)s.%(msecs)03d %(levelname)s] %(message)s',
|
'[%(asctime)s %(levelname)s] %(message)s',
|
||||||
datefmt="%Y/%m/%d %H:%M:%S"
|
datefmt="%Y/%m/%d %H:%M:%S"
|
||||||
))
|
))
|
||||||
console_handler.setLevel(console_level)
|
console_handler.setLevel(console_level)
|
||||||
|
Loading…
Reference in New Issue
Block a user