From 0ed6249f155b6657686c285a31356e5064000130 Mon Sep 17 00:00:00 2001 From: brige Date: Tue, 8 Oct 2024 20:15:04 +0800 Subject: [PATCH] add async task queue and add doubao model --- nlp/__init__.py | 2 + nlp/nlp_base.py | 22 +++++++++++ nlp/nlp_doubao.py | 78 +++++++++++++++++++++++++++++++++++++++ utils/__init__.py | 3 ++ utils/async_task_queue.py | 42 +++++++++++++++++++++ 5 files changed, 147 insertions(+) create mode 100644 nlp/__init__.py create mode 100644 nlp/nlp_base.py create mode 100644 nlp/nlp_doubao.py create mode 100644 utils/__init__.py create mode 100644 utils/async_task_queue.py diff --git a/nlp/__init__.py b/nlp/__init__.py new file mode 100644 index 0000000..6bb50fc --- /dev/null +++ b/nlp/__init__.py @@ -0,0 +1,2 @@ +#encoding = utf8 + diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py new file mode 100644 index 0000000..e2f94aa --- /dev/null +++ b/nlp/nlp_base.py @@ -0,0 +1,22 @@ +#encoding = utf8 +import logging + +from utils import AsyncTaskQueue + +logger = logging.getLogger(__name__) + + +class NLPBase: + def __init__(self): + self._ask_queue = AsyncTaskQueue() + self._ask_queue.start_worker() + + async def _request(self, question): + pass + + def ask(self, question): + self._ask_queue.add_task(self._request(question)) + + def stop(self): + self._ask_queue.stop() + \ No newline at end of file diff --git a/nlp/nlp_doubao.py b/nlp/nlp_doubao.py new file mode 100644 index 0000000..53507d0 --- /dev/null +++ b/nlp/nlp_doubao.py @@ -0,0 +1,78 @@ +#encoding = utf8 +import re +import time +from queue import Queue + +import volcenginesdkark +import volcenginesdkcore +from volcenginesdkcore.rest import ApiException + +from nlp.nlp_base import NLPBase +from volcenginesdkarkruntime import AsyncArk + +nlp_queue = Queue() + + +class DouBao(NLPBase): + def __init__(self): + super().__init__() + # Access Key ID + # AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y + # AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc + # Secret Access Key + # WmpRelltRXhNbVkyWWpnNU5HRmpNamc0WTJZMFpUWmpOV1E1TTJFME1tTQ== + # TkRJMk1tTTFZamt4TkRVNE5HRTNZMkUyTnpFeU5qQmxNMkUwWXpaak1HRQ== + # endpoint_id + # ep-20241008152048-fsgzf + # api_key + # c9635f9e-0f9e-4ca1-ac90-8af25a541b74 + # api_ky + # eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmstY29uc29sZSIsImV4cCI6MTczMDk2NTMxOSwiaWF0IjoxNzI4MzczMzE5LCJ0IjoidXNlciIsImt2IjoxLCJhaWQiOiIyMTAyMjc3NDc1IiwidWlkIjoiMCIsImlzX291dGVyX3VzZXIiOnRydWUsInJlc291cmNlX3R5cGUiOiJlbmRwb2ludCIsInJlc291cmNlX2lkcyI6WyJlcC0yMDI0MTAwODE1MjA0OC1mc2d6ZiJdfQ.BHgFj-UKeu7IGG5VL2e6iPQEMNMkQrgmM46zYmTpoNG_ySgSFJLWYzbrIABZmqVDB4Rt58j8kvoORs-RHJUz81rXUlh3BYl9-ZwbggtAU7Z1pm54_qZ00jF0jQ6r-fUSXZo2PVCLxb_clNuEh06NyaV7ullZwUCyLKx3vhCsxPAuEvQvLc_qDBx-IYNT-UApVADaqMs-OyewoxahqQ7RvaHFF14R6ihmg9H0uvl00_JiGThJveszKvy_T-Qk6iPOy-EDI2pwJxdHMZ7By0bWK5EfZoK2hOvOSRD0BNTYnvrTfI0l2JgS0nwCVEPR4KSTXxU_oVVtuUSZp1UHvvkhvA + self.__token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74' + self.__client = AsyncArk(api_key=self.__token) + + async def _request(self, question): + t = time.time() + print(f'-------dou_bao ask:', question) + stream = await self.__client.chat.completions.create( + model="ep-20241008152048-fsgzf", + messages=[ + {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"}, + {"role": "user", "content": question}, + ], + stream=True + ) + async for completion in stream: + # print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + nlp_queue.put(completion.choices[0].delta.content) + # print(completion.choices[0].delta.content, end="") + print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + + +if __name__ == "__main__": + # print(get_dou_bao_api()) + dou_bao = DouBao() + dou_bao.ask('你好。') + dou_bao.ask('你好,你是谁?') + dou_bao.ask('你能做什么?') + dou_bao.ask('介绍一下,我自己。') + count = 1000 + sec = '' + while count >= 0: + count = count - 1 + if nlp_queue.empty(): + time.sleep(0.1) + continue + sec = sec + nlp_queue.get(block=True, timeout=0.01) + + pattern = r'[,。、;?!,.!?]' + match = re.search(pattern, sec) + if match: + pos = match.start() + 1 + print(sec[:pos]) + sec = sec[pos:] + print(sec) + + + dou_bao.stop() + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..7240263 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,3 @@ +#encoding = utf8 + +from .async_task_queue import AsyncTaskQueue diff --git a/utils/async_task_queue.py b/utils/async_task_queue.py new file mode 100644 index 0000000..fb7cc83 --- /dev/null +++ b/utils/async_task_queue.py @@ -0,0 +1,42 @@ +#encoding = utf8 + +import asyncio +import threading + + +class AsyncTaskQueue: + def __init__(self): + self._queue = asyncio.Queue() + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._run_loop) + self.worker_task = None + self.loop_running = threading.Event() + self._thread.start() + + def _run_loop(self): + asyncio.set_event_loop(self._loop) + self.loop_running.set() # 设置事件,表明事件循环正在运行 + self._loop.run_forever() # 启动事件循环 + + async def _worker(self): + while True: + task = await self._queue.get() + if task is None: + break + await task + self._queue.task_done() + + def add_task(self, coro): + asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop) + + def start_worker(self): + if not self.worker_task: + self.worker_task = asyncio.run_coroutine_threadsafe(self._worker(), self._loop) + + def stop(self): + asyncio.run_coroutine_threadsafe(self._queue.put(None), self._loop).result() + if self.worker_task: + self.worker_task.result() + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join() # 等待线程结束 + self._loop.close() # 关闭事件循环