Compare commits

..

13 Commits

Author SHA1 Message Date
c4eb191e18 modify tts handle 2024-11-16 14:05:13 +08:00
4e98ee0e76 modify remove asyncio 2024-11-16 10:13:11 +08:00
aa554b1209 modify mutil-thread 2024-11-15 21:23:34 +08:00
7eccc99c2a modify handler 2024-11-15 01:09:47 +08:00
88a307ed6a modify tts logs 2024-11-14 19:14:22 +08:00
742340971b add txt render 2024-11-13 19:29:40 +08:00
b3bbf40d95 add human txt 2024-11-13 12:58:56 +08:00
d08a74b4e4 modify tts and txt 2024-11-12 23:07:55 +08:00
7cbe6d073b modify tts audio handle 2024-11-11 19:07:01 +08:00
dde41769bf modify abort tts hangle 2024-11-10 14:06:47 +08:00
a6153b1eb9 modify abort 2024-11-09 21:00:22 +08:00
3334918ed1 modif clean cacah 2024-11-09 07:39:03 +08:00
92b912e162 modfiy abort 2024-11-08 19:49:53 +08:00
12 changed files with 217 additions and 165 deletions

View File

@ -17,7 +17,7 @@ class AsrBase:
self._stop_event = threading.Event() self._stop_event = threading.Event()
self._stop_event.set() self._stop_event.set()
self._thread = threading.Thread(target=self._recognize_loop) self._thread = threading.Thread(target=self._recognize_loop, name="AsrBaseThread")
self._thread.start() self._thread.start()
def __del__(self): def __del__(self):
@ -34,6 +34,7 @@ class AsrBase:
observer.process(message) observer.process(message)
def _notify_complete(self, message: str): def _notify_complete(self, message: str):
EventBus().post('clear_cache')
for observer in self._observers: for observer in self._observers:
observer.completed(message) observer.completed(message)

View File

@ -63,6 +63,15 @@ class SherpaNcnnAsr(AsrBase):
logger.info(f'_recognize_loop') logger.info(f'_recognize_loop')
print(f'_recognize_loop') print(f'_recognize_loop')
while self._stop_event.is_set():
logger.info(f'_recognize_loop000')
self._notify_complete('介绍中国5000年历史文学')
logger.info(f'_recognize_loop111')
segment_id += 1
time.sleep(150)
logger.info(f'_recognize_loop222')
logger.info(f'_recognize_loop exit')
'''
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s: with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
while self._stop_event.is_set(): while self._stop_event.is_set():
samples, _ = s.read(self._samples_per_read) # a blocking read samples, _ = s.read(self._samples_per_read) # a blocking read
@ -84,12 +93,4 @@ class SherpaNcnnAsr(AsrBase):
segment_id += 1 segment_id += 1
self._recognizer.reset() self._recognizer.reset()
''' '''
while self._stop_event.is_set():
logger.info(f'_recognize_loop000')
self._notify_complete('介绍中国5000年历史文学')
logger.info(f'_recognize_loop111')
segment_id += 1
time.sleep(60)
logger.info(f'_recognize_loop222')
logger.info(f'_recognize_loop exit')
'''

View File

@ -22,6 +22,7 @@ class AudioInferenceHandler(AudioHandler):
super().__init__(context, handler) super().__init__(context, handler)
EventBus().register('stop', self._on_stop) EventBus().register('stop', self._on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel") self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel")
self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio") self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio")
@ -35,10 +36,15 @@ class AudioInferenceHandler(AudioHandler):
def __del__(self): def __del__(self):
EventBus().unregister('stop', self._on_stop) EventBus().unregister('stop', self._on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def _on_stop(self, *args, **kwargs): def _on_stop(self, *args, **kwargs):
self.stop() self.stop()
def on_clear_cache(self, *args, **kwargs):
self._mal_queue.clear()
self._audio_queue.clear()
def on_handle(self, stream, type_): def on_handle(self, stream, type_):
if not self._is_running: if not self._is_running:
return return
@ -82,9 +88,11 @@ class AudioInferenceHandler(AudioHandler):
# print('origin mel_batch:', len(mel_batch)) # print('origin mel_batch:', len(mel_batch))
is_all_silence = True is_all_silence = True
audio_frames = [] audio_frames = []
current_text = ''
for _ in range(batch_size * 2): for _ in range(batch_size * 2):
frame, type_ = self._audio_queue.get() frame, type_ = self._audio_queue.get()
# print('AudioInferenceHandler type_', type_) # print('AudioInferenceHandler type_', type_)
current_text = frame[1]
audio_frames.append((frame, type_)) audio_frames.append((frame, type_))
if type_ == 0: if type_ == 0:
is_all_silence = False is_all_silence = False
@ -101,7 +109,7 @@ class AudioInferenceHandler(AudioHandler):
0) 0)
index = index + 1 index = index + 1
else: else:
logger.info('infer=======') logger.info(f'infer======= {current_text}')
t = time.perf_counter() t = time.perf_counter()
img_batch = [] img_batch = []
# for i in range(batch_size): # for i in range(batch_size):

View File

@ -19,6 +19,7 @@ class AudioMalHandler(AudioHandler):
super().__init__(context, handler) super().__init__(context, handler)
EventBus().register('stop', self._on_stop) EventBus().register('stop', self._on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
self._is_running = True self._is_running = True
self._queue = SyncQueue(context.batch_size * 2, "AudioMalHandler_queue") self._queue = SyncQueue(context.batch_size * 2, "AudioMalHandler_queue")
@ -34,15 +35,20 @@ class AudioMalHandler(AudioHandler):
def __del__(self): def __del__(self):
EventBus().unregister('stop', self._on_stop) EventBus().unregister('stop', self._on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def _on_stop(self, *args, **kwargs): def _on_stop(self, *args, **kwargs):
self.stop() self.stop()
def on_clear_cache(self, *args, **kwargs):
self.frames.clear()
self._queue.clear()
def on_message(self, message): def on_message(self, message):
super().on_message(message) super().on_message(message)
def on_handle(self, stream, index): def on_handle(self, stream, index):
# print('AudioMalHandler on_handle', index) # logging.info(f'AudioMalHandler on_handle {index}')
self._queue.put(stream) self._queue.put(stream)
def _on_run(self): def _on_run(self):
@ -57,7 +63,8 @@ class AudioMalHandler(AudioHandler):
count = 0 count = 0
for _ in range(self._context.batch_size * 2): for _ in range(self._context.batch_size * 2):
frame, _type = self.get_audio_frame() frame, _type = self.get_audio_frame()
self.frames.append(frame) chunk, txt = frame
self.frames.append(chunk)
self.on_next_handle((frame, _type), 0) self.on_next_handle((frame, _type), 0)
count = count + 1 count = count + 1
@ -97,9 +104,10 @@ class AudioMalHandler(AudioHandler):
frame = self._queue.get() frame = self._queue.get()
type_ = 0 type_ = 0
else: else:
frame = np.zeros(self.chunk, dtype=np.float32) chunk = np.zeros(self.chunk, dtype=np.float32)
frame = (chunk, '')
type_ = 1 type_ = 1
# print('AudioMalHandler get_audio_frame type:', type_) # logging.info(f'AudioMalHandler get_audio_frame type:{type_}')
return frame, type_ return frame, type_
def stop(self): def stop(self):

View File

@ -19,6 +19,7 @@ class HumanRender(AudioHandler):
super().__init__(context, handler) super().__init__(context, handler)
EventBus().register('stop', self._on_stop) EventBus().register('stop', self._on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
play_clock = PlayClock() play_clock = PlayClock()
self._voice_render = VoiceRender(play_clock, context) self._voice_render = VoiceRender(play_clock, context)
self._video_render = VideoRender(play_clock, context, self) self._video_render = VideoRender(play_clock, context, self)
@ -35,17 +36,21 @@ class HumanRender(AudioHandler):
def __del__(self): def __del__(self):
EventBus().unregister('stop', self._on_stop) EventBus().unregister('stop', self._on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def _on_stop(self, *args, **kwargs): def _on_stop(self, *args, **kwargs):
self.stop() self.stop()
def on_clear_cache(self, *args, **kwargs):
self._queue.clear()
def _on_run(self): def _on_run(self):
logging.info('human render run') logging.info('human render run')
while self._exit_event.is_set() and self._is_running: while self._exit_event.is_set() and self._is_running:
# t = time.time() # t = time.time()
self._run_step() self._run_step()
# delay = time.time() - t # delay = time.time() - t
delay = 0.03805 # - delay delay = 0.038 # - delay
# print(delay) # print(delay)
# if delay <= 0.0: # if delay <= 0.0:
# continue # continue
@ -116,52 +121,4 @@ class HumanRender(AudioHandler):
# self._video_render.stop() # self._video_render.stop()
# self._exit_event.clear() # self._exit_event.clear()
# self._thread.join() # self._thread.join()
'''
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._exit_event.set()
self._thread.start()
def _on_run(self):
logging.info('human render run')
while self._exit_event.is_set():
self._run_step()
time.sleep(0.02)
logging.info('human render exit')
def _run_step(self):
try:
res_frame, idx, audio_frames = self._queue.get(block=True, timeout=.002)
except queue.Empty:
# print('render queue.Empty:')
return None
if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
combine_frame = self._context.frame_list_cycle[idx]
else:
bbox = self._context.coord_list_cycle[idx]
combine_frame = copy.deepcopy(self._context.frame_list_cycle[idx])
y1, y2, x1, x2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except:
return
# combine_frame = get_image(ori_frame,res_frame,bbox)
# t=time.perf_counter()
combine_frame[y1:y2, x1:x2] = res_frame
image = combine_frame
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self._image_render is not None:
self._image_render.on_render(image)
for audio_frame in audio_frames:
frame, type_ = audio_frame
frame = (frame * 32767).astype(np.int16)
if self._audio_render is not None:
self._audio_render.write(frame.tobytes(), int(frame.shape[0]*2))
# new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
# new_frame.planes[0].update(frame.tobytes())
# new_frame.sample_rate = 16000
'''

View File

@ -17,13 +17,19 @@ class NLPBase(AsrObserver):
self._is_running = True self._is_running = True
EventBus().register('stop', self.on_stop) EventBus().register('stop', self.on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
def __del__(self): def __del__(self):
EventBus().unregister('stop', self.on_stop) EventBus().unregister('stop', self.on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def on_stop(self, *args, **kwargs): def on_stop(self, *args, **kwargs):
self.stop() self.stop()
def on_clear_cache(self, *args, **kwargs):
logger.info('NLPBase clear_cache')
self._ask_queue.clear()
@property @property
def callback(self): def callback(self):
return self._callback return self._callback
@ -36,10 +42,10 @@ class NLPBase(AsrObserver):
if self._callback is not None and self._is_running: if self._callback is not None and self._is_running:
self._callback.on_message(txt) self._callback.on_message(txt)
async def _request(self, question): def _request(self, question):
pass pass
async def _on_close(self): def _on_close(self):
pass pass
def process(self, message: str): def process(self, message: str):

View File

@ -17,13 +17,13 @@ class DouBaoSDK:
self.__client = AsyncArk(api_key=token) self.__client = AsyncArk(api_key=token)
self._stream = None self._stream = None
async def request(self, question, handle, callback): def request(self, question, handle, callback):
if self.__client is None: if self.__client is None:
self.__client = AsyncArk(api_key=self._token) self.__client = AsyncArk(api_key=self._token)
t = time.time() t = time.time()
logger.info(f'-------dou_bao ask:{question}') logger.info(f'-------dou_bao ask:{question}')
try: try:
self._stream = await self.__client.chat.completions.create( self._stream = self.__client.chat.completions.create(
model="ep-20241008152048-fsgzf", model="ep-20241008152048-fsgzf",
messages=[ messages=[
{"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"}, {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
@ -33,27 +33,36 @@ class DouBaoSDK:
) )
sec = '' sec = ''
async for completion in self._stream: for completion in self._stream:
sec = sec + completion.choices[0].delta.content sec = sec + completion.choices[0].delta.content
sec, message = handle.handle(sec) sec, message = handle.handle(sec)
if len(message) > 0: if len(message) > 0:
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') # logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
callback(message) callback(message)
callback(sec) callback(sec)
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') # logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
await self._stream.close() self._stream.close()
self._stream = None self._stream = None
except Exception as e: except Exception as e:
print(e) logger.error(f'-------dou_bao error:{e}')
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') # logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
async def close(self): def close(self):
if self._stream is not None: if self._stream is not None:
await self._stream.close() self._stream.close()
self._stream = None self._stream = None
logger.info('AsyncArk close') logger.info('AsyncArk close')
if self.__client is not None and not self.__client.is_closed(): if self.__client is not None and not self.__client.is_closed():
await self.__client.close() self.__client.close()
self.__client = None
def aclose(self):
if self._stream is not None:
self._stream.close()
self._stream = None
logger.info('AsyncArk close')
if self.__client is not None and not self.__client.is_closed():
self.__client.close()
self.__client = None self.__client = None
@ -79,7 +88,7 @@ class DouBaoHttp:
response = requests.post(url, headers=headers, json=data, stream=True) response = requests.post(url, headers=headers, json=data, stream=True)
return response return response
async def request(self, question, handle, callback): def request(self, question, handle, callback):
t = time.time() t = time.time()
self._requesting = True self._requesting = True
logger.info(f'-------dou_bao ask:{question}') logger.info(f'-------dou_bao ask:{question}')
@ -89,7 +98,7 @@ class DouBaoHttp:
] ]
self._response = self.__request(msg_list) self._response = self.__request(msg_list)
if not self._response.ok: if not self._response.ok:
logger.info(f"请求失败,状态码:{self._response.status_code}") logger.error(f"请求失败,状态码:{self._response.status_code}")
return return
sec = '' sec = ''
for chunk in self._response.iter_lines(): for chunk in self._response.iter_lines():
@ -97,20 +106,35 @@ class DouBaoHttp:
if len(content) < 1: if len(content) < 1:
continue continue
content = content[5:] content = content[5:]
content = content.strip()
if content == '[DONE]':
break
try:
content = json.loads(content) content = json.loads(content)
except Exception as e:
logger.error(f"json解析失败错误信息{e, content}")
continue
sec = sec + content["choices"][0]["delta"]["content"] sec = sec + content["choices"][0]["delta"]["content"]
sec, message = handle.handle(sec) sec, message = handle.handle(sec)
if len(message) > 0: if len(message) > 0:
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
callback(message) callback(message)
if len(sec) > 0:
callback(sec) callback(sec)
self._requesting = False self._requesting = False
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
async def close(self): def close(self):
if self._response is not None and self._requesting: if self._response is not None and self._requesting:
self._response.close() self._response.close()
def aclose(self):
if self._response is not None and self._requesting:
self._response.close()
logger.info('DouBaoHttp close')
class DouBao(NLPBase): class DouBao(NLPBase):
def __init__(self, context, split, callback=None): def __init__(self, context, split, callback=None):
@ -131,10 +155,16 @@ class DouBao(NLPBase):
self.__token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74' self.__token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74'
self._dou_bao = DouBaoHttp(self.__token) self._dou_bao = DouBaoHttp(self.__token)
async def _request(self, question): def _request(self, question):
await self._dou_bao.request(question, self._split_handle, self._on_callback) self._dou_bao.request(question, self._split_handle, self._on_callback)
async def _on_close(self): def _on_close(self):
if self._dou_bao is not None: if self._dou_bao is not None:
await self._dou_bao.close() self._dou_bao.close()
logger.info('AsyncArk close') logger.info('AsyncArk close')
def on_clear_cache(self, *args, **kwargs):
super().on_clear_cache(*args, **kwargs)
if self._dou_bao is not None:
self._dou_bao.aclose()
logger.info('DouBao clear_cache')

View File

@ -16,6 +16,7 @@ class VoiceRender(BaseRender):
def __init__(self, play_clock, context): def __init__(self, play_clock, context):
self._audio_render = AudioRender() self._audio_render = AudioRender()
super().__init__(play_clock, context, 'Voice') super().__init__(play_clock, context, 'Voice')
self._current_text = ''
def render(self, frame, ps): def render(self, frame, ps):
self._play_clock.update_display_time() self._play_clock.update_display_time()
@ -23,12 +24,16 @@ class VoiceRender(BaseRender):
for audio_frame in frame: for audio_frame in frame:
frame, type_ = audio_frame frame, type_ = audio_frame
frame = (frame * 32767).astype(np.int16) chunk, txt = frame
if txt != self._current_text:
self._current_text = txt
logging.info(f'VoiceRender: {txt}')
chunk = (chunk * 32767).astype(np.int16)
if self._audio_render is not None: if self._audio_render is not None:
try: try:
chunk_len = int(frame.shape[0] * 2) chunk_len = int(chunk.shape[0] * 2)
# print('audio frame:', frame.shape, chunk_len) # print('audio frame:', frame.shape, chunk_len)
self._audio_render.write(frame.tobytes(), chunk_len) self._audio_render.write(chunk.tobytes(), chunk_len)
except Exception as e: except Exception as e:
logging.error(f'Error writing audio frame: {e}') logging.error(f'Error writing audio frame: {e}')

View File

@ -3,6 +3,8 @@ import heapq
import logging import logging
import os import os
import shutil import shutil
import time
from threading import Lock, Thread
from eventbus import EventBus from eventbus import EventBus
from utils import save_wav from utils import save_wav
@ -18,13 +20,18 @@ class TTSAudioHandle(AudioHandler):
self._index = -1 self._index = -1
EventBus().register('stop', self._on_stop) EventBus().register('stop', self._on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
def __del__(self): def __del__(self):
EventBus().unregister('stop', self._on_stop) EventBus().unregister('stop', self._on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def _on_stop(self, *args, **kwargs): def _on_stop(self, *args, **kwargs):
self.stop() self.stop()
def on_clear_cache(self, *args, **kwargs):
self._index = -1
@property @property
def sample_rate(self): def sample_rate(self):
return self._sample_rate return self._sample_rate
@ -53,40 +60,54 @@ class TTSAudioSplitHandle(TTSAudioHandle):
self.sample_rate = self._context.sample_rate self.sample_rate = self._context.sample_rate
self._chunk = self.sample_rate // self._context.fps self._chunk = self.sample_rate // self._context.fps
self._priority_queue = [] self._priority_queue = []
self._lock = Lock()
self._current = 0 self._current = 0
self._is_running = True self._is_running = True
self._thread = Thread(target=self._process_loop)
self._thread.start()
logger.info("TTSAudioSplitHandle init") logger.info("TTSAudioSplitHandle init")
def _process_loop(self):
while self._is_running:
with self._lock:
if self._priority_queue and self._priority_queue[0][0] == self._current:
self._current += 1
chunks, txt = heapq.heappop(self._priority_queue)[1]
if chunks is not None:
for chunk in chunks:
self.on_next_handle((chunk, txt), 0)
time.sleep(0.01) # Sleep briefly to prevent busy-waiting
def on_handle(self, stream, index): def on_handle(self, stream, index):
if not self._is_running: if not self._is_running:
logger.info('TTSAudioSplitHandle::on_handle is not running') logger.info('TTSAudioSplitHandle::on_handle is not running')
return return
# heapq.heappush(self._priority_queue, (index, stream))
if stream is None: logger.info(f'TTSAudioSplitHandle::on_handle {index}')
s, txt = stream
if s is None:
heapq.heappush(self._priority_queue, (index, None)) heapq.heappush(self._priority_queue, (index, None))
else: else:
stream_len = stream.shape[0] stream_len = s.shape[0]
idx = 0 idx = 0
chunks = [] chunks = []
while stream_len >= self._chunk and self._is_running: while stream_len >= self._chunk and self._is_running:
# self.on_next_handle(stream[idx:idx + self._chunk], 0) chunks.append(s[idx:idx + self._chunk])
chunks.append(stream[idx:idx + self._chunk])
stream_len -= self._chunk stream_len -= self._chunk
idx += self._chunk idx += self._chunk
heapq.heappush(self._priority_queue, (index, chunks)) if not self._is_running:
current = self._priority_queue[0][0] return
print('TTSAudioSplitHandle::on_handle', index, current, self._current) heapq.heappush(self._priority_queue, (index, (chunks, txt)))
if current == self._current and self._is_running:
self._current = self._current + 1
chunks = heapq.heappop(self._priority_queue)[1]
if chunks is None:
pass
else:
for chunk in chunks:
self.on_next_handle(chunk, 0)
def stop(self): def stop(self):
self._is_running = False self._is_running = False
self._thread.join()
def on_clear_cache(self, *args, **kwargs):
super().on_clear_cache()
with self._lock:
self._current = 0
self._priority_queue.clear()
class TTSAudioSaveHandle(TTSAudioHandle): class TTSAudioSaveHandle(TTSAudioHandle):

View File

@ -16,13 +16,19 @@ class TTSBase(NLPCallback):
self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5) self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5)
self._is_running = True self._is_running = True
EventBus().register('stop', self.on_stop) EventBus().register('stop', self.on_stop)
EventBus().register('clear_cache', self.on_clear_cache)
def __del__(self): def __del__(self):
EventBus().unregister('stop', self.on_stop) EventBus().unregister('stop', self.on_stop)
EventBus().unregister('clear_cache', self.on_clear_cache)
def on_stop(self, *args, **kwargs): def on_stop(self, *args, **kwargs):
self.stop() self.stop()
def on_clear_cache(self, *args, **kwargs):
logger.info('TTSBase clear_cache')
self._message_queue.clear()
@property @property
def handle(self): def handle(self):
return self._handle return self._handle
@ -31,32 +37,39 @@ class TTSBase(NLPCallback):
def handle(self, value): def handle(self, value):
self._handle = value self._handle = value
async def _request(self, txt: str, index): def _request(self, txt: str, index):
# print('_request:', txt) if not self._is_running:
t = time.time() logger.info('TTSBase::_request is not running')
stream = await self._on_request(txt) return
if stream is None:
logger.warn(f'-------stream is None') t = time.time()
stream = self._on_request(txt)
logger.info(f'-------tts request time:{time.time() - t:.4f}s, txt:{txt}')
if stream is None or self._is_running is False:
logger.warning(f'-------stream is None or is_running {self._is_running}')
return return
logger.info(f'-------tts time:{time.time() - t:.4f}s, txt:{txt}')
if self._handle is not None and self._is_running: if self._handle is not None and self._is_running:
await self._on_handle(stream, index) self._on_handle((stream, txt), index)
else: else:
logger.info(f'handle is None, running:{self._is_running}') logger.info(f'handle is None, running:{self._is_running}')
logger.info(f'-------tts finish time:{time.time() - t:.4f}s, txt:{txt}')
async def _on_request(self, text: str): def _on_request(self, text: str):
pass pass
async def _on_handle(self, stream, index): def _on_handle(self, stream, index):
pass pass
async def _on_close(self): def _on_close(self):
pass pass
def on_message(self, txt: str): def on_message(self, txt: str):
self.message(txt) self.message(txt)
def message(self, txt): def message(self, txt):
if not self._is_running:
logger.info('TTSBase::message is not running')
return
txt = txt.strip() txt = txt.strip()
if len(txt) == 0: if len(txt) == 0:
# logger.info(f'message is empty') # logger.info(f'message is empty')
@ -66,6 +79,7 @@ class TTSBase(NLPCallback):
if self._handle is not None: if self._handle is not None:
index = self._handle.get_index() index = self._handle.get_index()
# print(f'message txt-index:{txt}, index {index}') # print(f'message txt-index:{txt}, index {index}')
logger.info(f'TTSBase::message request:{txt}, index:{index}')
self._message_queue.add_task(self._request, txt, index) self._message_queue.add_task(self._request, txt, index)
def stop(self): def stop(self):

View File

@ -1,5 +1,6 @@
#encoding = utf8 #encoding = utf8
import logging import logging
import time
from io import BytesIO from io import BytesIO
import aiohttp import aiohttp
@ -21,13 +22,14 @@ class TTSEdgeHttp(TTSBase):
# self._url = 'http://localhost:8082/v1/audio/speech' # self._url = 'http://localhost:8082/v1/audio/speech'
self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech' self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech'
logger.info(f"TTSEdge init, {voice}") logger.info(f"TTSEdge init, {voice}")
self._response_list = []
async def _on_async_request(self, data): def _on_async_request(self, data):
async with aiohttp.ClientSession() as session: with aiohttp.ClientSession() as session:
async with session.post(self._url, json=data) as response: with session.post(self._url, json=data) as response:
print('TTSEdgeHttp, _on_request, response:', response) print('TTSEdgeHttp, _on_request, response:', response)
if response.status == 200: if response.status == 200:
stream = BytesIO(await response.read()) stream = BytesIO(response.read())
return stream return stream
else: else:
byte_stream = None byte_stream = None
@ -35,13 +37,14 @@ class TTSEdgeHttp(TTSBase):
def _on_sync_request(self, data): def _on_sync_request(self, data):
response = requests.post(self._url, json=data) response = requests.post(self._url, json=data)
self._response_list.append(response)
stream = None
if response.status_code == 200: if response.status_code == 200:
stream = BytesIO(response.content) stream = BytesIO(response.content)
self._response_list.remove(response)
return stream return stream
else:
return None
async def _on_request(self, txt: str): def _on_request(self, txt: str):
logger.info(f'TTSEdgeHttp, _on_request, txt:{txt}') logger.info(f'TTSEdgeHttp, _on_request, txt:{txt}')
data = { data = {
"model": "tts-1", "model": "tts-1",
@ -54,23 +57,25 @@ class TTSEdgeHttp(TTSBase):
# return self._on_async_request(data) # return self._on_async_request(data)
return self._on_sync_request(data) return self._on_sync_request(data)
async def _on_handle(self, stream, index): def _on_handle(self, stream, index):
print('-------tts _on_handle') st, txt = stream
try: try:
stream.seek(0) st.seek(0)
byte_stream = self.__create_bytes_stream(stream) t = time.time()
print('-------tts start push chunk', index) byte_stream = self.__create_bytes_stream(st)
self._handle.on_handle(byte_stream, index) logger.info(f'-------tts resample time:{time.time() - t:.4f}s, txt:{txt}')
stream.seek(0) t = time.time()
stream.truncate() self._handle.on_handle((byte_stream, txt), index)
print('-------tts finish push chunk') logger.info(f'-------tts handle time:{time.time() - t:.4f}s')
st.seek(0)
st.truncate()
except Exception as e: except Exception as e:
self._handle.on_handle(None, index) self._handle.on_handle(None, index)
stream.seek(0) st.seek(0)
stream.truncate() st.truncate()
print('-------tts finish error:', e) logger.error(f'-------tts finish error:{e}')
stream.close() st.close()
def __create_bytes_stream(self, byte_stream): def __create_bytes_stream(self, byte_stream):
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
@ -87,7 +92,13 @@ class TTSEdgeHttp(TTSBase):
return stream return stream
async def _on_close(self): def _on_close(self):
print('TTSEdge close') print('TTSEdge close')
# if self._byte_stream is not None and not self._byte_stream.closed: # if self._byte_stream is not None and not self._byte_stream.closed:
# self._byte_stream.close() # self._byte_stream.close()
def on_clear_cache(self, *args, **kwargs):
logger.info('TTSEdgeHttp clear_cache')
super().on_clear_cache(*args, **kwargs)
for response in self._response_list:
response.close()

View File

@ -1,6 +1,5 @@
#encoding = utf8 #encoding = utf8
import asyncio
import logging import logging
from queue import Queue from queue import Queue
import threading import threading
@ -14,24 +13,14 @@ class AsyncTaskQueue:
self._worker_num = work_num self._worker_num = work_num
self._current_worker_num = work_num self._current_worker_num = work_num
self._name = name self._name = name
self._thread = threading.Thread(target=self._run_loop, name=name) self._threads = []
self._thread.start() self._lock = threading.Lock()
self.__loop = None for _ in range(work_num):
thread = threading.Thread(target=self._worker, name=f'{name}_worker_{_}')
thread.start()
self._threads.append(thread)
def _run_loop(self): def _worker(self):
logging.info(f'{self._name}, _run_loop')
self.__loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.__loop)
self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
try:
self.__loop.run_forever()
finally:
logging.info(f'{self._name}, exit run')
self.__loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(self.__loop)))
self.__loop.close()
logging.info(f'{self._name}, close loop')
async def _worker(self):
logging.info(f'{self._name}, _worker') logging.info(f'{self._name}, _worker')
while True: while True:
try: try:
@ -43,16 +32,17 @@ class AsyncTaskQueue:
if func is None: # None as a stop signal if func is None: # None as a stop signal
break break
await func(*args) # Execute async function func(*args) # Execute function
except Exception as e: except Exception as e:
logging.error(f'{self._name} error: {e}') logging.error(f'{self._name} error: {repr(e)}')
finally: finally:
self._queue.task_done() self._queue.task_done()
logging.info(f'{self._name}, _worker finish') logging.info(f'{self._name}, _worker finish')
with self._lock:
self._current_worker_num -= 1 self._current_worker_num -= 1
if self._current_worker_num == 0: if self._current_worker_num == 0:
self.__loop.call_soon_threadsafe(self.__loop.stop) self._queue.put(None) # Send stop signal to remaining workers
def add_task(self, func, *args): def add_task(self, func, *args):
self._queue.put((func, *args)) self._queue.put((func, *args))
@ -62,10 +52,10 @@ class AsyncTaskQueue:
self.add_task(None) # Send stop signal self.add_task(None) # Send stop signal
def clear(self): def clear(self):
while not self._queue.empty(): with self._queue.mutex:
self._queue.get_nowait() self._queue.queue.clear()
self._queue.task_done()
def stop(self): def stop(self):
self.stop_workers() self.stop_workers()
self._thread.join() for thread in self._threads:
thread.join()