Compare commits
13 Commits
7406552289
...
c4eb191e18
Author | SHA1 | Date | |
---|---|---|---|
c4eb191e18 | |||
4e98ee0e76 | |||
aa554b1209 | |||
7eccc99c2a | |||
88a307ed6a | |||
742340971b | |||
b3bbf40d95 | |||
d08a74b4e4 | |||
7cbe6d073b | |||
dde41769bf | |||
a6153b1eb9 | |||
3334918ed1 | |||
92b912e162 |
@ -17,7 +17,7 @@ class AsrBase:
|
||||
|
||||
self._stop_event = threading.Event()
|
||||
self._stop_event.set()
|
||||
self._thread = threading.Thread(target=self._recognize_loop)
|
||||
self._thread = threading.Thread(target=self._recognize_loop, name="AsrBaseThread")
|
||||
self._thread.start()
|
||||
|
||||
def __del__(self):
|
||||
@ -34,6 +34,7 @@ class AsrBase:
|
||||
observer.process(message)
|
||||
|
||||
def _notify_complete(self, message: str):
|
||||
EventBus().post('clear_cache')
|
||||
for observer in self._observers:
|
||||
observer.completed(message)
|
||||
|
||||
|
@ -63,6 +63,15 @@ class SherpaNcnnAsr(AsrBase):
|
||||
logger.info(f'_recognize_loop')
|
||||
print(f'_recognize_loop')
|
||||
|
||||
while self._stop_event.is_set():
|
||||
logger.info(f'_recognize_loop000')
|
||||
self._notify_complete('介绍中国5000年历史文学')
|
||||
logger.info(f'_recognize_loop111')
|
||||
segment_id += 1
|
||||
time.sleep(150)
|
||||
logger.info(f'_recognize_loop222')
|
||||
logger.info(f'_recognize_loop exit')
|
||||
'''
|
||||
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
|
||||
while self._stop_event.is_set():
|
||||
samples, _ = s.read(self._samples_per_read) # a blocking read
|
||||
@ -84,12 +93,4 @@ class SherpaNcnnAsr(AsrBase):
|
||||
segment_id += 1
|
||||
self._recognizer.reset()
|
||||
'''
|
||||
while self._stop_event.is_set():
|
||||
logger.info(f'_recognize_loop000')
|
||||
self._notify_complete('介绍中国5000年历史文学')
|
||||
logger.info(f'_recognize_loop111')
|
||||
segment_id += 1
|
||||
time.sleep(60)
|
||||
logger.info(f'_recognize_loop222')
|
||||
logger.info(f'_recognize_loop exit')
|
||||
'''
|
||||
|
||||
|
@ -22,6 +22,7 @@ class AudioInferenceHandler(AudioHandler):
|
||||
super().__init__(context, handler)
|
||||
|
||||
EventBus().register('stop', self._on_stop)
|
||||
EventBus().register('clear_cache', self.on_clear_cache)
|
||||
self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel")
|
||||
self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio")
|
||||
|
||||
@ -35,10 +36,15 @@ class AudioInferenceHandler(AudioHandler):
|
||||
|
||||
def __del__(self):
|
||||
EventBus().unregister('stop', self._on_stop)
|
||||
EventBus().unregister('clear_cache', self.on_clear_cache)
|
||||
|
||||
def _on_stop(self, *args, **kwargs):
|
||||
self.stop()
|
||||
|
||||
def on_clear_cache(self, *args, **kwargs):
|
||||
self._mal_queue.clear()
|
||||
self._audio_queue.clear()
|
||||
|
||||
def on_handle(self, stream, type_):
|
||||
if not self._is_running:
|
||||
return
|
||||
@ -82,9 +88,11 @@ class AudioInferenceHandler(AudioHandler):
|
||||
# print('origin mel_batch:', len(mel_batch))
|
||||
is_all_silence = True
|
||||
audio_frames = []
|
||||
current_text = ''
|
||||
for _ in range(batch_size * 2):
|
||||
frame, type_ = self._audio_queue.get()
|
||||
# print('AudioInferenceHandler type_', type_)
|
||||
current_text = frame[1]
|
||||
audio_frames.append((frame, type_))
|
||||
if type_ == 0:
|
||||
is_all_silence = False
|
||||
@ -101,7 +109,7 @@ class AudioInferenceHandler(AudioHandler):
|
||||
0)
|
||||
index = index + 1
|
||||
else:
|
||||
logger.info('infer=======')
|
||||
logger.info(f'infer======= {current_text}')
|
||||
t = time.perf_counter()
|
||||
img_batch = []
|
||||
# for i in range(batch_size):
|
||||
|
@ -19,6 +19,7 @@ class AudioMalHandler(AudioHandler):
|
||||
super().__init__(context, handler)
|
||||
|
||||
EventBus().register('stop', self._on_stop)
|
||||
EventBus().register('clear_cache', self.on_clear_cache)
|
||||
|
||||
self._is_running = True
|
||||
self._queue = SyncQueue(context.batch_size * 2, "AudioMalHandler_queue")
|
||||
@ -34,15 +35,20 @@ class AudioMalHandler(AudioHandler):
|
||||
|
||||
def __del__(self):
|
||||
EventBus().unregister('stop', self._on_stop)
|
||||
EventBus().unregister('clear_cache', self.on_clear_cache)
|
||||
|
||||
def _on_stop(self, *args, **kwargs):
|
||||
self.stop()
|
||||
|
||||
def on_clear_cache(self, *args, **kwargs):
|
||||
self.frames.clear()
|
||||
self._queue.clear()
|
||||
|
||||
def on_message(self, message):
|
||||
super().on_message(message)
|
||||
|
||||
def on_handle(self, stream, index):
|
||||
# print('AudioMalHandler on_handle', index)
|
||||
# logging.info(f'AudioMalHandler on_handle {index}')
|
||||
self._queue.put(stream)
|
||||
|
||||
def _on_run(self):
|
||||
@ -57,7 +63,8 @@ class AudioMalHandler(AudioHandler):
|
||||
count = 0
|
||||
for _ in range(self._context.batch_size * 2):
|
||||
frame, _type = self.get_audio_frame()
|
||||
self.frames.append(frame)
|
||||
chunk, txt = frame
|
||||
self.frames.append(chunk)
|
||||
self.on_next_handle((frame, _type), 0)
|
||||
count = count + 1
|
||||
|
||||
@ -97,9 +104,10 @@ class AudioMalHandler(AudioHandler):
|
||||
frame = self._queue.get()
|
||||
type_ = 0
|
||||
else:
|
||||
frame = np.zeros(self.chunk, dtype=np.float32)
|
||||
chunk = np.zeros(self.chunk, dtype=np.float32)
|
||||
frame = (chunk, '')
|
||||
type_ = 1
|
||||
# print('AudioMalHandler get_audio_frame type:', type_)
|
||||
# logging.info(f'AudioMalHandler get_audio_frame type:{type_}')
|
||||
return frame, type_
|
||||
|
||||
def stop(self):
|
||||
|
@ -19,6 +19,7 @@ class HumanRender(AudioHandler):
|
||||
super().__init__(context, handler)
|
||||
|
||||
EventBus().register('stop', self._on_stop)
|
||||
EventBus().register('clear_cache', self.on_clear_cache)
|
||||
play_clock = PlayClock()
|
||||
self._voice_render = VoiceRender(play_clock, context)
|
||||
self._video_render = VideoRender(play_clock, context, self)
|
||||
@ -35,17 +36,21 @@ class HumanRender(AudioHandler):
|
||||
|
||||
def __del__(self):
|
||||
EventBus().unregister('stop', self._on_stop)
|
||||
EventBus().unregister('clear_cache', self.on_clear_cache)
|
||||
|
||||
def _on_stop(self, *args, **kwargs):
|
||||
self.stop()
|
||||
|
||||
def on_clear_cache(self, *args, **kwargs):
|
||||
self._queue.clear()
|
||||
|
||||
def _on_run(self):
|
||||
logging.info('human render run')
|
||||
while self._exit_event.is_set() and self._is_running:
|
||||
# t = time.time()
|
||||
self._run_step()
|
||||
# delay = time.time() - t
|
||||
delay = 0.03805 # - delay
|
||||
delay = 0.038 # - delay
|
||||
# print(delay)
|
||||
# if delay <= 0.0:
|
||||
# continue
|
||||
@ -116,52 +121,4 @@ class HumanRender(AudioHandler):
|
||||
# self._video_render.stop()
|
||||
# self._exit_event.clear()
|
||||
# self._thread.join()
|
||||
'''
|
||||
self._exit_event = Event()
|
||||
self._thread = Thread(target=self._on_run)
|
||||
self._exit_event.set()
|
||||
self._thread.start()
|
||||
|
||||
def _on_run(self):
|
||||
logging.info('human render run')
|
||||
while self._exit_event.is_set():
|
||||
self._run_step()
|
||||
time.sleep(0.02)
|
||||
|
||||
logging.info('human render exit')
|
||||
|
||||
def _run_step(self):
|
||||
try:
|
||||
res_frame, idx, audio_frames = self._queue.get(block=True, timeout=.002)
|
||||
except queue.Empty:
|
||||
# print('render queue.Empty:')
|
||||
return None
|
||||
if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
|
||||
combine_frame = self._context.frame_list_cycle[idx]
|
||||
else:
|
||||
bbox = self._context.coord_list_cycle[idx]
|
||||
combine_frame = copy.deepcopy(self._context.frame_list_cycle[idx])
|
||||
y1, y2, x1, x2 = bbox
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
||||
except:
|
||||
return
|
||||
# combine_frame = get_image(ori_frame,res_frame,bbox)
|
||||
# t=time.perf_counter()
|
||||
combine_frame[y1:y2, x1:x2] = res_frame
|
||||
|
||||
image = combine_frame
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self._image_render is not None:
|
||||
self._image_render.on_render(image)
|
||||
|
||||
for audio_frame in audio_frames:
|
||||
frame, type_ = audio_frame
|
||||
frame = (frame * 32767).astype(np.int16)
|
||||
if self._audio_render is not None:
|
||||
self._audio_render.write(frame.tobytes(), int(frame.shape[0]*2))
|
||||
# new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
|
||||
# new_frame.planes[0].update(frame.tobytes())
|
||||
# new_frame.sample_rate = 16000
|
||||
'''
|
||||
|
@ -17,13 +17,19 @@ class NLPBase(AsrObserver):
|
||||
self._is_running = True
|
||||
|
||||
EventBus().register('stop', self.on_stop)
|
||||
EventBus().register('clear_cache', self.on_clear_cache)
|
||||
|
||||
def __del__(self):
|
||||
EventBus().unregister('stop', self.on_stop)
|
||||
EventBus().unregister('clear_cache', self.on_clear_cache)
|
||||
|
||||
def on_stop(self, *args, **kwargs):
|
||||
self.stop()
|
||||
|
||||
def on_clear_cache(self, *args, **kwargs):
|
||||
logger.info('NLPBase clear_cache')
|
||||
self._ask_queue.clear()
|
||||
|
||||
@property
|
||||
def callback(self):
|
||||
return self._callback
|
||||
@ -36,10 +42,10 @@ class NLPBase(AsrObserver):
|
||||
if self._callback is not None and self._is_running:
|
||||
self._callback.on_message(txt)
|
||||
|
||||
async def _request(self, question):
|
||||
def _request(self, question):
|
||||
pass
|
||||
|
||||
async def _on_close(self):
|
||||
def _on_close(self):
|
||||
pass
|
||||
|
||||
def process(self, message: str):
|
||||
|
@ -17,13 +17,13 @@ class DouBaoSDK:
|
||||
self.__client = AsyncArk(api_key=token)
|
||||
self._stream = None
|
||||
|
||||
async def request(self, question, handle, callback):
|
||||
def request(self, question, handle, callback):
|
||||
if self.__client is None:
|
||||
self.__client = AsyncArk(api_key=self._token)
|
||||
t = time.time()
|
||||
logger.info(f'-------dou_bao ask:{question}')
|
||||
try:
|
||||
self._stream = await self.__client.chat.completions.create(
|
||||
self._stream = self.__client.chat.completions.create(
|
||||
model="ep-20241008152048-fsgzf",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
|
||||
@ -33,27 +33,36 @@ class DouBaoSDK:
|
||||
)
|
||||
|
||||
sec = ''
|
||||
async for completion in self._stream:
|
||||
for completion in self._stream:
|
||||
sec = sec + completion.choices[0].delta.content
|
||||
sec, message = handle.handle(sec)
|
||||
if len(message) > 0:
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
# logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
callback(message)
|
||||
callback(sec)
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
await self._stream.close()
|
||||
# logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
self._stream.close()
|
||||
self._stream = None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
logger.error(f'-------dou_bao error:{e}')
|
||||
# logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
|
||||
async def close(self):
|
||||
def close(self):
|
||||
if self._stream is not None:
|
||||
await self._stream.close()
|
||||
self._stream.close()
|
||||
self._stream = None
|
||||
logger.info('AsyncArk close')
|
||||
if self.__client is not None and not self.__client.is_closed():
|
||||
await self.__client.close()
|
||||
self.__client.close()
|
||||
self.__client = None
|
||||
|
||||
def aclose(self):
|
||||
if self._stream is not None:
|
||||
self._stream.close()
|
||||
self._stream = None
|
||||
logger.info('AsyncArk close')
|
||||
if self.__client is not None and not self.__client.is_closed():
|
||||
self.__client.close()
|
||||
self.__client = None
|
||||
|
||||
|
||||
@ -79,7 +88,7 @@ class DouBaoHttp:
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
return response
|
||||
|
||||
async def request(self, question, handle, callback):
|
||||
def request(self, question, handle, callback):
|
||||
t = time.time()
|
||||
self._requesting = True
|
||||
logger.info(f'-------dou_bao ask:{question}')
|
||||
@ -89,7 +98,7 @@ class DouBaoHttp:
|
||||
]
|
||||
self._response = self.__request(msg_list)
|
||||
if not self._response.ok:
|
||||
logger.info(f"请求失败,状态码:{self._response.status_code}")
|
||||
logger.error(f"请求失败,状态码:{self._response.status_code}")
|
||||
return
|
||||
sec = ''
|
||||
for chunk in self._response.iter_lines():
|
||||
@ -97,20 +106,35 @@ class DouBaoHttp:
|
||||
if len(content) < 1:
|
||||
continue
|
||||
content = content[5:]
|
||||
content = json.loads(content)
|
||||
content = content.strip()
|
||||
if content == '[DONE]':
|
||||
break
|
||||
|
||||
try:
|
||||
content = json.loads(content)
|
||||
except Exception as e:
|
||||
logger.error(f"json解析失败,错误信息:{e, content}")
|
||||
continue
|
||||
sec = sec + content["choices"][0]["delta"]["content"]
|
||||
sec, message = handle.handle(sec)
|
||||
if len(message) > 0:
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
callback(message)
|
||||
callback(sec)
|
||||
if len(sec) > 0:
|
||||
callback(sec)
|
||||
|
||||
self._requesting = False
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
|
||||
async def close(self):
|
||||
def close(self):
|
||||
if self._response is not None and self._requesting:
|
||||
self._response.close()
|
||||
|
||||
def aclose(self):
|
||||
if self._response is not None and self._requesting:
|
||||
self._response.close()
|
||||
logger.info('DouBaoHttp close')
|
||||
|
||||
|
||||
class DouBao(NLPBase):
|
||||
def __init__(self, context, split, callback=None):
|
||||
@ -131,10 +155,16 @@ class DouBao(NLPBase):
|
||||
self.__token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74'
|
||||
self._dou_bao = DouBaoHttp(self.__token)
|
||||
|
||||
async def _request(self, question):
|
||||
await self._dou_bao.request(question, self._split_handle, self._on_callback)
|
||||
def _request(self, question):
|
||||
self._dou_bao.request(question, self._split_handle, self._on_callback)
|
||||
|
||||
async def _on_close(self):
|
||||
def _on_close(self):
|
||||
if self._dou_bao is not None:
|
||||
await self._dou_bao.close()
|
||||
self._dou_bao.close()
|
||||
logger.info('AsyncArk close')
|
||||
|
||||
def on_clear_cache(self, *args, **kwargs):
|
||||
super().on_clear_cache(*args, **kwargs)
|
||||
if self._dou_bao is not None:
|
||||
self._dou_bao.aclose()
|
||||
logger.info('DouBao clear_cache')
|
||||
|
@ -16,6 +16,7 @@ class VoiceRender(BaseRender):
|
||||
def __init__(self, play_clock, context):
|
||||
self._audio_render = AudioRender()
|
||||
super().__init__(play_clock, context, 'Voice')
|
||||
self._current_text = ''
|
||||
|
||||
def render(self, frame, ps):
|
||||
self._play_clock.update_display_time()
|
||||
@ -23,12 +24,16 @@ class VoiceRender(BaseRender):
|
||||
|
||||
for audio_frame in frame:
|
||||
frame, type_ = audio_frame
|
||||
frame = (frame * 32767).astype(np.int16)
|
||||
chunk, txt = frame
|
||||
if txt != self._current_text:
|
||||
self._current_text = txt
|
||||
logging.info(f'VoiceRender: {txt}')
|
||||
chunk = (chunk * 32767).astype(np.int16)
|
||||
|
||||
if self._audio_render is not None:
|
||||
try:
|
||||
chunk_len = int(frame.shape[0] * 2)
|
||||
chunk_len = int(chunk.shape[0] * 2)
|
||||
# print('audio frame:', frame.shape, chunk_len)
|
||||
self._audio_render.write(frame.tobytes(), chunk_len)
|
||||
self._audio_render.write(chunk.tobytes(), chunk_len)
|
||||
except Exception as e:
|
||||
logging.error(f'Error writing audio frame: {e}')
|
||||
|
@ -3,6 +3,8 @@ import heapq
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from threading import Lock, Thread
|
||||
|
||||
from eventbus import EventBus
|
||||
from utils import save_wav
|
||||
@ -18,13 +20,18 @@ class TTSAudioHandle(AudioHandler):
|
||||
self._index = -1
|
||||
|
||||
EventBus().register('stop', self._on_stop)
|
||||
EventBus().register('clear_cache', self.on_clear_cache)
|
||||
|
||||
def __del__(self):
|
||||
EventBus().unregister('stop', self._on_stop)
|
||||
EventBus().unregister('clear_cache', self.on_clear_cache)
|
||||
|
||||
def _on_stop(self, *args, **kwargs):
|
||||
self.stop()
|
||||
|
||||
def on_clear_cache(self, *args, **kwargs):
|
||||
self._index = -1
|
||||
|
||||
@property
|
||||
def sample_rate(self):
|
||||
return self._sample_rate
|
||||
@ -53,40 +60,54 @@ class TTSAudioSplitHandle(TTSAudioHandle):
|
||||
self.sample_rate = self._context.sample_rate
|
||||
self._chunk = self.sample_rate // self._context.fps
|
||||
self._priority_queue = []
|
||||
self._lock = Lock()
|
||||
self._current = 0
|
||||
self._is_running = True
|
||||
self._thread = Thread(target=self._process_loop)
|
||||
self._thread.start()
|
||||
logger.info("TTSAudioSplitHandle init")
|
||||
|
||||
def _process_loop(self):
|
||||
while self._is_running:
|
||||
with self._lock:
|
||||
if self._priority_queue and self._priority_queue[0][0] == self._current:
|
||||
self._current += 1
|
||||
chunks, txt = heapq.heappop(self._priority_queue)[1]
|
||||
if chunks is not None:
|
||||
for chunk in chunks:
|
||||
self.on_next_handle((chunk, txt), 0)
|
||||
time.sleep(0.01) # Sleep briefly to prevent busy-waiting
|
||||
|
||||
def on_handle(self, stream, index):
|
||||
if not self._is_running:
|
||||
logger.info('TTSAudioSplitHandle::on_handle is not running')
|
||||
return
|
||||
# heapq.heappush(self._priority_queue, (index, stream))
|
||||
if stream is None:
|
||||
|
||||
logger.info(f'TTSAudioSplitHandle::on_handle {index}')
|
||||
s, txt = stream
|
||||
if s is None:
|
||||
heapq.heappush(self._priority_queue, (index, None))
|
||||
else:
|
||||
stream_len = stream.shape[0]
|
||||
stream_len = s.shape[0]
|
||||
idx = 0
|
||||
chunks = []
|
||||
while stream_len >= self._chunk and self._is_running:
|
||||
# self.on_next_handle(stream[idx:idx + self._chunk], 0)
|
||||
chunks.append(stream[idx:idx + self._chunk])
|
||||
chunks.append(s[idx:idx + self._chunk])
|
||||
stream_len -= self._chunk
|
||||
idx += self._chunk
|
||||
heapq.heappush(self._priority_queue, (index, chunks))
|
||||
current = self._priority_queue[0][0]
|
||||
print('TTSAudioSplitHandle::on_handle', index, current, self._current)
|
||||
if current == self._current and self._is_running:
|
||||
self._current = self._current + 1
|
||||
chunks = heapq.heappop(self._priority_queue)[1]
|
||||
if chunks is None:
|
||||
pass
|
||||
else:
|
||||
for chunk in chunks:
|
||||
self.on_next_handle(chunk, 0)
|
||||
if not self._is_running:
|
||||
return
|
||||
heapq.heappush(self._priority_queue, (index, (chunks, txt)))
|
||||
|
||||
def stop(self):
|
||||
self._is_running = False
|
||||
self._thread.join()
|
||||
|
||||
def on_clear_cache(self, *args, **kwargs):
|
||||
super().on_clear_cache()
|
||||
with self._lock:
|
||||
self._current = 0
|
||||
self._priority_queue.clear()
|
||||
|
||||
|
||||
class TTSAudioSaveHandle(TTSAudioHandle):
|
||||
|
@ -16,13 +16,19 @@ class TTSBase(NLPCallback):
|
||||
self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5)
|
||||
self._is_running = True
|
||||
EventBus().register('stop', self.on_stop)
|
||||
EventBus().register('clear_cache', self.on_clear_cache)
|
||||
|
||||
def __del__(self):
|
||||
EventBus().unregister('stop', self.on_stop)
|
||||
EventBus().unregister('clear_cache', self.on_clear_cache)
|
||||
|
||||
def on_stop(self, *args, **kwargs):
|
||||
self.stop()
|
||||
|
||||
def on_clear_cache(self, *args, **kwargs):
|
||||
logger.info('TTSBase clear_cache')
|
||||
self._message_queue.clear()
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
return self._handle
|
||||
@ -31,32 +37,39 @@ class TTSBase(NLPCallback):
|
||||
def handle(self, value):
|
||||
self._handle = value
|
||||
|
||||
async def _request(self, txt: str, index):
|
||||
# print('_request:', txt)
|
||||
t = time.time()
|
||||
stream = await self._on_request(txt)
|
||||
if stream is None:
|
||||
logger.warn(f'-------stream is None')
|
||||
def _request(self, txt: str, index):
|
||||
if not self._is_running:
|
||||
logger.info('TTSBase::_request is not running')
|
||||
return
|
||||
|
||||
t = time.time()
|
||||
stream = self._on_request(txt)
|
||||
logger.info(f'-------tts request time:{time.time() - t:.4f}s, txt:{txt}')
|
||||
if stream is None or self._is_running is False:
|
||||
logger.warning(f'-------stream is None or is_running {self._is_running}')
|
||||
return
|
||||
logger.info(f'-------tts time:{time.time() - t:.4f}s, txt:{txt}')
|
||||
if self._handle is not None and self._is_running:
|
||||
await self._on_handle(stream, index)
|
||||
self._on_handle((stream, txt), index)
|
||||
else:
|
||||
logger.info(f'handle is None, running:{self._is_running}')
|
||||
logger.info(f'-------tts finish time:{time.time() - t:.4f}s, txt:{txt}')
|
||||
|
||||
async def _on_request(self, text: str):
|
||||
def _on_request(self, text: str):
|
||||
pass
|
||||
|
||||
async def _on_handle(self, stream, index):
|
||||
def _on_handle(self, stream, index):
|
||||
pass
|
||||
|
||||
async def _on_close(self):
|
||||
def _on_close(self):
|
||||
pass
|
||||
|
||||
def on_message(self, txt: str):
|
||||
self.message(txt)
|
||||
|
||||
def message(self, txt):
|
||||
if not self._is_running:
|
||||
logger.info('TTSBase::message is not running')
|
||||
return
|
||||
txt = txt.strip()
|
||||
if len(txt) == 0:
|
||||
# logger.info(f'message is empty')
|
||||
@ -66,6 +79,7 @@ class TTSBase(NLPCallback):
|
||||
if self._handle is not None:
|
||||
index = self._handle.get_index()
|
||||
# print(f'message txt-index:{txt}, index {index}')
|
||||
logger.info(f'TTSBase::message request:{txt}, index:{index}')
|
||||
self._message_queue.add_task(self._request, txt, index)
|
||||
|
||||
def stop(self):
|
||||
|
@ -1,5 +1,6 @@
|
||||
#encoding = utf8
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
||||
import aiohttp
|
||||
@ -21,13 +22,14 @@ class TTSEdgeHttp(TTSBase):
|
||||
# self._url = 'http://localhost:8082/v1/audio/speech'
|
||||
self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech'
|
||||
logger.info(f"TTSEdge init, {voice}")
|
||||
self._response_list = []
|
||||
|
||||
async def _on_async_request(self, data):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self._url, json=data) as response:
|
||||
def _on_async_request(self, data):
|
||||
with aiohttp.ClientSession() as session:
|
||||
with session.post(self._url, json=data) as response:
|
||||
print('TTSEdgeHttp, _on_request, response:', response)
|
||||
if response.status == 200:
|
||||
stream = BytesIO(await response.read())
|
||||
stream = BytesIO(response.read())
|
||||
return stream
|
||||
else:
|
||||
byte_stream = None
|
||||
@ -35,13 +37,14 @@ class TTSEdgeHttp(TTSBase):
|
||||
|
||||
def _on_sync_request(self, data):
|
||||
response = requests.post(self._url, json=data)
|
||||
self._response_list.append(response)
|
||||
stream = None
|
||||
if response.status_code == 200:
|
||||
stream = BytesIO(response.content)
|
||||
return stream
|
||||
else:
|
||||
return None
|
||||
self._response_list.remove(response)
|
||||
return stream
|
||||
|
||||
async def _on_request(self, txt: str):
|
||||
def _on_request(self, txt: str):
|
||||
logger.info(f'TTSEdgeHttp, _on_request, txt:{txt}')
|
||||
data = {
|
||||
"model": "tts-1",
|
||||
@ -54,23 +57,25 @@ class TTSEdgeHttp(TTSBase):
|
||||
# return self._on_async_request(data)
|
||||
return self._on_sync_request(data)
|
||||
|
||||
async def _on_handle(self, stream, index):
|
||||
print('-------tts _on_handle')
|
||||
def _on_handle(self, stream, index):
|
||||
st, txt = stream
|
||||
try:
|
||||
stream.seek(0)
|
||||
byte_stream = self.__create_bytes_stream(stream)
|
||||
print('-------tts start push chunk', index)
|
||||
self._handle.on_handle(byte_stream, index)
|
||||
stream.seek(0)
|
||||
stream.truncate()
|
||||
print('-------tts finish push chunk')
|
||||
st.seek(0)
|
||||
t = time.time()
|
||||
byte_stream = self.__create_bytes_stream(st)
|
||||
logger.info(f'-------tts resample time:{time.time() - t:.4f}s, txt:{txt}')
|
||||
t = time.time()
|
||||
self._handle.on_handle((byte_stream, txt), index)
|
||||
logger.info(f'-------tts handle time:{time.time() - t:.4f}s')
|
||||
st.seek(0)
|
||||
st.truncate()
|
||||
|
||||
except Exception as e:
|
||||
self._handle.on_handle(None, index)
|
||||
stream.seek(0)
|
||||
stream.truncate()
|
||||
print('-------tts finish error:', e)
|
||||
stream.close()
|
||||
st.seek(0)
|
||||
st.truncate()
|
||||
logger.error(f'-------tts finish error:{e}')
|
||||
st.close()
|
||||
|
||||
def __create_bytes_stream(self, byte_stream):
|
||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||
@ -87,7 +92,13 @@ class TTSEdgeHttp(TTSBase):
|
||||
|
||||
return stream
|
||||
|
||||
async def _on_close(self):
|
||||
def _on_close(self):
|
||||
print('TTSEdge close')
|
||||
# if self._byte_stream is not None and not self._byte_stream.closed:
|
||||
# self._byte_stream.close()
|
||||
|
||||
def on_clear_cache(self, *args, **kwargs):
|
||||
logger.info('TTSEdgeHttp clear_cache')
|
||||
super().on_clear_cache(*args, **kwargs)
|
||||
for response in self._response_list:
|
||||
response.close()
|
||||
|
@ -1,6 +1,5 @@
|
||||
#encoding = utf8
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from queue import Queue
|
||||
import threading
|
||||
@ -14,24 +13,14 @@ class AsyncTaskQueue:
|
||||
self._worker_num = work_num
|
||||
self._current_worker_num = work_num
|
||||
self._name = name
|
||||
self._thread = threading.Thread(target=self._run_loop, name=name)
|
||||
self._thread.start()
|
||||
self.__loop = None
|
||||
self._threads = []
|
||||
self._lock = threading.Lock()
|
||||
for _ in range(work_num):
|
||||
thread = threading.Thread(target=self._worker, name=f'{name}_worker_{_}')
|
||||
thread.start()
|
||||
self._threads.append(thread)
|
||||
|
||||
def _run_loop(self):
|
||||
logging.info(f'{self._name}, _run_loop')
|
||||
self.__loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.__loop)
|
||||
self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
|
||||
try:
|
||||
self.__loop.run_forever()
|
||||
finally:
|
||||
logging.info(f'{self._name}, exit run')
|
||||
self.__loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(self.__loop)))
|
||||
self.__loop.close()
|
||||
logging.info(f'{self._name}, close loop')
|
||||
|
||||
async def _worker(self):
|
||||
def _worker(self):
|
||||
logging.info(f'{self._name}, _worker')
|
||||
while True:
|
||||
try:
|
||||
@ -43,16 +32,17 @@ class AsyncTaskQueue:
|
||||
if func is None: # None as a stop signal
|
||||
break
|
||||
|
||||
await func(*args) # Execute async function
|
||||
func(*args) # Execute function
|
||||
except Exception as e:
|
||||
logging.error(f'{self._name} error: {e}')
|
||||
logging.error(f'{self._name} error: {repr(e)}')
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
logging.info(f'{self._name}, _worker finish')
|
||||
self._current_worker_num -= 1
|
||||
if self._current_worker_num == 0:
|
||||
self.__loop.call_soon_threadsafe(self.__loop.stop)
|
||||
with self._lock:
|
||||
self._current_worker_num -= 1
|
||||
if self._current_worker_num == 0:
|
||||
self._queue.put(None) # Send stop signal to remaining workers
|
||||
|
||||
def add_task(self, func, *args):
|
||||
self._queue.put((func, *args))
|
||||
@ -62,10 +52,10 @@ class AsyncTaskQueue:
|
||||
self.add_task(None) # Send stop signal
|
||||
|
||||
def clear(self):
|
||||
while not self._queue.empty():
|
||||
self._queue.get_nowait()
|
||||
self._queue.task_done()
|
||||
with self._queue.mutex:
|
||||
self._queue.queue.clear()
|
||||
|
||||
def stop(self):
|
||||
self.stop_workers()
|
||||
self._thread.join()
|
||||
for thread in self._threads:
|
||||
thread.join()
|
||||
|
Loading…
Reference in New Issue
Block a user