add kimi nlp
This commit is contained in:
parent
322ff33c84
commit
31f9ec50cb
@ -11,7 +11,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from eventbus import EventBus
|
from eventbus import EventBus
|
||||||
from human import HumanStatus
|
|
||||||
from human_handler import AudioHandler
|
from human_handler import AudioHandler
|
||||||
from utils import load_model, mirror_index, get_device, SyncQueue
|
from utils import load_model, mirror_index, get_device, SyncQueue
|
||||||
|
|
||||||
@ -75,7 +74,7 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
count_time = 0
|
count_time = 0
|
||||||
logger.info('start inference')
|
logger.info('start inference')
|
||||||
silence_length = 133
|
silence_length = 133
|
||||||
human_status = HumanStatus(length, silence_length)
|
# human_status = HumanStatus(length, silence_length)
|
||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
logger.info(f'use device:{device}')
|
logger.info(f'use device:{device}')
|
||||||
@ -110,13 +109,13 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
if not self._is_running:
|
if not self._is_running:
|
||||||
break
|
break
|
||||||
# self.on_next_handle((None, mirror_index(silence_length, index),
|
self.on_next_handle((None, mirror_index(silence_length, index),
|
||||||
self.on_next_handle((None, human_status.get_index(),
|
# self.on_next_handle((None, human_status.get_index(),
|
||||||
audio_frames[i * 2:i * 2 + 2]), 0)
|
audio_frames[i * 2:i * 2 + 2]), 0)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
else:
|
else:
|
||||||
logger.info(f'infer======= {current_text}')
|
logger.info(f'infer======= {current_text}')
|
||||||
human_status.try_to_talk()
|
# human_status.try_to_talk()
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
img_batch = []
|
img_batch = []
|
||||||
# for i in range(batch_size):
|
# for i in range(batch_size):
|
||||||
|
@ -8,7 +8,7 @@ from .audio_inference_onnx_handler import AudioInferenceOnnxHandler
|
|||||||
from .audio_inference_handler import AudioInferenceHandler
|
from .audio_inference_handler import AudioInferenceHandler
|
||||||
from .audio_mal_handler import AudioMalHandler
|
from .audio_mal_handler import AudioMalHandler
|
||||||
from .human_render import HumanRender
|
from .human_render import HumanRender
|
||||||
from nlp import PunctuationSplit, DouBao
|
from nlp import PunctuationSplit, DouBao, Kimi
|
||||||
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
|
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
|
||||||
from utils import load_avatar, get_device, object_stop, load_avatar_from_processed, load_avatar_from_256_processed
|
from utils import load_avatar, get_device, object_stop, load_avatar_from_processed, load_avatar_from_256_processed
|
||||||
|
|
||||||
@ -125,7 +125,8 @@ class HumanContext:
|
|||||||
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
|
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
|
||||||
self._tts = TTSEdgeHttp(self._tts_handle)
|
self._tts = TTSEdgeHttp(self._tts_handle)
|
||||||
split = PunctuationSplit()
|
split = PunctuationSplit()
|
||||||
self._nlp = DouBao(self, split, self._tts)
|
# self._nlp = DouBao(self, split, self._tts)
|
||||||
|
self._nlp = Kimi(self, split, self._tts)
|
||||||
self._asr = SherpaNcnnAsr()
|
self._asr = SherpaNcnnAsr()
|
||||||
self._asr.attach(self._nlp)
|
self._asr.attach(self._nlp)
|
||||||
|
|
||||||
|
@ -2,4 +2,5 @@
|
|||||||
|
|
||||||
from .nlp_callback import NLPCallback
|
from .nlp_callback import NLPCallback
|
||||||
from .nlp_doubao import DouBao
|
from .nlp_doubao import DouBao
|
||||||
|
from .nlp_kimi import Kimi
|
||||||
from .nlp_split import PunctuationSplit
|
from .nlp_split import PunctuationSplit
|
||||||
|
117
nlp/nlp_kimi.py
Normal file
117
nlp/nlp_kimi.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from nlp.nlp_base import NLPBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KimiHttp:
|
||||||
|
def __init__(self, token):
|
||||||
|
self.__token = token
|
||||||
|
self._response = None
|
||||||
|
self._requesting = False
|
||||||
|
|
||||||
|
def __request(self, question):
|
||||||
|
url = "https://api.moonshot.cn/v1/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": "Bearer " + self.__token,
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": "moonshot-v1-8k",
|
||||||
|
"messages": question,
|
||||||
|
'stream': True,
|
||||||
|
"temperature": 0.3
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def request(self, question, handle, callback):
|
||||||
|
t = time.time()
|
||||||
|
self._requesting = True
|
||||||
|
logger.info(f'-------dou_bao ask:{question}')
|
||||||
|
msg_list = [
|
||||||
|
{"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"
|
||||||
|
"你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视,"
|
||||||
|
"黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。"},
|
||||||
|
{"role": "user", "content": question}
|
||||||
|
]
|
||||||
|
self._response = self.__request(msg_list)
|
||||||
|
if not self._response.ok:
|
||||||
|
logger.error(f"请求失败,状态码:{self._response.status_code}")
|
||||||
|
return
|
||||||
|
sec = ''
|
||||||
|
for chunk in self._response.iter_lines():
|
||||||
|
content = chunk.decode("utf-8").strip()
|
||||||
|
if len(content) < 1:
|
||||||
|
continue
|
||||||
|
content = content[5:]
|
||||||
|
content = content.strip()
|
||||||
|
if content == '[DONE]':
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = json.loads(content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"json解析失败,错误信息:{e, content}")
|
||||||
|
continue
|
||||||
|
sec = sec + content["choices"][0]["delta"]["content"]
|
||||||
|
sec, message = handle.handle(sec)
|
||||||
|
if len(message) > 0:
|
||||||
|
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||||
|
callback(message)
|
||||||
|
if len(sec) > 0:
|
||||||
|
callback(sec)
|
||||||
|
|
||||||
|
self._requesting = False
|
||||||
|
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if self._response is not None and self._requesting:
|
||||||
|
self._response.close()
|
||||||
|
|
||||||
|
def aclose(self):
|
||||||
|
if self._response is not None and self._requesting:
|
||||||
|
self._response.close()
|
||||||
|
logger.info('DouBaoHttp close')
|
||||||
|
|
||||||
|
|
||||||
|
class Kimi(NLPBase):
|
||||||
|
def __init__(self, context, split, callback=None):
|
||||||
|
super().__init__(context, split, callback)
|
||||||
|
logger.info("DouBao init")
|
||||||
|
# Access Key ID
|
||||||
|
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
|
||||||
|
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
|
||||||
|
# Secret Access Key
|
||||||
|
# WmpRelltRXhNbVkyWWpnNU5HRmpNamc0WTJZMFpUWmpOV1E1TTJFME1tTQ==
|
||||||
|
# TkRJMk1tTTFZamt4TkRVNE5HRTNZMkUyTnpFeU5qQmxNMkUwWXpaak1HRQ==
|
||||||
|
# endpoint_id
|
||||||
|
# ep-20241008152048-fsgzf
|
||||||
|
# api_key
|
||||||
|
# c9635f9e-0f9e-4ca1-ac90-8af25a541b74
|
||||||
|
# api_ky
|
||||||
|
# eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmstY29uc29sZSIsImV4cCI6MTczMDk2NTMxOSwiaWF0IjoxNzI4MzczMzE5LCJ0IjoidXNlciIsImt2IjoxLCJhaWQiOiIyMTAyMjc3NDc1IiwidWlkIjoiMCIsImlzX291dGVyX3VzZXIiOnRydWUsInJlc291cmNlX3R5cGUiOiJlbmRwb2ludCIsInJlc291cmNlX2lkcyI6WyJlcC0yMDI0MTAwODE1MjA0OC1mc2d6ZiJdfQ.BHgFj-UKeu7IGG5VL2e6iPQEMNMkQrgmM46zYmTpoNG_ySgSFJLWYzbrIABZmqVDB4Rt58j8kvoORs-RHJUz81rXUlh3BYl9-ZwbggtAU7Z1pm54_qZ00jF0jQ6r-fUSXZo2PVCLxb_clNuEh06NyaV7ullZwUCyLKx3vhCsxPAuEvQvLc_qDBx-IYNT-UApVADaqMs-OyewoxahqQ7RvaHFF14R6ihmg9H0uvl00_JiGThJveszKvy_T-Qk6iPOy-EDI2pwJxdHMZ7By0bWK5EfZoK2hOvOSRD0BNTYnvrTfI0l2JgS0nwCVEPR4KSTXxU_oVVtuUSZp1UHvvkhvA
|
||||||
|
self.__token = 'sk-yCx0lZUmfGx0ECEQAp8jTnAisHwUIokoDXN7XNBuvMILxWnN'
|
||||||
|
self._dou_bao = KimiHttp(self.__token)
|
||||||
|
|
||||||
|
def _request(self, question):
|
||||||
|
self._dou_bao.request(question, self._split_handle, self._on_callback)
|
||||||
|
|
||||||
|
def _on_close(self):
|
||||||
|
if self._dou_bao is not None:
|
||||||
|
self._dou_bao.close()
|
||||||
|
logger.info('AsyncArk close')
|
||||||
|
|
||||||
|
def on_clear_cache(self, *args, **kwargs):
|
||||||
|
super().on_clear_cache(*args, **kwargs)
|
||||||
|
if self._dou_bao is not None:
|
||||||
|
self._dou_bao.aclose()
|
||||||
|
logger.info('DouBao clear_cache')
|
@ -15,6 +15,7 @@ class PunctuationSplit(NLPSplit):
|
|||||||
|
|
||||||
def handle(self, message: str):
|
def handle(self, message: str):
|
||||||
message = message.replace('*', '')
|
message = message.replace('*', '')
|
||||||
|
message = message.replace('#', '')
|
||||||
match = re.search(self._pattern, message)
|
match = re.search(self._pattern, message)
|
||||||
if match:
|
if match:
|
||||||
pos = match.start() + 1
|
pos = match.start() + 1
|
||||||
|
Loading…
Reference in New Issue
Block a user