From 740655228962385b7ba157f19398c0383efd6841 Mon Sep 17 00:00:00 2001 From: brige Date: Fri, 8 Nov 2024 15:23:44 +0800 Subject: [PATCH] modify nlp abort --- nlp/nlp_base.py | 1 - nlp/nlp_doubao.py | 145 +++++++++++++++++++++++++++++++----------- test/test_nlp_only.py | 67 +++++++++++++++++-- 3 files changed, 171 insertions(+), 42 deletions(-) diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py index 62f91fc..47936a9 100644 --- a/nlp/nlp_base.py +++ b/nlp/nlp_base.py @@ -49,7 +49,6 @@ class NLPBase(AsrObserver): if not self._is_running: return logger.info(f'complete:{message}') - # self._context.pause_talk() self.ask(message) def ask(self, question): diff --git a/nlp/nlp_doubao.py b/nlp/nlp_doubao.py index 6935c10..5402f62 100644 --- a/nlp/nlp_doubao.py +++ b/nlp/nlp_doubao.py @@ -1,14 +1,117 @@ #encoding = utf8 - +import json import logging import time +import requests + from nlp.nlp_base import NLPBase from volcenginesdkarkruntime import AsyncArk logger = logging.getLogger(__name__) +class DouBaoSDK: + def __init__(self, token): + self._token = token + self.__client = AsyncArk(api_key=token) + self._stream = None + + async def request(self, question, handle, callback): + if self.__client is None: + self.__client = AsyncArk(api_key=self._token) + t = time.time() + logger.info(f'-------dou_bao ask:{question}') + try: + self._stream = await self.__client.chat.completions.create( + model="ep-20241008152048-fsgzf", + messages=[ + {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"}, + {"role": "user", "content": question}, + ], + stream=True + ) + + sec = '' + async for completion in self._stream: + sec = sec + completion.choices[0].delta.content + sec, message = handle.handle(sec) + if len(message) > 0: + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + callback(message) + callback(sec) + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + await self._stream.close() + self._stream = None + except Exception as e: + print(e) + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + + async def close(self): + if self._stream is not None: + await self._stream.close() + self._stream = None + logger.info('AsyncArk close') + if self.__client is not None and not self.__client.is_closed(): + await self.__client.close() + self.__client = None + + +class DouBaoHttp: + def __init__(self, token): + self.__token = token + self._response = None + self._requesting = False + + def __request(self, question): + url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" + headers = { + "Authorization": "Bearer " + self.__token, + "Content-Type": "application/json" + } + + data = { + "model": "ep-20241008152048-fsgzf", + "messages": question, + 'stream': True + } + + response = requests.post(url, headers=headers, json=data, stream=True) + return response + + async def request(self, question, handle, callback): + t = time.time() + self._requesting = True + logger.info(f'-------dou_bao ask:{question}') + msg_list = [ + {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"}, + {"role": "user", "content": question} + ] + self._response = self.__request(msg_list) + if not self._response.ok: + logger.info(f"请求失败,状态码:{self._response.status_code}") + return + sec = '' + for chunk in self._response.iter_lines(): + content = chunk.decode("utf-8").strip() + if len(content) < 1: + continue + content = content[5:] + content = json.loads(content) + sec = sec + content["choices"][0]["delta"]["content"] + sec, message = handle.handle(sec) + if len(message) > 0: + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + callback(message) + callback(sec) + self._requesting = False + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + + async def close(self): + if self._response is not None and self._requesting: + self._response.close() + + class DouBao(NLPBase): def __init__(self, context, split, callback=None): super().__init__(context, split, callback) @@ -26,44 +129,12 @@ class DouBao(NLPBase): # api_ky # eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmstY29uc29sZSIsImV4cCI6MTczMDk2NTMxOSwiaWF0IjoxNzI4MzczMzE5LCJ0IjoidXNlciIsImt2IjoxLCJhaWQiOiIyMTAyMjc3NDc1IiwidWlkIjoiMCIsImlzX291dGVyX3VzZXIiOnRydWUsInJlc291cmNlX3R5cGUiOiJlbmRwb2ludCIsInJlc291cmNlX2lkcyI6WyJlcC0yMDI0MTAwODE1MjA0OC1mc2d6ZiJdfQ.BHgFj-UKeu7IGG5VL2e6iPQEMNMkQrgmM46zYmTpoNG_ySgSFJLWYzbrIABZmqVDB4Rt58j8kvoORs-RHJUz81rXUlh3BYl9-ZwbggtAU7Z1pm54_qZ00jF0jQ6r-fUSXZo2PVCLxb_clNuEh06NyaV7ullZwUCyLKx3vhCsxPAuEvQvLc_qDBx-IYNT-UApVADaqMs-OyewoxahqQ7RvaHFF14R6ihmg9H0uvl00_JiGThJveszKvy_T-Qk6iPOy-EDI2pwJxdHMZ7By0bWK5EfZoK2hOvOSRD0BNTYnvrTfI0l2JgS0nwCVEPR4KSTXxU_oVVtuUSZp1UHvvkhvA self.__token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74' - self.__client = AsyncArk(api_key=self.__token) + self._dou_bao = DouBaoHttp(self.__token) async def _request(self, question): - t = time.time() - logger.info(f'-------dou_bao ask:{question}') - try: - stream = await self.__client.chat.completions.create( - model="ep-20241008152048-fsgzf", - messages=[ - {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"}, - {"role": "user", "content": question}, - ], - stream=True - ) - - sec = '' - async for completion in stream: - sec = sec + completion.choices[0].delta.content - sec, message = self._split_handle.handle(sec) - if len(message) > 0: - logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') - self._on_callback(message) - self._on_callback(sec) - logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') - await stream.close() - - # sec = "你是测试客服,是由字节跳动开发的 AI 人工智能助手" - # sec, message = self._split_handle.handle(sec) - # sec, message = self._split_handle.handle(sec) - # if len(message) > 0: - # self._on_callback(message) - # if len(sec) > 0: - # self._on_callback(sec) - except Exception as e: - print(e) - logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + await self._dou_bao.request(question, self._split_handle, self._on_callback) async def _on_close(self): + if self._dou_bao is not None: + await self._dou_bao.close() logger.info('AsyncArk close') - if self.__client is not None and not self.__client.is_closed(): - await self.__client.close() diff --git a/test/test_nlp_only.py b/test/test_nlp_only.py index 9de74fb..d8af213 100644 --- a/test/test_nlp_only.py +++ b/test/test_nlp_only.py @@ -1,22 +1,81 @@ #encoding = utf8 - +import json +import logging +import os import time +import requests -from nlp import PunctuationSplit, DouBao +from nlp import PunctuationSplit, DouBao, NLPCallback +from utils import config_logging + +logger = logging.getLogger(__name__) +current_file_path = os.path.dirname(os.path.abspath(__file__)) + + +# 接入点和apiKey生成、模型选择部分请访问豆包官方文档,本文仅给出请求体和访问路径 +# 豆包大模型接口官方文档 https://console.volcengine.com/ark/region:ark+cn-beijing/endpoint/detail?Id=ep-20240826182225-kp7rp&Tab=api +# 部分内容参考通义千问文档实现 https://help.aliyun.com/zh/model-studio/developer-reference/use-qwen-by-calling-api?spm=a2c4g.11186623.0.0.28b919a1NbCP4i#e7932c7e33gvv + + +def __request(key, question): + url = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" + headers = { + "Authorization": "Bearer " + key, + "Content-Type": "application/json" + } + data = { + "model": "ep-20241008152048-fsgzf", + "messages": question, + 'stream': True + } + + response = requests.post(url, headers=headers, json=data, stream=True) + return response + + +class DisplayNLP(NLPCallback): + def on_message(self, txt: str): + print(txt) def main(): - print("Started! Please speak") + # 你的API_KEY + # api_key = "c9635f9e-0f9e-4ca1-ac90-8af25a541b74" + # __token = 'c9635f9e-0f9e-4ca1-ac90-8af25a541b74' + # # 问题列表 + # msg_list = [ + # {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"}, + # {"role": "user", "content": "你好"} + # ] + # t = time.time() + # response = __request(__token, msg_list) + # if response.status_code != 200: + # print(f"请求失败,状态码:{response.status_code}") + # return + # + # for chunk in response.iter_lines(): + # content = chunk.decode("utf-8").strip() + # if len(content) < 1: + # continue + # + # content = content[5:] + # content = json.loads(content) + # print(f'-------dou_bao ask time:{time.time() - t:.4f}s, response:{content["choices"][0]["delta"]["content"]}') + # + # print("文件下载完成") + split = PunctuationSplit() - nlp = DouBao(split) + nlp = DouBao(None, split, DisplayNLP()) nlp.ask('你好') nlp.ask('你是谁') nlp.ask('能做什么') + time.sleep(5) nlp.stop() print("stop") if __name__ == "__main__": + config_logging('../logs/info.log', logging.INFO, logging.INFO) try: main() except KeyboardInterrupt: