diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py index 6755a4b..537b3a8 100644 --- a/human/audio_inference_handler.py +++ b/human/audio_inference_handler.py @@ -11,7 +11,6 @@ import numpy as np import torch from eventbus import EventBus -from human import HumanStatus from human_handler import AudioHandler from utils import load_model, mirror_index, get_device, SyncQueue @@ -75,7 +74,7 @@ class AudioInferenceHandler(AudioHandler): count_time = 0 logger.info('start inference') silence_length = 133 - human_status = HumanStatus(length, silence_length) + # human_status = HumanStatus(length, silence_length) device = get_device() logger.info(f'use device:{device}') @@ -110,13 +109,13 @@ class AudioInferenceHandler(AudioHandler): for i in range(batch_size): if not self._is_running: break - # self.on_next_handle((None, mirror_index(silence_length, index), - self.on_next_handle((None, human_status.get_index(), + self.on_next_handle((None, mirror_index(silence_length, index), + # self.on_next_handle((None, human_status.get_index(), audio_frames[i * 2:i * 2 + 2]), 0) index = index + 1 else: logger.info(f'infer======= {current_text}') - human_status.try_to_talk() + # human_status.try_to_talk() t = time.perf_counter() img_batch = [] # for i in range(batch_size): diff --git a/human/human_context.py b/human/human_context.py index e76e8ad..cc28cc0 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -8,7 +8,7 @@ from .audio_inference_onnx_handler import AudioInferenceOnnxHandler from .audio_inference_handler import AudioInferenceHandler from .audio_mal_handler import AudioMalHandler from .human_render import HumanRender -from nlp import PunctuationSplit, DouBao +from nlp import PunctuationSplit, DouBao, Kimi from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp from utils import load_avatar, get_device, object_stop, load_avatar_from_processed, load_avatar_from_256_processed @@ -125,7 +125,8 @@ class HumanContext: self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler) self._tts = TTSEdgeHttp(self._tts_handle) split = PunctuationSplit() - self._nlp = DouBao(self, split, self._tts) + # self._nlp = DouBao(self, split, self._tts) + self._nlp = Kimi(self, split, self._tts) self._asr = SherpaNcnnAsr() self._asr.attach(self._nlp) diff --git a/nlp/__init__.py b/nlp/__init__.py index 3ebb749..9da22f8 100644 --- a/nlp/__init__.py +++ b/nlp/__init__.py @@ -2,4 +2,5 @@ from .nlp_callback import NLPCallback from .nlp_doubao import DouBao +from .nlp_kimi import Kimi from .nlp_split import PunctuationSplit diff --git a/nlp/nlp_kimi.py b/nlp/nlp_kimi.py new file mode 100644 index 0000000..833bc03 --- /dev/null +++ b/nlp/nlp_kimi.py @@ -0,0 +1,117 @@ +#encoding = utf8 +import json +import logging +import time + +import requests + +from nlp.nlp_base import NLPBase + +logger = logging.getLogger(__name__) + + +class KimiHttp: + def __init__(self, token): + self.__token = token + self._response = None + self._requesting = False + + def __request(self, question): + url = "https://api.moonshot.cn/v1/chat/completions" + headers = { + "Authorization": "Bearer " + self.__token, + "Content-Type": "application/json" + } + + data = { + "model": "moonshot-v1-8k", + "messages": question, + 'stream': True, + "temperature": 0.3 + } + + response = requests.post(url, headers=headers, json=data, stream=True) + return response + + def request(self, question, handle, callback): + t = time.time() + self._requesting = True + logger.info(f'-------dou_bao ask:{question}') + msg_list = [ + {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。" + "你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视," + "黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。"}, + {"role": "user", "content": question} + ] + self._response = self.__request(msg_list) + if not self._response.ok: + logger.error(f"请求失败,状态码:{self._response.status_code}") + return + sec = '' + for chunk in self._response.iter_lines(): + content = chunk.decode("utf-8").strip() + if len(content) < 1: + continue + content = content[5:] + content = content.strip() + if content == '[DONE]': + break + + try: + content = json.loads(content) + except Exception as e: + logger.error(f"json解析失败,错误信息:{e, content}") + continue + sec = sec + content["choices"][0]["delta"]["content"] + sec, message = handle.handle(sec) + if len(message) > 0: + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + callback(message) + if len(sec) > 0: + callback(sec) + + self._requesting = False + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + + def close(self): + if self._response is not None and self._requesting: + self._response.close() + + def aclose(self): + if self._response is not None and self._requesting: + self._response.close() + logger.info('DouBaoHttp close') + + +class Kimi(NLPBase): + def __init__(self, context, split, callback=None): + super().__init__(context, split, callback) + logger.info("DouBao init") + # Access Key ID + # AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y + # AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc + # Secret Access Key + # WmpRelltRXhNbVkyWWpnNU5HRmpNamc0WTJZMFpUWmpOV1E1TTJFME1tTQ== + # TkRJMk1tTTFZamt4TkRVNE5HRTNZMkUyTnpFeU5qQmxNMkUwWXpaak1HRQ== + # endpoint_id + # ep-20241008152048-fsgzf + # api_key + # c9635f9e-0f9e-4ca1-ac90-8af25a541b74 + # api_ky + # eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmstY29uc29sZSIsImV4cCI6MTczMDk2NTMxOSwiaWF0IjoxNzI4MzczMzE5LCJ0IjoidXNlciIsImt2IjoxLCJhaWQiOiIyMTAyMjc3NDc1IiwidWlkIjoiMCIsImlzX291dGVyX3VzZXIiOnRydWUsInJlc291cmNlX3R5cGUiOiJlbmRwb2ludCIsInJlc291cmNlX2lkcyI6WyJlcC0yMDI0MTAwODE1MjA0OC1mc2d6ZiJdfQ.BHgFj-UKeu7IGG5VL2e6iPQEMNMkQrgmM46zYmTpoNG_ySgSFJLWYzbrIABZmqVDB4Rt58j8kvoORs-RHJUz81rXUlh3BYl9-ZwbggtAU7Z1pm54_qZ00jF0jQ6r-fUSXZo2PVCLxb_clNuEh06NyaV7ullZwUCyLKx3vhCsxPAuEvQvLc_qDBx-IYNT-UApVADaqMs-OyewoxahqQ7RvaHFF14R6ihmg9H0uvl00_JiGThJveszKvy_T-Qk6iPOy-EDI2pwJxdHMZ7By0bWK5EfZoK2hOvOSRD0BNTYnvrTfI0l2JgS0nwCVEPR4KSTXxU_oVVtuUSZp1UHvvkhvA + self.__token = 'sk-yCx0lZUmfGx0ECEQAp8jTnAisHwUIokoDXN7XNBuvMILxWnN' + self._dou_bao = KimiHttp(self.__token) + + def _request(self, question): + self._dou_bao.request(question, self._split_handle, self._on_callback) + + def _on_close(self): + if self._dou_bao is not None: + self._dou_bao.close() + logger.info('AsyncArk close') + + def on_clear_cache(self, *args, **kwargs): + super().on_clear_cache(*args, **kwargs) + if self._dou_bao is not None: + self._dou_bao.aclose() + logger.info('DouBao clear_cache') diff --git a/nlp/nlp_split.py b/nlp/nlp_split.py index 75f6327..d938088 100644 --- a/nlp/nlp_split.py +++ b/nlp/nlp_split.py @@ -15,6 +15,7 @@ class PunctuationSplit(NLPSplit): def handle(self, message: str): message = message.replace('*', '') + message = message.replace('#', '') match = re.search(self._pattern, message) if match: pos = match.start() + 1