modify nlp abort
This commit is contained in:
parent
d9f55d1ba1
commit
7406552289
@ -49,7 +49,6 @@ class NLPBase(AsrObserver):
|
||||
if not self._is_running:
|
||||
return
|
||||
logger.info(f'complete:{message}')
|
||||
# self._context.pause_talk()
|
||||
self.ask(message)
|
||||
|
||||
def ask(self, question):
|
||||
|
@ -1,14 +1,117 @@
|
||||
#encoding = utf8
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from nlp.nlp_base import NLPBase
|
||||
from volcenginesdkarkruntime import AsyncArk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DouBaoSDK:
|
||||
def __init__(self, token):
|
||||
self._token = token
|
||||
self.__client = AsyncArk(api_key=token)
|
||||
self._stream = None
|
||||
|
||||
async def request(self, question, handle, callback):
|
||||
if self.__client is None:
|
||||
self.__client = AsyncArk(api_key=self._token)
|
||||
t = time.time()
|
||||
logger.info(f'-------dou_bao ask:{question}')
|
||||
try:
|
||||
self._stream = await self.__client.chat.completions.create(
|
||||
model="ep-20241008152048-fsgzf",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
sec = ''
|
||||
async for completion in self._stream:
|
||||
sec = sec + completion.choices[0].delta.content
|
||||
sec, message = handle.handle(sec)
|
||||
if len(message) > 0:
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
callback(message)
|
||||
callback(sec)
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
await self._stream.close()
|
||||
self._stream = None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
|
||||
async def close(self):
|
||||
if self._stream is not None:
|
||||
await self._stream.close()
|
||||
self._stream = None
|
||||
logger.info('AsyncArk close')
|
||||
if self.__client is not None and not self.__client.is_closed():
|
||||
await self.__client.close()
|
||||
self.__client = None
|
||||
|
||||
|
||||
class DouBaoHttp:
|
||||
def __init__(self, token):
|
||||
self.__token = token
|
||||
self._response = None
|
||||
self._requesting = False
|
||||
|
||||
def __request(self, question):
|
||||
url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
||||
headers = {
|
||||
"Authorization": "Bearer " + self.__token,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": "ep-20241008152048-fsgzf",
|
||||
"messages": question,
|
||||
'stream': True
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
return response
|
||||
|
||||
async def request(self, question, handle, callback):
|
||||
t = time.time()
|
||||
self._requesting = True
|
||||
logger.info(f'-------dou_bao ask:{question}')
|
||||
msg_list = [
|
||||
{"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
|
||||
{"role": "user", "content": question}
|
||||
]
|
||||
self._response = self.__request(msg_list)
|
||||
if not self._response.ok:
|
||||
logger.info(f"请求失败,状态码:{self._response.status_code}")
|
||||
return
|
||||
sec = ''
|
||||
for chunk in self._response.iter_lines():
|
||||
content = chunk.decode("utf-8").strip()
|
||||
if len(content) < 1:
|
||||
continue
|
||||
content = content[5:]
|
||||
content = json.loads(content)
|
||||
sec = sec + content["choices"][0]["delta"]["content"]
|
||||
sec, message = handle.handle(sec)
|
||||
if len(message) > 0:
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
callback(message)
|
||||
callback(sec)
|
||||
self._requesting = False
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
|
||||
async def close(self):
|
||||
if self._response is not None and self._requesting:
|
||||
self._response.close()
|
||||
|
||||
|
||||
class DouBao(NLPBase):
|
||||
def __init__(self, context, split, callback=None):
|
||||
super().__init__(context, split, callback)
|
||||
@ -26,44 +129,12 @@ class DouBao(NLPBase):
|
||||
# api_ky
|
||||
# eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmstY29uc29sZSIsImV4cCI6MTczMDk2NTMxOSwiaWF0IjoxNzI4MzczMzE5LCJ0IjoidXNlciIsImt2IjoxLCJhaWQiOiIyMTAyMjc3NDc1IiwidWlkIjoiMCIsImlzX291dGVyX3VzZXIiOnRydWUsInJlc291cmNlX3R5cGUiOiJlbmRwb2ludCIsInJlc291cmNlX2lkcyI6WyJlcC0yMDI0MTAwODE1MjA0OC1mc2d6ZiJdfQ.BHgFj-UKeu7IGG5VL2e6iPQEMNMkQrgmM46zYmTpoNG_ySgSFJLWYzbrIABZmqVDB4Rt58j8kvoORs-RHJUz81rXUlh3BYl9-ZwbggtAU7Z1pm54_qZ00jF0jQ6r-fUSXZo2PVCLxb_clNuEh06NyaV7ullZwUCyLKx3vhCsxPAuEvQvLc_qDBx-IYNT-UApVADaqMs-OyewoxahqQ7RvaHFF14R6ihmg9H0uvl00_JiGThJveszKvy_T-Qk6iPOy-EDI2pwJxdHMZ7By0bWK5EfZoK2hOvOSRD0BNTYnvrTfI0l2JgS0nwCVEPR4KSTXxU_oVVtuUSZp1UHvvkhvA
|
||||
self.__token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74'
|
||||
self.__client = AsyncArk(api_key=self.__token)
|
||||
self._dou_bao = DouBaoHttp(self.__token)
|
||||
|
||||
async def _request(self, question):
|
||||
t = time.time()
|
||||
logger.info(f'-------dou_bao ask:{question}')
|
||||
try:
|
||||
stream = await self.__client.chat.completions.create(
|
||||
model="ep-20241008152048-fsgzf",
|
||||
messages=[
|
||||
{"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
|
||||
{"role": "user", "content": question},
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
sec = ''
|
||||
async for completion in stream:
|
||||
sec = sec + completion.choices[0].delta.content
|
||||
sec, message = self._split_handle.handle(sec)
|
||||
if len(message) > 0:
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
self._on_callback(message)
|
||||
self._on_callback(sec)
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
await stream.close()
|
||||
|
||||
# sec = "你是测试客服,是由字节跳动开发的 AI 人工智能助手"
|
||||
# sec, message = self._split_handle.handle(sec)
|
||||
# sec, message = self._split_handle.handle(sec)
|
||||
# if len(message) > 0:
|
||||
# self._on_callback(message)
|
||||
# if len(sec) > 0:
|
||||
# self._on_callback(sec)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
await self._dou_bao.request(question, self._split_handle, self._on_callback)
|
||||
|
||||
async def _on_close(self):
|
||||
if self._dou_bao is not None:
|
||||
await self._dou_bao.close()
|
||||
logger.info('AsyncArk close')
|
||||
if self.__client is not None and not self.__client.is_closed():
|
||||
await self.__client.close()
|
||||
|
@ -1,22 +1,81 @@
|
||||
#encoding = utf8
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
|
||||
from nlp import PunctuationSplit, DouBao
|
||||
from nlp import PunctuationSplit, DouBao, NLPCallback
|
||||
from utils import config_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
# 接入点和apiKey生成、模型选择部分请访问豆包官方文档,本文仅给出请求体和访问路径
|
||||
# 豆包大模型接口官方文档 https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint/detail?Id=ep-20240826182225-kp7rp&Tab=api
|
||||
# 部分内容参考通义千问文档实现 https://help.aliyun.com/zh/model-studio/developer-reference/use-qwen-by-calling-api?spm=a2c4g.11186623.0.0.28b919a1NbCP4i#e7932c7e33gvv
|
||||
|
||||
|
||||
def __request(key, question):
|
||||
url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
||||
headers = {
|
||||
"Authorization": "Bearer " + key,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"model": "ep-20241008152048-fsgzf",
|
||||
"messages": question,
|
||||
'stream': True
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
return response
|
||||
|
||||
|
||||
class DisplayNLP(NLPCallback):
|
||||
def on_message(self, txt: str):
|
||||
print(txt)
|
||||
|
||||
|
||||
def main():
|
||||
print("Started! Please speak")
|
||||
# 你的API_KEY
|
||||
# api_key = "c9635f9e-0f9e-4ca1-ac90-8af25a541b74"
|
||||
# __token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74'
|
||||
# # 问题列表
|
||||
# msg_list = [
|
||||
# {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
|
||||
# {"role": "user", "content": "你好"}
|
||||
# ]
|
||||
# t = time.time()
|
||||
# response = __request(__token, msg_list)
|
||||
# if response.status_code != 200:
|
||||
# print(f"请求失败,状态码:{response.status_code}")
|
||||
# return
|
||||
#
|
||||
# for chunk in response.iter_lines():
|
||||
# content = chunk.decode("utf-8").strip()
|
||||
# if len(content) < 1:
|
||||
# continue
|
||||
#
|
||||
# content = content[5:]
|
||||
# content = json.loads(content)
|
||||
# print(f'-------dou_bao ask time:{time.time() - t:.4f}s, response:{content["choices"][0]["delta"]["content"]}')
|
||||
#
|
||||
# print("文件下载完成")
|
||||
|
||||
split = PunctuationSplit()
|
||||
nlp = DouBao(split)
|
||||
nlp = DouBao(None, split, DisplayNLP())
|
||||
nlp.ask('你好')
|
||||
nlp.ask('你是谁')
|
||||
nlp.ask('能做什么')
|
||||
time.sleep(5)
|
||||
nlp.stop()
|
||||
print("stop")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config_logging('../logs/info.log', logging.INFO, logging.INFO)
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
|
Loading…
Reference in New Issue
Block a user