add async task queue and add doubao model
This commit is contained in:
parent
aef7d3d499
commit
0ed6249f15
2
nlp/__init__.py
Normal file
2
nlp/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
22
nlp/nlp_base.py
Normal file
22
nlp/nlp_base.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from utils import AsyncTaskQueue
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NLPBase:
|
||||||
|
def __init__(self):
|
||||||
|
self._ask_queue = AsyncTaskQueue()
|
||||||
|
self._ask_queue.start_worker()
|
||||||
|
|
||||||
|
async def _request(self, question):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def ask(self, question):
|
||||||
|
self._ask_queue.add_task(self._request(question))
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._ask_queue.stop()
|
||||||
|
|
78
nlp/nlp_doubao.py
Normal file
78
nlp/nlp_doubao.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
|
import volcenginesdkark
|
||||||
|
import volcenginesdkcore
|
||||||
|
from volcenginesdkcore.rest import ApiException
|
||||||
|
|
||||||
|
from nlp.nlp_base import NLPBase
|
||||||
|
from volcenginesdkarkruntime import AsyncArk
|
||||||
|
|
||||||
|
nlp_queue = Queue()
|
||||||
|
|
||||||
|
|
||||||
|
class DouBao(NLPBase):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Access Key ID
|
||||||
|
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
|
||||||
|
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
|
||||||
|
# Secret Access Key
|
||||||
|
# WmpRelltRXhNbVkyWWpnNU5HRmpNamc0WTJZMFpUWmpOV1E1TTJFME1tTQ==
|
||||||
|
# TkRJMk1tTTFZamt4TkRVNE5HRTNZMkUyTnpFeU5qQmxNMkUwWXpaak1HRQ==
|
||||||
|
# endpoint_id
|
||||||
|
# ep-20241008152048-fsgzf
|
||||||
|
# api_key
|
||||||
|
# c9635f9e-0f9e-4ca1-ac90-8af25a541b74
|
||||||
|
# api_ky
|
||||||
|
# eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmstY29uc29sZSIsImV4cCI6MTczMDk2NTMxOSwiaWF0IjoxNzI4MzczMzE5LCJ0IjoidXNlciIsImt2IjoxLCJhaWQiOiIyMTAyMjc3NDc1IiwidWlkIjoiMCIsImlzX291dGVyX3VzZXIiOnRydWUsInJlc291cmNlX3R5cGUiOiJlbmRwb2ludCIsInJlc291cmNlX2lkcyI6WyJlcC0yMDI0MTAwODE1MjA0OC1mc2d6ZiJdfQ.BHgFj-UKeu7IGG5VL2e6iPQEMNMkQrgmM46zYmTpoNG_ySgSFJLWYzbrIABZmqVDB4Rt58j8kvoORs-RHJUz81rXUlh3BYl9-ZwbggtAU7Z1pm54_qZ00jF0jQ6r-fUSXZo2PVCLxb_clNuEh06NyaV7ullZwUCyLKx3vhCsxPAuEvQvLc_qDBx-IYNT-UApVADaqMs-OyewoxahqQ7RvaHFF14R6ihmg9H0uvl00_JiGThJveszKvy_T-Qk6iPOy-EDI2pwJxdHMZ7By0bWK5EfZoK2hOvOSRD0BNTYnvrTfI0l2JgS0nwCVEPR4KSTXxU_oVVtuUSZp1UHvvkhvA
|
||||||
|
self.__token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74'
|
||||||
|
self.__client = AsyncArk(api_key=self.__token)
|
||||||
|
|
||||||
|
async def _request(self, question):
|
||||||
|
t = time.time()
|
||||||
|
print(f'-------dou_bao ask:', question)
|
||||||
|
stream = await self.__client.chat.completions.create(
|
||||||
|
model="ep-20241008152048-fsgzf",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
|
||||||
|
{"role": "user", "content": question},
|
||||||
|
],
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
async for completion in stream:
|
||||||
|
# print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||||
|
nlp_queue.put(completion.choices[0].delta.content)
|
||||||
|
# print(completion.choices[0].delta.content, end="")
|
||||||
|
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# print(get_dou_bao_api())
|
||||||
|
dou_bao = DouBao()
|
||||||
|
dou_bao.ask('你好。')
|
||||||
|
dou_bao.ask('你好,你是谁?')
|
||||||
|
dou_bao.ask('你能做什么?')
|
||||||
|
dou_bao.ask('介绍一下,我自己。')
|
||||||
|
count = 1000
|
||||||
|
sec = ''
|
||||||
|
while count >= 0:
|
||||||
|
count = count - 1
|
||||||
|
if nlp_queue.empty():
|
||||||
|
time.sleep(0.1)
|
||||||
|
continue
|
||||||
|
sec = sec + nlp_queue.get(block=True, timeout=0.01)
|
||||||
|
|
||||||
|
pattern = r'[,。、;?!,.!?]'
|
||||||
|
match = re.search(pattern, sec)
|
||||||
|
if match:
|
||||||
|
pos = match.start() + 1
|
||||||
|
print(sec[:pos])
|
||||||
|
sec = sec[pos:]
|
||||||
|
print(sec)
|
||||||
|
|
||||||
|
|
||||||
|
dou_bao.stop()
|
||||||
|
|
3
utils/__init__.py
Normal file
3
utils/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
from .async_task_queue import AsyncTaskQueue
|
42
utils/async_task_queue.py
Normal file
42
utils/async_task_queue.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTaskQueue:
|
||||||
|
def __init__(self):
|
||||||
|
self._queue = asyncio.Queue()
|
||||||
|
self._loop = asyncio.new_event_loop()
|
||||||
|
self._thread = threading.Thread(target=self._run_loop)
|
||||||
|
self.worker_task = None
|
||||||
|
self.loop_running = threading.Event()
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def _run_loop(self):
|
||||||
|
asyncio.set_event_loop(self._loop)
|
||||||
|
self.loop_running.set() # 设置事件,表明事件循环正在运行
|
||||||
|
self._loop.run_forever() # 启动事件循环
|
||||||
|
|
||||||
|
async def _worker(self):
|
||||||
|
while True:
|
||||||
|
task = await self._queue.get()
|
||||||
|
if task is None:
|
||||||
|
break
|
||||||
|
await task
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
def add_task(self, coro):
|
||||||
|
asyncio.run_coroutine_threadsafe(self._queue.put(coro), self._loop)
|
||||||
|
|
||||||
|
def start_worker(self):
|
||||||
|
if not self.worker_task:
|
||||||
|
self.worker_task = asyncio.run_coroutine_threadsafe(self._worker(), self._loop)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
asyncio.run_coroutine_threadsafe(self._queue.put(None), self._loop).result()
|
||||||
|
if self.worker_task:
|
||||||
|
self.worker_task.result()
|
||||||
|
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||||
|
self._thread.join() # 等待线程结束
|
||||||
|
self._loop.close() # 关闭事件循环
|
Loading…
Reference in New Issue
Block a user