modify async task

This commit is contained in:
jiegeaiai 2024-10-14 18:20:55 +08:00
parent 9cbeb58f69
commit 5d98af51de
9 changed files with 159 additions and 83 deletions

View File

@ -10,7 +10,6 @@ logger = logging.getLogger(__name__)
class NLPBase(AsrObserver): class NLPBase(AsrObserver):
def __init__(self, split, callback=None): def __init__(self, split, callback=None):
self._ask_queue = AsyncTaskQueue() self._ask_queue = AsyncTaskQueue()
self._ask_queue.start_worker()
self._split_handle = split self._split_handle = split
self._callback = callback self._callback = callback
@ -29,6 +28,9 @@ class NLPBase(AsrObserver):
async def _request(self, question): async def _request(self, question):
pass pass
async def _on_close(self):
pass
def process(self, message: str): def process(self, message: str):
pass pass
@ -38,8 +40,9 @@ class NLPBase(AsrObserver):
def ask(self, question): def ask(self, question):
logger.info(f'ask:{question}') logger.info(f'ask:{question}')
self._ask_queue.add_task(self._request(question)) self._ask_queue.add_task(self._request, question)
def stop(self): def stop(self):
self._ask_queue.add_task(self._on_close)
self._ask_queue.stop() self._ask_queue.stop()

View File

@ -31,29 +31,33 @@ class DouBao(NLPBase):
t = time.time() t = time.time()
logger.info(f'_request:{question}') logger.info(f'_request:{question}')
print(f'-------dou_bao ask:', question) print(f'-------dou_bao ask:', question)
stream = await self.__client.chat.completions.create( try:
model="ep-20241008152048-fsgzf", stream = await self.__client.chat.completions.create(
messages=[ model="ep-20241008152048-fsgzf",
{"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"}, messages=[
{"role": "user", "content": question}, {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
], {"role": "user", "content": question},
stream=True ],
) stream=True
sec = '' )
async for completion in stream: sec = ''
# print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') async for completion in stream:
# nlp_queue.put(completion.choices[0].delta.content) sec = sec + completion.choices[0].delta.content
# print(completion.choices[0].delta.content, end="") sec, message = self._split_handle.handle(sec)
sec = sec + completion.choices[0].delta.content if len(message) > 0:
sec, message = self._split_handle.handle(sec) self._on_callback(message)
if len(message) > 0: self._on_callback(sec)
print(message) await stream.close()
self._on_callback(message) except Exception as e:
print(sec) print(e)
self._on_callback(sec)
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s') logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
async def _on_close(self):
print('AsyncArk close')
if self.__client is not None and not self.__client.is_closed():
await self.__client.close()
''' '''
if __name__ == "__main__": if __name__ == "__main__":
# print(get_dou_bao_api()) # print(get_dou_bao_api())

View File

@ -12,9 +12,8 @@ def main():
nlp.ask('你好') nlp.ask('你好')
nlp.ask('你是谁') nlp.ask('你是谁')
nlp.ask('能做什么') nlp.ask('能做什么')
time.sleep(20)
print("Stop! ")
nlp.stop() nlp.stop()
print("stop")
if __name__ == "__main__": if __name__ == "__main__":
@ -22,3 +21,5 @@ if __name__ == "__main__":
main() main()
except KeyboardInterrupt: except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting") print("\nCaught Ctrl + C. Exiting")
except Exception as e:
print(e)

44
test/test_nlp_tts.py Normal file
View File

@ -0,0 +1,44 @@
#encoding = utf8
import sys
import time
from asr import SherpaNcnnAsr
from nlp import PunctuationSplit
from nlp.nlp_doubao import DouBao
from tts import TTSEdge, TTSAudioSaveHandle
try:
import sounddevice as sd
except ImportError as e:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
def main():
print("Started! Please speak")
handle = TTSAudioSaveHandle()
tts = TTSEdge(handle)
split = PunctuationSplit()
nlp = DouBao(split, tts)
nlp.ask('你好,你是谁?')
nlp.ask('可以帮我做什么?')
nlp.stop()
tts.stop()
print("Stop! ")
if __name__ == "__main__":
devices = sd.query_devices()
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")

View File

@ -2,20 +2,17 @@
import time import time
from tts import TTSEdge from tts import TTSEdge, TTSAudioSaveHandle
def main(): def main():
print("Started! Please speak") handle = TTSAudioSaveHandle()
tts = TTSEdge(handle)
tts = TTSEdge()
tts.message('你好,') tts.message('你好,')
tts.message('请问有什么可以帮到您,') tts.message('请问有什么可以帮到您,')
tts.message('很高兴为您服务。') tts.message('很高兴为您服务。')
tts.message('祝您平安,') tts.message('祝您平安,')
tts.message('再见') tts.message('再见')
time.sleep(20)
tts.stop() tts.stop()
print("Stop! ") print("Stop! ")

View File

@ -9,6 +9,7 @@ from audio import save_wav
class TTSAudioHandle(ABC): class TTSAudioHandle(ABC):
def __init__(self): def __init__(self):
self._sample_rate = 16000 self._sample_rate = 16000
self._index = 1
@property @property
def sample_rate(self): def sample_rate(self):
@ -19,9 +20,13 @@ class TTSAudioHandle(ABC):
self._sample_rate = value self._sample_rate = value
@abstractmethod @abstractmethod
def on_handle(self, stream): def on_handle(self, stream, index):
pass pass
def get_index(self):
self._index = self._index + 1
return self._index
class TTSAudioSplitHandle(TTSAudioHandle): class TTSAudioSplitHandle(TTSAudioHandle):
def __init__(self, human): def __init__(self, human):
@ -30,7 +35,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
self.sample_rate = self._human.get_audio_sample_rate() self.sample_rate = self._human.get_audio_sample_rate()
self._chunk = self.sample_rate // self._human.get_fps() self._chunk = self.sample_rate // self._human.get_fps()
def on_handle(self, stream): def on_handle(self, stream, index):
stream_len = stream.shape[0] stream_len = stream.shape[0]
idx = 0 idx = 0
@ -43,7 +48,6 @@ class TTSAudioSplitHandle(TTSAudioHandle):
class TTSAudioSaveHandle(TTSAudioHandle): class TTSAudioSaveHandle(TTSAudioHandle):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self._count = 1
self._save_path_dir = '../temp/audio/' self._save_path_dir = '../temp/audio/'
self._clean() self._clean()
@ -65,8 +69,7 @@ class TTSAudioSaveHandle(TTSAudioHandle):
shutil.rmtree(file_path) shutil.rmtree(file_path)
print(f"Deleted directory and its contents: {file_path}") print(f"Deleted directory and its contents: {file_path}")
def on_handle(self, stream): def on_handle(self, stream, index):
file_name = self._save_path_dir + str(self._count) + '.wav' file_name = self._save_path_dir + str(index) + '.wav'
save_wav(stream, file_name, self.sample_rate) save_wav(stream, file_name, self.sample_rate)
self._count = self._count + 1

View File

@ -12,8 +12,7 @@ logger = logging.getLogger(__name__)
class TTSBase(NLPCallback): class TTSBase(NLPCallback):
def __init__(self, handle): def __init__(self, handle):
self._handle = handle self._handle = handle
self._message_queue = AsyncTaskQueue() self._message_queue = AsyncTaskQueue(5)
self._message_queue.start_worker()
@property @property
def handle(self): def handle(self):
@ -23,17 +22,23 @@ class TTSBase(NLPCallback):
def handle(self, value): def handle(self, value):
self._handle = value self._handle = value
async def _request(self, txt: str): async def _request(self, txt: str, index):
print('_request:', txt) print('_request:', txt)
t = time.time() t = time.time()
await self._on_request(txt) stream = await self._on_request(txt)
if stream is None:
print(f'-------stream is None')
return
print(f'-------tts time:{time.time() - t:.4f}s') print(f'-------tts time:{time.time() - t:.4f}s')
await self._on_handle() await self._on_handle(stream, index)
async def _on_request(self, text: str): async def _on_request(self, text: str):
pass pass
async def _on_handle(self): async def _on_handle(self, stream, index):
pass
async def _on_close(self):
pass pass
def on_message(self, txt: str): def on_message(self, txt: str):
@ -42,7 +47,11 @@ class TTSBase(NLPCallback):
def message(self, txt): def message(self, txt):
logger.info(f'message:{txt}') logger.info(f'message:{txt}')
print(f'message:{txt}') print(f'message:{txt}')
self._message_queue.add_task(self._request(txt)) index = 0
if self._handle is not None:
index = self._handle.get_index()
self._message_queue.add_task(self._request, txt, index)
def stop(self): def stop(self):
self._message_queue.add_task(self._on_close)
self._message_queue.stop() self._message_queue.stop()

View File

@ -14,34 +14,38 @@ class TTSEdge(TTSBase):
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'): def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
super().__init__(handle) super().__init__(handle)
self._voice = voice self._voice = voice
self._byte_stream = BytesIO()
async def _on_request(self, txt: str): async def _on_request(self, txt: str):
print('_on_request, txt')
communicate = edge_tts.Communicate(txt, self._voice) communicate = edge_tts.Communicate(txt, self._voice)
first = True first = True
byte_stream = BytesIO()
async for chunk in communicate.stream(): async for chunk in communicate.stream():
if first: if first:
first = False first = False
if chunk["type"] == "audio": if chunk["type"] == "audio":
self._byte_stream.write(chunk["data"]) byte_stream.write(chunk["data"])
elif chunk["type"] == "WordBoundary": elif chunk["type"] == "WordBoundary":
pass pass
await communicate.stream().aclose()
return byte_stream
async def _on_handle(self): async def _on_handle(self, stream, index):
self._byte_stream.seek(0) stream.seek(0)
try: try:
self._byte_stream.seek(0) stream.seek(0)
stream = self.__create_bytes_stream(self._byte_stream) byte_stream = self.__create_bytes_stream(stream)
print('-------tts start push chunk') print('-------tts start push chunk')
self._handle.on_handle(stream) self._handle.on_handle(byte_stream, index)
self._byte_stream.seek(0) stream.seek(0)
self._byte_stream.truncate() stream.truncate()
print('-------tts finish push chunk') print('-------tts finish push chunk')
except Exception as e: except Exception as e:
self._byte_stream.seek(0) stream.seek(0)
self._byte_stream.truncate() stream.truncate()
print('-------tts finish error:', e) print('-------tts finish error:', e)
stream.close()
def __create_bytes_stream(self, byte_stream): def __create_bytes_stream(self, byte_stream):
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
@ -57,3 +61,8 @@ class TTSEdge(TTSBase):
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate) stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
return stream return stream
async def _on_close(self):
print('TTSEdge close')
# if self._byte_stream is not None and not self._byte_stream.closed:
# self._byte_stream.close()

View File

@ -5,48 +5,54 @@ import threading
class AsyncTaskQueue: class AsyncTaskQueue:
def __init__(self): def __init__(self, work_num=1):
self._queue = asyncio.Queue() self._queue = asyncio.Queue()
self._loop = asyncio.new_event_loop() self._worker_num = work_num
self._current_worker_num = work_num
self._condition = threading.Condition()
self._thread = threading.Thread(target=self._run_loop) self._thread = threading.Thread(target=self._run_loop)
self._worker_task = None
# self._loop_running = threading.Event()
# self._loop_running.set()
self._thread.start() self._thread.start()
self.__loop = None
def _run_loop(self): def _run_loop(self):
print('_run_loop') print('_run_loop')
asyncio.set_event_loop(self._loop) self.__loop = asyncio.new_event_loop()
self._loop.run_forever() asyncio.set_event_loop(self.__loop)
self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
self.__loop.run_forever()
print("exit run")
if not self.__loop.is_closed():
self.__loop.close()
async def _worker(self): async def _worker(self):
print('_worker') print('_worker')
while True: while True:
print('_worker1') with self._condition:
task = await self._queue.get() self._condition.wait_for(lambda: not self._queue.empty())
print('_worker2') task = await self._queue.get()
if task is None: func, *args = task # 解包任务
break if func is None: # None 作为结束信号
print('run task') break
await task
self._queue.task_done() print(f"Executing task with args: {args}")
await func(*args) # 执行异步函数
self._queue.task_done()
print('_worker finish') print('_worker finish')
self._current_worker_num = self._current_worker_num - 1
if self._current_worker_num == 0:
print('loop stop')
self.__loop.stop()
def add_task(self, coro): def add_task(self, func, *args):
print('add_task') with self._condition:
asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop) self._queue.put_nowait((func, *args))
print('add_task1') self._condition.notify()
def start_worker(self): def stop_workers(self):
print('start_worker') for _ in range(self._worker_num):
if not self._worker_task: self.add_task(None) # 发送结束信号
self._worker_task = asyncio.run_coroutine_threadsafe(self._worker(), self._loop)
def stop(self): def stop(self):
# self._loop_running.clear() self.stop_workers()
asyncio.run_coroutine_threadsafe(self._queue.put(None), self._loop).result()
if self._worker_task:
self._worker_task.result()
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join() self._thread.join()
self._loop.close()