modify async task
This commit is contained in:
parent
9cbeb58f69
commit
5d98af51de
@ -10,7 +10,6 @@ logger = logging.getLogger(__name__)
|
|||||||
class NLPBase(AsrObserver):
|
class NLPBase(AsrObserver):
|
||||||
def __init__(self, split, callback=None):
|
def __init__(self, split, callback=None):
|
||||||
self._ask_queue = AsyncTaskQueue()
|
self._ask_queue = AsyncTaskQueue()
|
||||||
self._ask_queue.start_worker()
|
|
||||||
self._split_handle = split
|
self._split_handle = split
|
||||||
self._callback = callback
|
self._callback = callback
|
||||||
|
|
||||||
@ -29,6 +28,9 @@ class NLPBase(AsrObserver):
|
|||||||
async def _request(self, question):
|
async def _request(self, question):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _on_close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def process(self, message: str):
|
def process(self, message: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -38,8 +40,9 @@ class NLPBase(AsrObserver):
|
|||||||
|
|
||||||
def ask(self, question):
|
def ask(self, question):
|
||||||
logger.info(f'ask:{question}')
|
logger.info(f'ask:{question}')
|
||||||
self._ask_queue.add_task(self._request(question))
|
self._ask_queue.add_task(self._request, question)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
self._ask_queue.add_task(self._on_close)
|
||||||
self._ask_queue.stop()
|
self._ask_queue.stop()
|
||||||
|
|
@ -31,6 +31,7 @@ class DouBao(NLPBase):
|
|||||||
t = time.time()
|
t = time.time()
|
||||||
logger.info(f'_request:{question}')
|
logger.info(f'_request:{question}')
|
||||||
print(f'-------dou_bao ask:', question)
|
print(f'-------dou_bao ask:', question)
|
||||||
|
try:
|
||||||
stream = await self.__client.chat.completions.create(
|
stream = await self.__client.chat.completions.create(
|
||||||
model="ep-20241008152048-fsgzf",
|
model="ep-20241008152048-fsgzf",
|
||||||
messages=[
|
messages=[
|
||||||
@ -41,19 +42,22 @@ class DouBao(NLPBase):
|
|||||||
)
|
)
|
||||||
sec = ''
|
sec = ''
|
||||||
async for completion in stream:
|
async for completion in stream:
|
||||||
# print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
|
||||||
# nlp_queue.put(completion.choices[0].delta.content)
|
|
||||||
# print(completion.choices[0].delta.content, end="")
|
|
||||||
sec = sec + completion.choices[0].delta.content
|
sec = sec + completion.choices[0].delta.content
|
||||||
sec, message = self._split_handle.handle(sec)
|
sec, message = self._split_handle.handle(sec)
|
||||||
if len(message) > 0:
|
if len(message) > 0:
|
||||||
print(message)
|
|
||||||
self._on_callback(message)
|
self._on_callback(message)
|
||||||
print(sec)
|
|
||||||
self._on_callback(sec)
|
self._on_callback(sec)
|
||||||
|
await stream.close()
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
|
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
|
||||||
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||||
|
|
||||||
|
async def _on_close(self):
|
||||||
|
print('AsyncArk close')
|
||||||
|
if self.__client is not None and not self.__client.is_closed():
|
||||||
|
await self.__client.close()
|
||||||
|
|
||||||
'''
|
'''
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# print(get_dou_bao_api())
|
# print(get_dou_bao_api())
|
||||||
|
@ -12,9 +12,8 @@ def main():
|
|||||||
nlp.ask('你好')
|
nlp.ask('你好')
|
||||||
nlp.ask('你是谁')
|
nlp.ask('你是谁')
|
||||||
nlp.ask('能做什么')
|
nlp.ask('能做什么')
|
||||||
time.sleep(20)
|
|
||||||
print("Stop! ")
|
|
||||||
nlp.stop()
|
nlp.stop()
|
||||||
|
print("stop")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -22,3 +21,5 @@ if __name__ == "__main__":
|
|||||||
main()
|
main()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nCaught Ctrl + C. Exiting")
|
print("\nCaught Ctrl + C. Exiting")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
44
test/test_nlp_tts.py
Normal file
44
test/test_nlp_tts.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from asr import SherpaNcnnAsr
|
||||||
|
from nlp import PunctuationSplit
|
||||||
|
from nlp.nlp_doubao import DouBao
|
||||||
|
from tts import TTSEdge, TTSAudioSaveHandle
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
except ImportError as e:
|
||||||
|
print("Please install sounddevice first. You can use")
|
||||||
|
print()
|
||||||
|
print(" pip install sounddevice")
|
||||||
|
print()
|
||||||
|
print("to install it")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Started! Please speak")
|
||||||
|
handle = TTSAudioSaveHandle()
|
||||||
|
tts = TTSEdge(handle)
|
||||||
|
split = PunctuationSplit()
|
||||||
|
nlp = DouBao(split, tts)
|
||||||
|
nlp.ask('你好,你是谁?')
|
||||||
|
nlp.ask('可以帮我做什么?')
|
||||||
|
nlp.stop()
|
||||||
|
tts.stop()
|
||||||
|
print("Stop! ")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
devices = sd.query_devices()
|
||||||
|
print(devices)
|
||||||
|
default_input_device_idx = sd.default.device[0]
|
||||||
|
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nCaught Ctrl + C. Exiting")
|
@ -2,20 +2,17 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from tts import TTSEdge
|
from tts import TTSEdge, TTSAudioSaveHandle
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("Started! Please speak")
|
handle = TTSAudioSaveHandle()
|
||||||
|
tts = TTSEdge(handle)
|
||||||
tts = TTSEdge()
|
|
||||||
tts.message('你好,')
|
tts.message('你好,')
|
||||||
tts.message('请问有什么可以帮到您,')
|
tts.message('请问有什么可以帮到您,')
|
||||||
tts.message('很高兴为您服务。')
|
tts.message('很高兴为您服务。')
|
||||||
tts.message('祝您平安,')
|
tts.message('祝您平安,')
|
||||||
tts.message('再见')
|
tts.message('再见')
|
||||||
|
|
||||||
time.sleep(20)
|
|
||||||
tts.stop()
|
tts.stop()
|
||||||
|
|
||||||
print("Stop! ")
|
print("Stop! ")
|
||||||
|
@ -9,6 +9,7 @@ from audio import save_wav
|
|||||||
class TTSAudioHandle(ABC):
|
class TTSAudioHandle(ABC):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._sample_rate = 16000
|
self._sample_rate = 16000
|
||||||
|
self._index = 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sample_rate(self):
|
def sample_rate(self):
|
||||||
@ -19,9 +20,13 @@ class TTSAudioHandle(ABC):
|
|||||||
self._sample_rate = value
|
self._sample_rate = value
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_handle(self, stream):
|
def on_handle(self, stream, index):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_index(self):
|
||||||
|
self._index = self._index + 1
|
||||||
|
return self._index
|
||||||
|
|
||||||
|
|
||||||
class TTSAudioSplitHandle(TTSAudioHandle):
|
class TTSAudioSplitHandle(TTSAudioHandle):
|
||||||
def __init__(self, human):
|
def __init__(self, human):
|
||||||
@ -30,7 +35,7 @@ class TTSAudioSplitHandle(TTSAudioHandle):
|
|||||||
self.sample_rate = self._human.get_audio_sample_rate()
|
self.sample_rate = self._human.get_audio_sample_rate()
|
||||||
self._chunk = self.sample_rate // self._human.get_fps()
|
self._chunk = self.sample_rate // self._human.get_fps()
|
||||||
|
|
||||||
def on_handle(self, stream):
|
def on_handle(self, stream, index):
|
||||||
stream_len = stream.shape[0]
|
stream_len = stream.shape[0]
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
||||||
@ -43,7 +48,6 @@ class TTSAudioSplitHandle(TTSAudioHandle):
|
|||||||
class TTSAudioSaveHandle(TTSAudioHandle):
|
class TTSAudioSaveHandle(TTSAudioHandle):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._count = 1
|
|
||||||
self._save_path_dir = '../temp/audio/'
|
self._save_path_dir = '../temp/audio/'
|
||||||
self._clean()
|
self._clean()
|
||||||
|
|
||||||
@ -65,8 +69,7 @@ class TTSAudioSaveHandle(TTSAudioHandle):
|
|||||||
shutil.rmtree(file_path)
|
shutil.rmtree(file_path)
|
||||||
print(f"Deleted directory and its contents: {file_path}")
|
print(f"Deleted directory and its contents: {file_path}")
|
||||||
|
|
||||||
def on_handle(self, stream):
|
def on_handle(self, stream, index):
|
||||||
file_name = self._save_path_dir + str(self._count) + '.wav'
|
file_name = self._save_path_dir + str(index) + '.wav'
|
||||||
save_wav(stream, file_name, self.sample_rate)
|
save_wav(stream, file_name, self.sample_rate)
|
||||||
self._count = self._count + 1
|
|
||||||
|
|
||||||
|
@ -12,8 +12,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class TTSBase(NLPCallback):
|
class TTSBase(NLPCallback):
|
||||||
def __init__(self, handle):
|
def __init__(self, handle):
|
||||||
self._handle = handle
|
self._handle = handle
|
||||||
self._message_queue = AsyncTaskQueue()
|
self._message_queue = AsyncTaskQueue(5)
|
||||||
self._message_queue.start_worker()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def handle(self):
|
def handle(self):
|
||||||
@ -23,17 +22,23 @@ class TTSBase(NLPCallback):
|
|||||||
def handle(self, value):
|
def handle(self, value):
|
||||||
self._handle = value
|
self._handle = value
|
||||||
|
|
||||||
async def _request(self, txt: str):
|
async def _request(self, txt: str, index):
|
||||||
print('_request:', txt)
|
print('_request:', txt)
|
||||||
t = time.time()
|
t = time.time()
|
||||||
await self._on_request(txt)
|
stream = await self._on_request(txt)
|
||||||
|
if stream is None:
|
||||||
|
print(f'-------stream is None')
|
||||||
|
return
|
||||||
print(f'-------tts time:{time.time() - t:.4f}s')
|
print(f'-------tts time:{time.time() - t:.4f}s')
|
||||||
await self._on_handle()
|
await self._on_handle(stream, index)
|
||||||
|
|
||||||
async def _on_request(self, text: str):
|
async def _on_request(self, text: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _on_handle(self):
|
async def _on_handle(self, stream, index):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _on_close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_message(self, txt: str):
|
def on_message(self, txt: str):
|
||||||
@ -42,7 +47,11 @@ class TTSBase(NLPCallback):
|
|||||||
def message(self, txt):
|
def message(self, txt):
|
||||||
logger.info(f'message:{txt}')
|
logger.info(f'message:{txt}')
|
||||||
print(f'message:{txt}')
|
print(f'message:{txt}')
|
||||||
self._message_queue.add_task(self._request(txt))
|
index = 0
|
||||||
|
if self._handle is not None:
|
||||||
|
index = self._handle.get_index()
|
||||||
|
self._message_queue.add_task(self._request, txt, index)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
|
self._message_queue.add_task(self._on_close)
|
||||||
self._message_queue.stop()
|
self._message_queue.stop()
|
||||||
|
@ -14,34 +14,38 @@ class TTSEdge(TTSBase):
|
|||||||
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
|
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
|
||||||
super().__init__(handle)
|
super().__init__(handle)
|
||||||
self._voice = voice
|
self._voice = voice
|
||||||
self._byte_stream = BytesIO()
|
|
||||||
|
|
||||||
async def _on_request(self, txt: str):
|
async def _on_request(self, txt: str):
|
||||||
|
print('_on_request, txt')
|
||||||
communicate = edge_tts.Communicate(txt, self._voice)
|
communicate = edge_tts.Communicate(txt, self._voice)
|
||||||
first = True
|
first = True
|
||||||
|
byte_stream = BytesIO()
|
||||||
async for chunk in communicate.stream():
|
async for chunk in communicate.stream():
|
||||||
if first:
|
if first:
|
||||||
first = False
|
first = False
|
||||||
if chunk["type"] == "audio":
|
if chunk["type"] == "audio":
|
||||||
self._byte_stream.write(chunk["data"])
|
byte_stream.write(chunk["data"])
|
||||||
elif chunk["type"] == "WordBoundary":
|
elif chunk["type"] == "WordBoundary":
|
||||||
pass
|
pass
|
||||||
|
await communicate.stream().aclose()
|
||||||
|
return byte_stream
|
||||||
|
|
||||||
async def _on_handle(self):
|
async def _on_handle(self, stream, index):
|
||||||
self._byte_stream.seek(0)
|
stream.seek(0)
|
||||||
try:
|
try:
|
||||||
self._byte_stream.seek(0)
|
stream.seek(0)
|
||||||
stream = self.__create_bytes_stream(self._byte_stream)
|
byte_stream = self.__create_bytes_stream(stream)
|
||||||
print('-------tts start push chunk')
|
print('-------tts start push chunk')
|
||||||
self._handle.on_handle(stream)
|
self._handle.on_handle(byte_stream, index)
|
||||||
self._byte_stream.seek(0)
|
stream.seek(0)
|
||||||
self._byte_stream.truncate()
|
stream.truncate()
|
||||||
print('-------tts finish push chunk')
|
print('-------tts finish push chunk')
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._byte_stream.seek(0)
|
stream.seek(0)
|
||||||
self._byte_stream.truncate()
|
stream.truncate()
|
||||||
print('-------tts finish error:', e)
|
print('-------tts finish error:', e)
|
||||||
|
stream.close()
|
||||||
|
|
||||||
def __create_bytes_stream(self, byte_stream):
|
def __create_bytes_stream(self, byte_stream):
|
||||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||||
@ -57,3 +61,8 @@ class TTSEdge(TTSBase):
|
|||||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
|
||||||
|
async def _on_close(self):
|
||||||
|
print('TTSEdge close')
|
||||||
|
# if self._byte_stream is not None and not self._byte_stream.closed:
|
||||||
|
# self._byte_stream.close()
|
||||||
|
@ -5,48 +5,54 @@ import threading
|
|||||||
|
|
||||||
|
|
||||||
class AsyncTaskQueue:
|
class AsyncTaskQueue:
|
||||||
def __init__(self):
|
def __init__(self, work_num=1):
|
||||||
self._queue = asyncio.Queue()
|
self._queue = asyncio.Queue()
|
||||||
self._loop = asyncio.new_event_loop()
|
self._worker_num = work_num
|
||||||
|
self._current_worker_num = work_num
|
||||||
|
self._condition = threading.Condition()
|
||||||
self._thread = threading.Thread(target=self._run_loop)
|
self._thread = threading.Thread(target=self._run_loop)
|
||||||
self._worker_task = None
|
|
||||||
# self._loop_running = threading.Event()
|
|
||||||
# self._loop_running.set()
|
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
self.__loop = None
|
||||||
|
|
||||||
def _run_loop(self):
|
def _run_loop(self):
|
||||||
print('_run_loop')
|
print('_run_loop')
|
||||||
asyncio.set_event_loop(self._loop)
|
self.__loop = asyncio.new_event_loop()
|
||||||
self._loop.run_forever()
|
asyncio.set_event_loop(self.__loop)
|
||||||
|
self._tasks = [self.__loop.create_task(self._worker()) for _ in range(self._worker_num)]
|
||||||
|
self.__loop.run_forever()
|
||||||
|
print("exit run")
|
||||||
|
if not self.__loop.is_closed():
|
||||||
|
self.__loop.close()
|
||||||
|
|
||||||
async def _worker(self):
|
async def _worker(self):
|
||||||
print('_worker')
|
print('_worker')
|
||||||
while True:
|
while True:
|
||||||
print('_worker1')
|
with self._condition:
|
||||||
|
self._condition.wait_for(lambda: not self._queue.empty())
|
||||||
task = await self._queue.get()
|
task = await self._queue.get()
|
||||||
print('_worker2')
|
func, *args = task # 解包任务
|
||||||
if task is None:
|
if func is None: # None 作为结束信号
|
||||||
break
|
break
|
||||||
print('run task')
|
|
||||||
await task
|
print(f"Executing task with args: {args}")
|
||||||
|
await func(*args) # 执行异步函数
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
|
|
||||||
print('_worker finish')
|
print('_worker finish')
|
||||||
|
self._current_worker_num = self._current_worker_num - 1
|
||||||
|
if self._current_worker_num == 0:
|
||||||
|
print('loop stop')
|
||||||
|
self.__loop.stop()
|
||||||
|
|
||||||
def add_task(self, coro):
|
def add_task(self, func, *args):
|
||||||
print('add_task')
|
with self._condition:
|
||||||
asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop)
|
self._queue.put_nowait((func, *args))
|
||||||
print('add_task1')
|
self._condition.notify()
|
||||||
|
|
||||||
def start_worker(self):
|
def stop_workers(self):
|
||||||
print('start_worker')
|
for _ in range(self._worker_num):
|
||||||
if not self._worker_task:
|
self.add_task(None) # 发送结束信号
|
||||||
self._worker_task = asyncio.run_coroutine_threadsafe(self._worker(), self._loop)
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
# self._loop_running.clear()
|
self.stop_workers()
|
||||||
asyncio.run_coroutine_threadsafe(self._queue.put(None), self._loop).result()
|
|
||||||
if self._worker_task:
|
|
||||||
self._worker_task.result()
|
|
||||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
|
||||||
self._thread.join()
|
self._thread.join()
|
||||||
self._loop.close()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user