human/nlp/nlp_doubao.py
2024-11-16 10:13:11 +08:00

171 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#encoding = utf8
import json
import logging
import time
import requests
from nlp.nlp_base import NLPBase
from volcenginesdkarkruntime import AsyncArk
logger = logging.getLogger(__name__)
class DouBaoSDK:
def __init__(self, token):
self._token = token
self.__client = AsyncArk(api_key=token)
self._stream = None
def request(self, question, handle, callback):
if self.__client is None:
self.__client = AsyncArk(api_key=self._token)
t = time.time()
logger.info(f'-------dou_bao ask:{question}')
try:
self._stream = self.__client.chat.completions.create(
model="ep-20241008152048-fsgzf",
messages=[
{"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
{"role": "user", "content": question},
],
stream=True
)
sec = ''
for completion in self._stream:
sec = sec + completion.choices[0].delta.content
sec, message = handle.handle(sec)
if len(message) > 0:
# logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
callback(message)
callback(sec)
# logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
self._stream.close()
self._stream = None
except Exception as e:
logger.error(f'-------dou_bao error:{e}')
# logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
def close(self):
if self._stream is not None:
self._stream.close()
self._stream = None
logger.info('AsyncArk close')
if self.__client is not None and not self.__client.is_closed():
self.__client.close()
self.__client = None
def aclose(self):
if self._stream is not None:
self._stream.close()
self._stream = None
logger.info('AsyncArk close')
if self.__client is not None and not self.__client.is_closed():
self.__client.close()
self.__client = None
class DouBaoHttp:
def __init__(self, token):
self.__token = token
self._response = None
self._requesting = False
def __request(self, question):
url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
headers = {
"Authorization": "Bearer " + self.__token,
"Content-Type": "application/json"
}
data = {
"model": "ep-20241008152048-fsgzf",
"messages": question,
'stream': True
}
response = requests.post(url, headers=headers, json=data, stream=True)
return response
def request(self, question, handle, callback):
t = time.time()
self._requesting = True
logger.info(f'-------dou_bao ask:{question}')
msg_list = [
{"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"},
{"role": "user", "content": question}
]
self._response = self.__request(msg_list)
if not self._response.ok:
logger.error(f"请求失败,状态码:{self._response.status_code}")
return
sec = ''
for chunk in self._response.iter_lines():
content = chunk.decode("utf-8").strip()
if len(content) < 1:
continue
content = content[5:]
content = content.strip()
if content == '[DONE]':
break
try:
content = json.loads(content)
except Exception as e:
logger.error(f"json解析失败错误信息{e, content}")
continue
sec = sec + content["choices"][0]["delta"]["content"]
sec, message = handle.handle(sec)
if len(message) > 0:
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
callback(message)
if len(sec) > 0:
callback(sec)
self._requesting = False
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
def close(self):
if self._response is not None and self._requesting:
self._response.close()
def aclose(self):
if self._response is not None and self._requesting:
self._response.close()
logger.info('DouBaoHttp close')
class DouBao(NLPBase):
def __init__(self, context, split, callback=None):
super().__init__(context, split, callback)
logger.info("DouBao init")
# Access Key ID
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
# Secret Access Key
# WmpRelltRXhNbVkyWWpnNU5HRmpNamc0WTJZMFpUWmpOV1E1TTJFME1tTQ==
# TkRJMk1tTTFZamt4TkRVNE5HRTNZMkUyTnpFeU5qQmxNMkUwWXpaak1HRQ==
# endpoint_id
# ep-20241008152048-fsgzf
# api_key
# c9635f9e-0f9e-4ca1-ac90-8af25a541b74
# api_ky
# eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmstY29uc29sZSIsImV4cCI6MTczMDk2NTMxOSwiaWF0IjoxNzI4MzczMzE5LCJ0IjoidXNlciIsImt2IjoxLCJhaWQiOiIyMTAyMjc3NDc1IiwidWlkIjoiMCIsImlzX291dGVyX3VzZXIiOnRydWUsInJlc291cmNlX3R5cGUiOiJlbmRwb2ludCIsInJlc291cmNlX2lkcyI6WyJlcC0yMDI0MTAwODE1MjA0OC1mc2d6ZiJdfQ.BHgFj-UKeu7IGG5VL2e6iPQEMNMkQrgmM46zYmTpoNG_ySgSFJLWYzbrIABZmqVDB4Rt58j8kvoORs-RHJUz81rXUlh3BYl9-ZwbggtAU7Z1pm54_qZ00jF0jQ6r-fUSXZo2PVCLxb_clNuEh06NyaV7ullZwUCyLKx3vhCsxPAuEvQvLc_qDBx-IYNT-UApVADaqMs-OyewoxahqQ7RvaHFF14R6ihmg9H0uvl00_JiGThJveszKvy_T-Qk6iPOy-EDI2pwJxdHMZ7By0bWK5EfZoK2hOvOSRD0BNTYnvrTfI0l2JgS0nwCVEPR4KSTXxU_oVVtuUSZp1UHvvkhvA
self.__token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74'
self._dou_bao = DouBaoHttp(self.__token)
def _request(self, question):
self._dou_bao.request(question, self._split_handle, self._on_callback)
def _on_close(self):
if self._dou_bao is not None:
self._dou_bao.close()
logger.info('AsyncArk close')
def on_clear_cache(self, *args, **kwargs):
super().on_clear_cache(*args, **kwargs)
if self._dou_bao is not None:
self._dou_bao.aclose()
logger.info('DouBao clear_cache')