Merge branch 'render_sync' into develop

This commit is contained in:
brige 2024-11-08 09:33:41 +08:00
commit d9f55d1ba1
23 changed files with 307 additions and 124 deletions

View File

@ -2,6 +2,7 @@
import threading import threading
from eventbus import EventBus
from .asr_observer import AsrObserver from .asr_observer import AsrObserver
@ -12,11 +13,19 @@ class AsrBase:
self._samples_per_read = 100 self._samples_per_read = 100
self._observers = [] self._observers = []
EventBus().register('stop', self._on_stop)
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._stop_event.set() self._stop_event.set()
self._thread = threading.Thread(target=self._recognize_loop) self._thread = threading.Thread(target=self._recognize_loop)
self._thread.start() self._thread.start()
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def _recognize_loop(self): def _recognize_loop(self):
pass pass

View File

@ -27,6 +27,7 @@ class SherpaNcnnAsr(AsrBase):
super().__init__() super().__init__()
self._recognizer = self._create_recognizer() self._recognizer = self._create_recognizer()
logger.info('SherpaNcnnAsr init') logger.info('SherpaNcnnAsr init')
print('SherpaNcnnAsr init')
def __del__(self): def __del__(self):
self.__del__() self.__del__()
@ -60,17 +61,10 @@ class SherpaNcnnAsr(AsrBase):
time.sleep(3) time.sleep(3)
last_result = "" last_result = ""
logger.info(f'_recognize_loop') logger.info(f'_recognize_loop')
while self._stop_event.is_set(): print(f'_recognize_loop')
logger.info(f'_recognize_loop000')
self._notify_complete('介绍中国5000年历史文学')
logger.info(f'_recognize_loop111')
segment_id += 1
time.sleep(50)
logger.info(f'_recognize_loop222')
logger.info(f'_recognize_loop exit')
'''
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s: with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
while not self._stop_event.is_set(): while self._stop_event.is_set():
samples, _ = s.read(self._samples_per_read) # a blocking read samples, _ = s.read(self._samples_per_read) # a blocking read
samples = samples.reshape(-1) samples = samples.reshape(-1)
self._recognizer.accept_waveform(self._sample_rate, samples) self._recognizer.accept_waveform(self._sample_rate, samples)
@ -89,4 +83,13 @@ class SherpaNcnnAsr(AsrBase):
self._notify_complete(result) self._notify_complete(result)
segment_id += 1 segment_id += 1
self._recognizer.reset() self._recognizer.reset()
'''
while self._stop_event.is_set():
logger.info(f'_recognize_loop000')
self._notify_complete('介绍中国5000年历史文学')
logger.info(f'_recognize_loop111')
segment_id += 1
time.sleep(60)
logger.info(f'_recognize_loop222')
logger.info(f'_recognize_loop exit')
''' '''

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

4
eventbus/__init__.py Normal file
View File

@ -0,0 +1,4 @@
#encoding = utf8
from .event_bus import EventBus

39
eventbus/event_bus.py Normal file
View File

@ -0,0 +1,39 @@
#encoding = utf8
import threading
class EventBus:
_instance = None
_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
with cls._lock:
if not cls._instance:
cls._instance = super(EventBus, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
self._listeners = {}
self._lock = threading.Lock()
self._initialized = True
def register(self, event_type, listener):
with self._lock:
if event_type not in self._listeners:
self._listeners[event_type] = []
self._listeners[event_type].append(listener)
def unregister(self, event_type, listener):
with self._lock:
if event_type in self._listeners:
self._listeners[event_type].remove(listener)
if not self._listeners[event_type]:
del self._listeners[event_type]
def post(self, event_type, *args, **kwargs):
with self._lock:
listeners = self._listeners.get(event_type, []).copy()
for listener in listeners:
listener(*args, **kwargs)

BIN
face/img00016.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 452 KiB

BIN
face/img00016.jpg.bak1 Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 452 KiB

BIN
face/img00020.png.bak Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 258 KiB

View File

@ -9,6 +9,7 @@ from threading import Event, Thread
import numpy as np import numpy as np
import torch import torch
from eventbus import EventBus
from human_handler import AudioHandler from human_handler import AudioHandler
from utils import load_model, mirror_index, get_device, SyncQueue from utils import load_model, mirror_index, get_device, SyncQueue
@ -20,16 +21,28 @@ class AudioInferenceHandler(AudioHandler):
def __init__(self, context, handler): def __init__(self, context, handler):
super().__init__(context, handler) super().__init__(context, handler)
EventBus().register('stop', self._on_stop)
self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel") self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel")
self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio") self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio")
self._is_running = True
self._exit_event = Event() self._exit_event = Event()
self._run_thread = Thread(target=self.__on_run, name="AudioInferenceHandlerThread") self._run_thread = Thread(target=self.__on_run, name="AudioInferenceHandlerThread")
self._exit_event.set() self._exit_event.set()
self._run_thread.start() self._run_thread.start()
logger.info("AudioInferenceHandler init") logger.info("AudioInferenceHandler init")
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def on_handle(self, stream, type_): def on_handle(self, stream, type_):
if not self._is_running:
return
if type_ == 1: if type_ == 1:
self._mal_queue.put(stream) self._mal_queue.put(stream)
elif type_ == 0: elif type_ == 0:
@ -56,14 +69,13 @@ class AudioInferenceHandler(AudioHandler):
device = get_device() device = get_device()
logger.info(f'use device:{device}') logger.info(f'use device:{device}')
while True: while self._is_running:
if self._exit_event.is_set(): if self._exit_event.is_set():
start_time = time.perf_counter() start_time = time.perf_counter()
batch_size = self._context.batch_size batch_size = self._context.batch_size
try: try:
mel_batch = self._mal_queue.get() mel_batch = self._mal_queue.get(timeout=0.02)
size = self._audio_queue.size() # print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', self._mal_queue.size())
# print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', size)
except queue.Empty: except queue.Empty:
continue continue
@ -76,8 +88,15 @@ class AudioInferenceHandler(AudioHandler):
audio_frames.append((frame, type_)) audio_frames.append((frame, type_))
if type_ == 0: if type_ == 0:
is_all_silence = False is_all_silence = False
if not self._is_running:
print('AudioInferenceHandler not running')
break
if is_all_silence: if is_all_silence:
for i in range(batch_size): for i in range(batch_size):
if not self._is_running:
break
self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
0) 0)
index = index + 1 index = index + 1
@ -118,6 +137,8 @@ class AudioInferenceHandler(AudioHandler):
count_time = 0 count_time = 0
for i, res_frame in enumerate(pred): for i, res_frame in enumerate(pred):
if not self._is_running:
break
self.on_next_handle( self.on_next_handle(
(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), (res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
0) 0)
@ -129,10 +150,17 @@ class AudioInferenceHandler(AudioHandler):
logger.info('AudioInferenceHandler inference processor stop') logger.info('AudioInferenceHandler inference processor stop')
def stop(self): def stop(self):
logger.info('AudioInferenceHandler stop')
self._is_running = False
self._exit_event.clear() self._exit_event.clear()
self._run_thread.join() if self._run_thread.is_alive():
logger.info('AudioInferenceHandler stop join')
self._run_thread.join()
logger.info('AudioInferenceHandler stop exit')
def pause_talk(self): def pause_talk(self):
print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size()) print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size())
self._audio_queue.clear() self._audio_queue.clear()
print('AudioInferenceHandler111')
self._mal_queue.clear() self._mal_queue.clear()
print('AudioInferenceHandler222')

View File

@ -3,7 +3,8 @@ import logging
import queue import queue
import time import time
from threading import Thread, Event, Condition from threading import Thread, Event
from eventbus import EventBus
import numpy as np import numpy as np
@ -17,16 +18,26 @@ class AudioMalHandler(AudioHandler):
def __init__(self, context, handler): def __init__(self, context, handler):
super().__init__(context, handler) super().__init__(context, handler)
self._queue = SyncQueue(context.batch_size, "AudioMalHandler_queue") EventBus().register('stop', self._on_stop)
self._exit_event = Event()
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread") self._is_running = True
self._exit_event.set() self._queue = SyncQueue(context.batch_size * 2, "AudioMalHandler_queue")
self._thread.start()
self.frames = [] self.frames = []
self.chunk = context.sample_rate // context.fps self.chunk = context.sample_rate // context.fps
self._exit_event = Event()
self._exit_event.set()
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
self._thread.start()
logger.info("AudioMalHandler init") logger.info("AudioMalHandler init")
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def on_message(self, message): def on_message(self, message):
super().on_message(message) super().on_message(message)
@ -36,7 +47,7 @@ class AudioMalHandler(AudioHandler):
def _on_run(self): def _on_run(self):
logging.info('chunk2mal run') logging.info('chunk2mal run')
while self._exit_event.is_set(): while self._exit_event.is_set() and self._is_running:
self._run_step() self._run_step()
time.sleep(0.02) time.sleep(0.02)
@ -49,6 +60,9 @@ class AudioMalHandler(AudioHandler):
self.frames.append(frame) self.frames.append(frame)
self.on_next_handle((frame, _type), 0) self.on_next_handle((frame, _type), 0)
count = count + 1 count = count + 1
if self._is_running is False:
return
# context not enough, do not run network. # context not enough, do not run network.
if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size: if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size:
return return
@ -64,7 +78,8 @@ class AudioMalHandler(AudioHandler):
mel_step_size = 16 mel_step_size = 16
i = 0 i = 0
mel_chunks = [] mel_chunks = []
while i < (len(self.frames) - self._context.stride_left_size - self._context.stride_right_size) / 2: while i < (len(self.frames) - self._context.stride_left_size - self._context.stride_right_size) / 2\
and self._is_running:
start_idx = int(left + i * mel_idx_multiplier) start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx) # print(start_idx)
if start_idx + mel_step_size > len(mel[0]): if start_idx + mel_step_size > len(mel[0]):
@ -78,11 +93,10 @@ class AudioMalHandler(AudioHandler):
self.frames = self.frames[-(self._context.stride_left_size + self._context.stride_right_size):] self.frames = self.frames[-(self._context.stride_left_size + self._context.stride_right_size):]
def get_audio_frame(self): def get_audio_frame(self):
try: if not self._queue.is_empty():
# print('AudioMalHandler get_audio_frame')
frame = self._queue.get() frame = self._queue.get()
type_ = 0 type_ = 0
except queue.Empty: else:
frame = np.zeros(self.chunk, dtype=np.float32) frame = np.zeros(self.chunk, dtype=np.float32)
type_ = 1 type_ = 1
# print('AudioMalHandler get_audio_frame type:', type_) # print('AudioMalHandler get_audio_frame type:', type_)
@ -90,6 +104,7 @@ class AudioMalHandler(AudioHandler):
def stop(self): def stop(self):
logging.info('stop') logging.info('stop')
self._is_running = False
if self._exit_event is None: if self._exit_event is None:
return return

View File

@ -3,6 +3,7 @@ import logging
import os import os
from asr import SherpaNcnnAsr from asr import SherpaNcnnAsr
from eventbus import EventBus
from .audio_inference_handler import AudioInferenceHandler from .audio_inference_handler import AudioInferenceHandler
from .audio_mal_handler import AudioMalHandler from .audio_mal_handler import AudioMalHandler
from .human_render import HumanRender from .human_render import HumanRender
@ -113,13 +114,7 @@ class HumanContext:
self._asr.attach(self._nlp) self._asr.attach(self._nlp)
def stop(self): def stop(self):
object_stop(self._asr) EventBus().post('stop')
object_stop(self._nlp)
object_stop(self._tts)
object_stop(self._tts_handle)
object_stop(self._mal_handler)
object_stop(self._infer_handler)
object_stop(self._render_handler)
def pause_talk(self): def pause_talk(self):
self._nlp.pause_talk() self._nlp.pause_talk()

View File

@ -5,6 +5,7 @@ import time
from queue import Empty from queue import Empty
from threading import Event, Thread from threading import Event, Thread
from eventbus import EventBus
from human.message_type import MessageType from human.message_type import MessageType
from human_handler import AudioHandler from human_handler import AudioHandler
from render import VoiceRender, VideoRender, PlayClock from render import VoiceRender, VideoRender, PlayClock
@ -17,9 +18,11 @@ class HumanRender(AudioHandler):
def __init__(self, context, handler): def __init__(self, context, handler):
super().__init__(context, handler) super().__init__(context, handler)
EventBus().register('stop', self._on_stop)
play_clock = PlayClock() play_clock = PlayClock()
self._voice_render = VoiceRender(play_clock, context) self._voice_render = VoiceRender(play_clock, context)
self._video_render = VideoRender(play_clock, context, self) self._video_render = VideoRender(play_clock, context, self)
self._is_running = True
self._queue = SyncQueue(context.batch_size, "HumanRender_queue") self._queue = SyncQueue(context.batch_size, "HumanRender_queue")
self._exit_event = Event() self._exit_event = Event()
self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread") self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread")
@ -28,24 +31,43 @@ class HumanRender(AudioHandler):
self._image_render = None self._image_render = None
self._last_audio_ps = 0 self._last_audio_ps = 0
self._last_video_ps = 0 self._last_video_ps = 0
self._empty_log = True
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
def _on_run(self): def _on_run(self):
logging.info('human render run') logging.info('human render run')
while self._exit_event.is_set(): while self._exit_event.is_set() and self._is_running:
# t = time.time()
self._run_step() self._run_step()
time.sleep(0.038) # delay = time.time() - t
delay = 0.03805 # - delay
# print(delay)
# if delay <= 0.0:
# continue
time.sleep(delay)
logging.info('human render exit') logging.info('human render exit')
def _run_step(self): def _run_step(self):
try: try:
value = self._queue.get() value = self._queue.get(timeout=.005)
if value is None: if value is None:
return return
res_frame, idx, audio_frames = value res_frame, idx, audio_frames = value
# print('render queue size', self._queue.size())
if not self._empty_log:
self._empty_log = True
logging.info('render render:')
# print('voice render queue size', self._queue.size()) # print('voice render queue size', self._queue.size())
except Empty: except Empty:
print('render queue.Empty:') if self._empty_log:
self._empty_log = False
logging.info('render queue.Empty:')
return return
type_ = 1 type_ = 1
@ -69,32 +91,23 @@ class HumanRender(AudioHandler):
super().on_message(message) super().on_message(message)
def on_handle(self, stream, index): def on_handle(self, stream, index):
self._queue.put(stream) if not self._is_running:
# res_frame, idx, audio_frames = stream return
# self._voice_render.put(audio_frames, self._last_audio_ps)
# self._last_audio_ps = self._last_audio_ps + 0.4
# type_ = 1
# if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
# type_ = 0
# self._video_render.put((res_frame, idx, type_), self._last_video_ps)
# self._last_video_ps = self._last_video_ps + 0.4
#
# if self._voice_render.is_full():
# self._context.notify({'msg_id': MessageType.Video_Render_Queue_Full})
# def get_audio_queue_size(self): self._queue.put(stream)
# return self._voice_render.size()
def pause_talk(self): def pause_talk(self):
pass logging.info('hunan pause_talk')
# self._voice_render.pause_talk() # self._voice_render.pause_talk()
# self._video_render.pause_talk() # self._video_render.pause_talk()
def stop(self): def stop(self):
logging.info('hunan render stop') logging.info('hunan render stop')
self._is_running = False
if self._exit_event is None: if self._exit_event is None:
return return
self._queue.clear()
self._exit_event.clear() self._exit_event.clear()
if self._thread.is_alive(): if self._thread.is_alive():
self._thread.join() self._thread.join()

View File

@ -2,6 +2,7 @@
import logging import logging
from asr import AsrObserver from asr import AsrObserver
from eventbus import EventBus
from utils import AsyncTaskQueue from utils import AsyncTaskQueue
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -9,11 +10,19 @@ logger = logging.getLogger(__name__)
class NLPBase(AsrObserver): class NLPBase(AsrObserver):
def __init__(self, context, split, callback=None): def __init__(self, context, split, callback=None):
self._ask_queue = AsyncTaskQueue() self._ask_queue = AsyncTaskQueue('NLPBaseQueue')
self._context = context self._context = context
self._split_handle = split self._split_handle = split
self._callback = callback self._callback = callback
self._is_running = False self._is_running = True
EventBus().register('stop', self.on_stop)
def __del__(self):
EventBus().unregister('stop', self.on_stop)
def on_stop(self, *args, **kwargs):
self.stop()
@property @property
def callback(self): def callback(self):
@ -37,6 +46,8 @@ class NLPBase(AsrObserver):
pass pass
def completed(self, message: str): def completed(self, message: str):
if not self._is_running:
return
logger.info(f'complete:{message}') logger.info(f'complete:{message}')
# self._context.pause_talk() # self._context.pause_talk()
self.ask(message) self.ask(message)
@ -50,8 +61,11 @@ class NLPBase(AsrObserver):
def stop(self): def stop(self):
logger.info('NLPBase stop') logger.info('NLPBase stop')
self._is_running = False self._is_running = False
self._ask_queue.clear()
self._ask_queue.add_task(self._on_close) self._ask_queue.add_task(self._on_close)
logger.info('NLPBase add close')
self._ask_queue.stop() self._ask_queue.stop()
logger.info('NLPBase _ask_queue stop')
def pause_talk(self): def pause_talk(self):
logger.info('NLPBase pause_talk') logger.info('NLPBase pause_talk')

View File

@ -30,7 +30,6 @@ class DouBao(NLPBase):
async def _request(self, question): async def _request(self, question):
t = time.time() t = time.time()
logger.info(f'_request:{question}')
logger.info(f'-------dou_bao ask:{question}') logger.info(f'-------dou_bao ask:{question}')
try: try:
stream = await self.__client.chat.completions.create( stream = await self.__client.chat.completions.create(
@ -41,13 +40,16 @@ class DouBao(NLPBase):
], ],
stream=True stream=True
) )
sec = '' sec = ''
async for completion in stream: async for completion in stream:
sec = sec + completion.choices[0].delta.content sec = sec + completion.choices[0].delta.content
sec, message = self._split_handle.handle(sec) sec, message = self._split_handle.handle(sec)
if len(message) > 0: if len(message) > 0:
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
self._on_callback(message) self._on_callback(message)
self._on_callback(sec) self._on_callback(sec)
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
await stream.close() await stream.close()
# sec = "你是测试客服,是由字节跳动开发的 AI 人工智能助手" # sec = "你是测试客服,是由字节跳动开发的 AI 人工智能助手"
@ -59,7 +61,6 @@ class DouBao(NLPBase):
# self._on_callback(sec) # self._on_callback(sec)
except Exception as e: except Exception as e:
print(e) print(e)
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
async def _on_close(self): async def _on_close(self):

View File

@ -15,28 +15,9 @@ class VideoRender(BaseRender):
def __init__(self, play_clock, context, human_render): def __init__(self, play_clock, context, human_render):
super().__init__(play_clock, context, 'Video') super().__init__(play_clock, context, 'Video')
self._human_render = human_render self._human_render = human_render
self._diff_avg_count = 0
def render(self, frame, ps): def render(self, frame, ps):
res_frame, idx, type_ = frame res_frame, idx, type_ = frame
clock_time = self._play_clock.clock_time()
time_difference = clock_time - ps
if abs(time_difference) > self._play_clock.audio_diff_threshold:
if self._diff_avg_count < 5:
self._diff_avg_count += 1
else:
if time_difference < -self._play_clock.audio_diff_threshold:
sleep_time = abs(time_difference)
print("Video frame waiting to catch up with audio", sleep_time)
if sleep_time <= 1.0:
time.sleep(sleep_time)
# elif time_difference > self._play_clock.audio_diff_threshold: # 视频比音频快超过10ms
# print("Video frame dropped to catch up with audio")
# continue
else:
self._diff_avg_count = 0
if type_ == 0: if type_ == 0:
combine_frame = self._context.frame_list_cycle[idx] combine_frame = self._context.frame_list_cycle[idx]

View File

@ -6,6 +6,7 @@ import numpy as np
import requests import requests
import resampy import resampy
import soundfile as sf import soundfile as sf
from io import BytesIO
def download_tts(url): def download_tts(url):
@ -38,16 +39,15 @@ def __create_bytes_stream(byte_stream):
return stream return stream
def main(): def test_async_tts(url, content):
import aiohttp import aiohttp
import asyncio import asyncio
from io import BytesIO
async def fetch_audio(): async def fetch_audio():
url = "http://localhost:8082/v1/audio/speech"
data = { data = {
"model": "tts-1", "model": "tts-1",
"input": "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用", "input": content,
"voice": "alloy", "voice": "alloy",
"speed": 1.0 "speed": 1.0
} }
@ -68,6 +68,33 @@ def main():
asyncio.run(fetch_audio()) asyncio.run(fetch_audio())
def test_sync_tts(url, content):
data = {
"model": "tts-1",
"input": content,
"voice": "alloy",
"speed": 1.0
}
response = requests.post(url, json=data)
if response.status_code == 200:
audio_data = BytesIO(response.content)
audio_stream = __create_bytes_stream(audio_data)
# 保存为新的音频文件
sf.write("output_audio.wav", audio_stream, 16000)
print("Audio data received and saved to output_audio.wav")
else:
print("Error:", response.status_code, response.text)
def main():
# url = "http://localhost:8082/v1/audio/speech"
url = "https://tts.mzzsfy.eu.org/v1/audio/speech"
content = "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用"
# test_async_tts(url, content)
test_sync_tts(url, content)
if __name__ == "__main__": if __name__ == "__main__":
try: try:
t = time.time() t = time.time()

View File

@ -4,6 +4,7 @@ import logging
import os import os
import shutil import shutil
from eventbus import EventBus
from utils import save_wav from utils import save_wav
from human_handler import AudioHandler from human_handler import AudioHandler
@ -16,6 +17,14 @@ class TTSAudioHandle(AudioHandler):
self._sample_rate = 16000 self._sample_rate = 16000
self._index = -1 self._index = -1
EventBus().register('stop', self._on_stop)
def __del__(self):
EventBus().unregister('stop', self._on_stop)
def _on_stop(self, *args, **kwargs):
self.stop()
@property @property
def sample_rate(self): def sample_rate(self):
return self._sample_rate return self._sample_rate
@ -45,9 +54,13 @@ class TTSAudioSplitHandle(TTSAudioHandle):
self._chunk = self.sample_rate // self._context.fps self._chunk = self.sample_rate // self._context.fps
self._priority_queue = [] self._priority_queue = []
self._current = 0 self._current = 0
self._is_running = True
logger.info("TTSAudioSplitHandle init") logger.info("TTSAudioSplitHandle init")
def on_handle(self, stream, index): def on_handle(self, stream, index):
if not self._is_running:
logger.info('TTSAudioSplitHandle::on_handle is not running')
return
# heapq.heappush(self._priority_queue, (index, stream)) # heapq.heappush(self._priority_queue, (index, stream))
if stream is None: if stream is None:
heapq.heappush(self._priority_queue, (index, None)) heapq.heappush(self._priority_queue, (index, None))
@ -55,7 +68,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
stream_len = stream.shape[0] stream_len = stream.shape[0]
idx = 0 idx = 0
chunks = [] chunks = []
while stream_len >= self._chunk: while stream_len >= self._chunk and self._is_running:
# self.on_next_handle(stream[idx:idx + self._chunk], 0) # self.on_next_handle(stream[idx:idx + self._chunk], 0)
chunks.append(stream[idx:idx + self._chunk]) chunks.append(stream[idx:idx + self._chunk])
stream_len -= self._chunk stream_len -= self._chunk
@ -63,7 +76,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
heapq.heappush(self._priority_queue, (index, chunks)) heapq.heappush(self._priority_queue, (index, chunks))
current = self._priority_queue[0][0] current = self._priority_queue[0][0]
print('TTSAudioSplitHandle::on_handle', index, current, self._current) print('TTSAudioSplitHandle::on_handle', index, current, self._current)
if current == self._current: if current == self._current and self._is_running:
self._current = self._current + 1 self._current = self._current + 1
chunks = heapq.heappop(self._priority_queue)[1] chunks = heapq.heappop(self._priority_queue)[1]
if chunks is None: if chunks is None:
@ -73,7 +86,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
self.on_next_handle(chunk, 0) self.on_next_handle(chunk, 0)
def stop(self): def stop(self):
pass self._is_running = False
class TTSAudioSaveHandle(TTSAudioHandle): class TTSAudioSaveHandle(TTSAudioHandle):

View File

@ -3,6 +3,7 @@
import logging import logging
import time import time
from eventbus import EventBus
from nlp import NLPCallback from nlp import NLPCallback
from utils import AsyncTaskQueue from utils import AsyncTaskQueue
@ -12,8 +13,15 @@ logger = logging.getLogger(__name__)
class TTSBase(NLPCallback): class TTSBase(NLPCallback):
def __init__(self, handle): def __init__(self, handle):
self._handle = handle self._handle = handle
self._message_queue = AsyncTaskQueue(5) self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5)
self._is_running = False self._is_running = True
EventBus().register('stop', self.on_stop)
def __del__(self):
EventBus().unregister('stop', self.on_stop)
def on_stop(self, *args, **kwargs):
self.stop()
@property @property
def handle(self): def handle(self):
@ -24,13 +32,13 @@ class TTSBase(NLPCallback):
self._handle = value self._handle = value
async def _request(self, txt: str, index): async def _request(self, txt: str, index):
print('_request:', txt) # print('_request:', txt)
t = time.time() t = time.time()
stream = await self._on_request(txt) stream = await self._on_request(txt)
if stream is None: if stream is None:
print(f'-------stream is None') logger.warn(f'-------stream is None')
return return
print(f'-------tts time:{time.time() - t:.4f}s') logger.info(f'-------tts time:{time.time() - t:.4f}s, txt:{txt}')
if self._handle is not None and self._is_running: if self._handle is not None and self._is_running:
await self._on_handle(stream, index) await self._on_handle(stream, index)
else: else:
@ -49,23 +57,22 @@ class TTSBase(NLPCallback):
self.message(txt) self.message(txt)
def message(self, txt): def message(self, txt):
self._is_running = True
txt = txt.strip() txt = txt.strip()
if len(txt) == 0: if len(txt) == 0:
logger.info(f'message is empty') # logger.info(f'message is empty')
return return
logger.info(f'message:{txt}') logger.info(f'message:{txt}')
index = 0 index = 0
if self._handle is not None: if self._handle is not None:
index = self._handle.get_index() index = self._handle.get_index()
print(f'message txt-index:{txt}, index {index}') # print(f'message txt-index:{txt}, index {index}')
self._message_queue.add_task(self._request, txt, index) self._message_queue.add_task(self._request, txt, index)
def stop(self): def stop(self):
self._is_running = False
self._message_queue.add_task(self._on_close) self._message_queue.add_task(self._on_close)
self._message_queue.stop() self._message_queue.stop()
def pause_talk(self): def pause_talk(self):
logger.info(f'TTSBase pause_talk') logger.info(f'TTSBase pause_talk')
self._is_running = False
self._message_queue.clear() self._message_queue.clear()

View File

@ -4,6 +4,7 @@ from io import BytesIO
import aiohttp import aiohttp
import numpy as np import numpy as np
import requests
import soundfile as sf import soundfile as sf
import edge_tts import edge_tts
import resampy import resampy
@ -21,15 +22,7 @@ class TTSEdgeHttp(TTSBase):
self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech' self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech'
logger.info(f"TTSEdge init, {voice}") logger.info(f"TTSEdge init, {voice}")
async def _on_request(self, txt: str): async def _on_async_request(self, data):
print('TTSEdgeHttp, _on_request, txt:', txt)
data = {
"model": "tts-1",
"input": txt,
"voice": "alloy",
"speed": 1.0,
"thread": 10
}
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(self._url, json=data) as response: async with session.post(self._url, json=data) as response:
print('TTSEdgeHttp, _on_request, response:', response) print('TTSEdgeHttp, _on_request, response:', response)
@ -38,7 +31,28 @@ class TTSEdgeHttp(TTSBase):
return stream return stream
else: else:
byte_stream = None byte_stream = None
return byte_stream return byte_stream, None
def _on_sync_request(self, data):
response = requests.post(self._url, json=data)
if response.status_code == 200:
stream = BytesIO(response.content)
return stream
else:
return None
async def _on_request(self, txt: str):
logger.info(f'TTSEdgeHttp, _on_request, txt:{txt}')
data = {
"model": "tts-1",
"input": txt,
"voice": "alloy",
"speed": 1.0,
"thread": 10
}
# return self._on_async_request(data)
return self._on_sync_request(data)
async def _on_handle(self, stream, index): async def _on_handle(self, stream, index):
print('-------tts _on_handle') print('-------tts _on_handle')

View File

@ -55,12 +55,16 @@ class PyGameUI:
if self._queue.empty(): if self._queue.empty():
return return
image = self._queue.get() image = self._queue.get()
self._human_image = pygame.image.frombuffer(image.tobytes(), image.shape[1::-1], "RGB") color_format = "RGB"
if 4 == image.shape[2]:
color_format = "RGBA"
self._human_image = pygame.image.frombuffer(image.tobytes(), image.shape[1::-1], color_format)
def stop(self): def stop(self):
logger.info('stop') logger.info('stop')
if self._human_context is not None: if self._human_context is not None:
self._human_context.pause_talk() # self._human_context.pause_talk()
self._human_context.stop() self._human_context.stop()
def on_render(self, image): def on_render(self, image):

View File

@ -1,48 +1,61 @@
#encoding = utf8 #encoding = utf8
import asyncio import asyncio
import logging
from queue import Queue
import threading import threading
logger = logging.getLogger(__name__)
class AsyncTaskQueue: class AsyncTaskQueue:
def __init__(self, work_num=1): def __init__(self, name, work_num=1):
self._queue = asyncio.Queue() self._queue = Queue()
self._worker_num = work_num self._worker_num = work_num
self._current_worker_num = work_num self._current_worker_num = work_num
self._thread = threading.Thread(target=self._run_loop) self._name = name
self._thread = threading.Thread(target=self._run_loop, name=name)
self._thread.start() self._thread.start()
self.__loop = None self.__loop = None
def _run_loop(self): def _run_loop(self):
print('_run_loop') logging.info(f'{self._name}, _run_loop')
self.__loop = asyncio.new_event_loop() self.__loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.__loop) asyncio.set_event_loop(self.__loop)
self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)] self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
self.__loop.run_forever() try:
print("exit run") self.__loop.run_forever()
if not self.__loop.is_closed(): finally:
logging.info(f'{self._name}, exit run')
self.__loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(self.__loop)))
self.__loop.close() self.__loop.close()
logging.info(f'{self._name}, close loop')
async def _worker(self): async def _worker(self):
print('_worker') logging.info(f'{self._name}, _worker')
while True: while True:
task = await self._queue.get() try:
if task is None: # None as a stop signal task = self._queue.get()
break if task is None: # None as a stop signal
break
func, *args = task # Unpack task func, *args = task # Unpack task
print(f"Executing task with args: {args}") if func is None: # None as a stop signal
await func(*args) # Execute async function break
self._queue.task_done()
print('_worker finish') await func(*args) # Execute async function
except Exception as e:
logging.error(f'{self._name} error: {e}')
finally:
self._queue.task_done()
logging.info(f'{self._name}, _worker finish')
self._current_worker_num -= 1 self._current_worker_num -= 1
if self._current_worker_num == 0: if self._current_worker_num == 0:
print('loop stop') self.__loop.call_soon_threadsafe(self.__loop.stop)
self.__loop.stop()
def add_task(self, func, *args): def add_task(self, func, *args):
self.__loop.call_soon_threadsafe(self._queue.put_nowait, (func, *args)) self._queue.put((func, *args))
def stop_workers(self): def stop_workers(self):
for _ in range(self._worker_num): for _ in range(self._worker_num):
@ -55,4 +68,4 @@ class AsyncTaskQueue:
def stop(self): def stop(self):
self.stop_workers() self.stop_workers()
self._thread.join() self._thread.join()

View File

@ -10,6 +10,9 @@ class SyncQueue:
self._queue = Queue(maxsize) self._queue = Queue(maxsize)
self._condition = threading.Condition() self._condition = threading.Condition()
def is_empty(self):
return self._queue.empty()
def put(self, item): def put(self, item):
with self._condition: with self._condition:
while self._queue.full(): while self._queue.full():

View File

@ -194,7 +194,7 @@ def config_logging(file_name: str, console_level: int = logging.INFO, file_level
console_handler = logging.StreamHandler() console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter( console_handler.setFormatter(logging.Formatter(
'[%(asctime)s %(levelname)s] %(message)s', '[%(asctime)s.%(msecs)03d %(levelname)s] %(message)s',
datefmt="%Y/%m/%d %H:%M:%S" datefmt="%Y/%m/%d %H:%M:%S"
)) ))
console_handler.setLevel(console_level) console_handler.setLevel(console_level)