modify ui and nlp tts code

This commit is contained in:
brige 2024-10-17 23:26:21 +08:00
parent 6c0733d6b9
commit ad54248ff3
15 changed files with 192 additions and 201 deletions

View File

@ -1,8 +1,7 @@
#encoding = utf8
import logging
import os
import sys
import time
try:
import sounddevice as sd
@ -16,18 +15,26 @@ except ImportError as e:
import sherpa_ncnn
from asr.asr_base import AsrBase
logger = logging.getLogger(__name__)
current_file_path = os.path.dirname(os.path.abspath(__file__))
class SherpaNcnnAsr(AsrBase):
def __init__(self):
super().__init__()
self._recognizer = self._create_recognizer()
logger.info('SherpaNcnnAsr init')
def __del__(self):
self.__del__()
logger.info('SherpaNcnnAsr del')
def _create_recognizer(self):
base_path = os.path.join(os.getcwd(), 'data', 'asr', 'sherpa-ncnn',
base_path = os.path.join(current_file_path, '..', 'data', 'asr', 'sherpa-ncnn',
'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
logger.info(f'_create_recognizer init, path:{base_path}')
recognizer = sherpa_ncnn.Recognizer(
tokens=base_path + '/tokens.txt',
encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param',
@ -50,6 +57,7 @@ class SherpaNcnnAsr(AsrBase):
def _recognize_loop(self):
segment_id = 0
last_result = ""
logger.info(f'_recognize_loop')
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
while not self._stop_event.is_set():
samples, _ = s.read(self._samples_per_read) # a blocking read
@ -70,74 +78,3 @@ class SherpaNcnnAsr(AsrBase):
self._notify_complete(result)
segment_id += 1
self._recognizer.reset()
def main():
print("Started! Please speak")
asr = SherpaNcnnAsr()
time.sleep(20)
print("Stop! ")
asr.stop()
# print("Started! Please speak")
# recognizer = create_recognizer()
# sample_rate = recognizer.sample_rate
# samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
# last_result = ""
# segment_id = 0
#
# with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
# while True:
# samples, _ = s.read(samples_per_read) # a blocking read
# samples = samples.reshape(-1)
# recognizer.accept_waveform(sample_rate, samples)
#
# is_endpoint = recognizer.is_endpoint
#
# result = recognizer.text
# if result and (last_result != result):
# last_result = result
# print("\r{}:{}".format(segment_id, result), end=".", flush=True)
#
# if is_endpoint:
# if result:
# print("\r{}:{}".format(segment_id, result), flush=True)
# segment_id += 1
# recognizer.reset()
# print("Started! Please speak")
# recognizer = create_recognizer()
# sample_rate = recognizer.sample_rate
# samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
# last_result = ""
# with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
# while True:
# samples, _ = s.read(samples_per_read) # a blocking read
# samples = samples.reshape(-1)
# recognizer.accept_waveform(sample_rate, samples)
# result = recognizer.text
# if last_result != result:
# last_result = result
# print("\r{}".format(result), end="", flush=True)
'''
if __name__ == "__main__":
devices = sd.query_devices()
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
# devices = sd.query_devices()
# print(devices)
# default_input_device_idx = sd.default.device[0]
# print(f'Use default device: {devices[default_input_device_idx]["name"]}')
#
# try:
# main()
# except KeyboardInterrupt:
# print("\nCaught Ctrl + C. Exiting")
'''

View File

@ -55,6 +55,7 @@ def detect(net, img, device):
return bboxlist
def batch_detect(net, imgs, device):
imgs = imgs - np.array([104, 117, 123])
imgs = imgs.transpose(0, 3, 1, 2)
@ -93,6 +94,7 @@ def batch_detect(net, imgs, device):
return bboxlist
def flip_detect(net, img, device):
img = cv2.flip(img, 1)
b = detect(net, img, device)

View File

@ -1,4 +1,6 @@
#encoding = utf8
from .audio_handler import AudioHandler
from .human_context import HumanContext
from .audio_mal_handler import AudioMalHandler
from .audio_inference_handler import AudioInferenceHandler
from .human_render import HumanRender

View File

@ -1,21 +0,0 @@
#encoding = utf8
import logging
from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
class AudioHandler(ABC):
def __init__(self, context, handler):
self._context = context
self._handler = handler
@abstractmethod
def on_handle(self, stream, index):
pass
def on_next_handle(self, stream, type_):
if self._handler is not None:
self._handler.on_handle(stream, type_)
else:
logging.info(f'_handler is None')

View File

@ -1,4 +1,6 @@
#encoding = utf8
import logging
import os
import queue
import time
from queue import Queue
@ -7,9 +9,12 @@ from threading import Event, Thread
import numpy as np
import torch
from .audio_handler import AudioHandler
from human_handler import AudioHandler
from utils import load_model, mirror_index, get_device
logger = logging.getLogger(__name__)
current_file_path = os.path.dirname(os.path.abspath(__file__))
class AudioInferenceHandler(AudioHandler):
def __init__(self, context, handler):
@ -22,6 +27,7 @@ class AudioInferenceHandler(AudioHandler):
self._run_thread = Thread(target=self.__on_run)
self._exit_event.set()
self._run_thread.start()
logger.info("AudioInferenceHandler init")
def on_handle(self, stream, type_):
if type_ == 1:
@ -30,8 +36,10 @@ class AudioInferenceHandler(AudioHandler):
self._audio_queue.put(stream)
def __on_run(self):
model = load_model(r'.\checkpoints\wav2lip.pth')
print("Model loaded")
wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'wav2lip.pth')
logger.info(f'AudioInferenceHandler init, path:{wav2lip_path}')
model = load_model(wav2lip_path)
logger.info("Model loaded")
face_list_cycle = self._context.face_list_cycle
@ -39,10 +47,10 @@ class AudioInferenceHandler(AudioHandler):
index = 0
count = 0
count_time = 0
print('start inference')
logger.info('start inference')
device = get_device()
print(f'use device:{device}')
logger.info(f'use device:{device}')
while True:
if self._exit_event.is_set():
@ -66,7 +74,7 @@ class AudioInferenceHandler(AudioHandler):
0)
index = index + 1
else:
print('infer=======')
logger.info('infer=======')
t = time.perf_counter()
img_batch = []
for i in range(batch_size):
@ -95,20 +103,21 @@ class AudioInferenceHandler(AudioHandler):
count += batch_size
if count >= 100:
print(f"------actual avg infer fps:{count / count_time:.4f}")
logger.info(f"------actual avg infer fps:{count / count_time:.4f}")
count = 0
count_time = 0
image_index = 0
for i, res_frame in enumerate(pred):
self.on_next_handle(
(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
0)
index = index + 1
image_index = image_index + 1
print('batch count', image_index)
print('total batch time:', time.perf_counter() - start_time)
logger.info(f'total batch time: {time.perf_counter() - start_time}')
else:
time.sleep(1)
break
print('musereal inference processor stop')
logger.info('AudioInferenceHandler inference processor stop')
def stop(self):
self._exit_event.clear()
self._run_thread.join()

View File

@ -7,7 +7,7 @@ from threading import Thread, Event
import numpy as np
from .audio_handler import AudioHandler
from human_handler import AudioHandler
from utils import melspectrogram
logger = logging.getLogger(__name__)
@ -25,6 +25,7 @@ class AudioMalHandler(AudioHandler):
self.frames = []
self.chunk = context.sample_rate // context.fps
logger.info("AudioMalHandler init")
def on_handle(self, stream, index):
self._queue.put(stream)

View File

@ -1,15 +1,17 @@
#encoding = utf8
import logging
import os
from asr import SherpaNcnnAsr
from human.audio_inference_handler import AudioInferenceHandler
from human.audio_mal_handler import AudioMalHandler
from human.human_render import HumanRender
from .audio_inference_handler import AudioInferenceHandler
from .audio_mal_handler import AudioMalHandler
from .human_render import HumanRender
from nlp import PunctuationSplit, DouBao
from tts import TTSEdge, TTSAudioSplitHandle
from utils import load_avatar, get_device
logger = logging.getLogger(__name__)
current_file_path = os.path.dirname(os.path.abspath(__file__))
class HumanContext:
@ -23,7 +25,9 @@ class HumanContext:
self._device = get_device()
print(f'device:{self._device}')
full_images, face_frames, coord_frames = load_avatar(r'./face/', self._image_size, self._device)
base_path = os.path.join(current_file_path, '..', 'face')
logger.info(f'_create_recognizer init, path:{base_path}')
full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
self._frame_list_cycle = full_images
self._face_list_cycle = face_frames
self._coord_list_cycle = coord_frames
@ -31,14 +35,24 @@ class HumanContext:
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
self.asr = None
self.nlp = None
self.tts = None
self.tts_handle = None
self.mal_handler = None
self.infer_handler = None
self._asr = None
self._nlp = None
self._tts = None
self._tts_handle = None
self._mal_handler = None
self._infer_handler = None
self._render_handler = None
def __del__(self):
print(f'HumanContext: __del__')
self._asr.stop()
self._nlp.stop()
self._tts.stop()
self._tts_handle.stop()
self._mal_handler.stop()
self._infer_handler.stop()
self._render_handler.stop()
@property
def fps(self):
return self._fps
@ -81,17 +95,17 @@ class HumanContext:
@property
def render_handler(self):
return self.render_handler
return self._render_handler
def build(self):
self._render_handler = HumanRender(self, None)
self.infer_handler = AudioInferenceHandler(self, self._render_handler)
self.mal_handler = AudioMalHandler(self, self.infer_handler)
self.tts_handle = TTSAudioSplitHandle(self, self.mal_handler)
self.tts = TTSEdge(self.tts_handle)
self._infer_handler = AudioInferenceHandler(self, self._render_handler)
self._mal_handler = AudioMalHandler(self, self._infer_handler)
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
self._tts = TTSEdge(self._tts_handle)
split = PunctuationSplit()
nlp = DouBao(split, self.tts)
self.asr = SherpaNcnnAsr()
self.asr.attach(nlp)
self._nlp = DouBao(split, self._tts)
self._asr = SherpaNcnnAsr()
self._asr.attach(self._nlp)

View File

@ -9,8 +9,7 @@ from threading import Thread, Event
import cv2
import numpy as np
from audio_render import AudioRender
from .audio_handler import AudioHandler
from human_handler import AudioHandler
class HumanRender(AudioHandler):
@ -27,12 +26,12 @@ class HumanRender(AudioHandler):
self._thread.start()
def _on_run(self):
logging.info('chunk2mal run')
logging.info('human render run')
while self._exit_event.is_set():
self._run_step()
time.sleep(0.002)
logging.info('chunk2mal exit')
logging.info('human render exit')
def _run_step(self):
try:
@ -58,7 +57,7 @@ class HumanRender(AudioHandler):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self._image_render is not None:
self._image_render.render(image)
self._image_render.on_render(image)
for audio_frame in audio_frames:
frame, type_ = audio_frame
@ -69,7 +68,6 @@ class HumanRender(AudioHandler):
# new_frame.planes[0].update(frame.tobytes())
# new_frame.sample_rate = 16000
def set_audio_render(self, render):
self._audio_render = render
@ -79,4 +77,6 @@ class HumanRender(AudioHandler):
def on_handle(self, stream, index):
self._queue.put(stream)
def stop(self):
self._exit_event.clear()
self._thread.join()

View File

@ -12,6 +12,7 @@ logger = logging.getLogger(__name__)
class DouBao(NLPBase):
def __init__(self, split, callback=None):
super().__init__(split, callback)
logger.info("DouBao init")
# Access Key ID
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
@ -30,7 +31,7 @@ class DouBao(NLPBase):
async def _request(self, question):
t = time.time()
logger.info(f'_request:{question}')
print(f'-------dou_bao ask:', question)
logger.info(f'-------dou_bao ask:{question}')
try:
stream = await self.__client.chat.completions.create(
model="ep-20241008152048-fsgzf",
@ -51,38 +52,9 @@ class DouBao(NLPBase):
except Exception as e:
print(e)
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
async def _on_close(self):
print('AsyncArk close')
logger.info('AsyncArk close')
if self.__client is not None and not self.__client.is_closed():
await self.__client.close()
'''
if __name__ == "__main__":
# print(get_dou_bao_api())
dou_bao = DouBao()
dou_bao.ask('你好。')
dou_bao.ask('你好,你是谁?')
dou_bao.ask('你能做什么?')
dou_bao.ask('介绍一下,我自己。')
count = 1000
sec = ''
while count >= 0:
count = count - 1
if nlp_queue.empty():
time.sleep(0.1)
continue
sec = sec + nlp_queue.get(block=True, timeout=0.01)
pattern = r'[,。、;?!,.!?]'
match = re.search(pattern, sec)
if match:
pos = match.start() + 1
print(sec[:pos])
sec = sec[pos:]
print(sec)
dou_bao.stop()
'''

View File

@ -0,0 +1,44 @@
#encoding = utf8
import logging
import os
import sys
import time
from human import HumanContext
from utils import config_logging
# try:
# import sounddevice as sd
# except ImportError as e:
# print("Please install sounddevice first. You can use")
# print()
# print(" pip install sounddevice")
# print()
# print("to install it")
# sys.exit(-1)
#
def main():
print("Started! Please speak")
human = HumanContext()
human.build()
time.sleep(60)
print("Stop! ")
if __name__ == "__main__":
# devices = sd.query_devices()
# print(devices)
# default_input_device_idx = sd.default.device[0]
# print(f'Use default device: {devices[default_input_device_idx]["name"]}')
current_file_path = os.path.dirname(os.path.abspath(__file__))
log_path = os.path.join(current_file_path, '..', 'logs', 'info.log')
config_logging(log_path, logging.INFO, logging.INFO)
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")

View File

@ -1,9 +1,13 @@
#encoding = utf8
import heapq
import logging
import os
import shutil
from utils import save_wav
from human import AudioHandler
from human_handler import AudioHandler
logger = logging.getLogger(__name__)
class TTSAudioHandle(AudioHandler):
@ -27,26 +31,38 @@ class TTSAudioHandle(AudioHandler):
def on_handle(self, stream, index):
pass
def stop(self):
pass
class TTSAudioSplitHandle(TTSAudioHandle):
def __init__(self, context, handler):
super().__init__(context, handler)
self.sample_rate = self._context.sample_rate
self._chunk = self.sample_rate // self._context.fps
self._priority_queue = []
logger.info("TTSAudioSplitHandle init")
def on_handle(self, stream, index):
# heapq.heappush(self._priority_queue, (index, stream))
if stream is None:
heapq.heappush(self._priority_queue, (index, None))
stream_len = stream.shape[0]
idx = 0
while stream_len >= self._chunk:
self._context.put_audio_frame(stream[idx:idx + self._chunk])
self.on_next_handle(stream[idx:idx + self._chunk], 0)
stream_len -= self._chunk
idx += self._chunk
def stop(self):
pass
class TTSAudioSaveHandle(TTSAudioHandle):
def __init__(self):
super().__init__()
def __init__(self, context, handler):
super().__init__(context, handler)
self._save_path_dir = '../temp/audio/'
self._clean()
@ -72,3 +88,6 @@ class TTSAudioSaveHandle(TTSAudioHandle):
file_name = self._save_path_dir + str(index) + '.wav'
save_wav(stream, file_name, self.sample_rate)
def stop(self):
pass

View File

@ -1,5 +1,5 @@
#encoding = utf8
import logging
from io import BytesIO
import numpy as np
@ -9,11 +9,14 @@ import resampy
from .tts_base import TTSBase
logger = logging.getLogger(__name__)
class TTSEdge(TTSBase):
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
super().__init__(handle)
self._voice = voice
logger.info(f"TTSEdge init, {voice}")
async def _on_request(self, txt: str):
print('_on_request, txt')
@ -42,6 +45,7 @@ class TTSEdge(TTSBase):
print('-------tts finish push chunk')
except Exception as e:
self._handle.on_handle(None, index)
stream.seek(0)
stream.truncate()
print('-------tts finish error:', e)

41
ui.py
View File

@ -16,8 +16,11 @@ from PIL import Image, ImageTk
from playsound import playsound
from audio_render import AudioRender
# from Human import Human
from human import HumanContext
from utils import config_logging
# from tts.EdgeTTS import EdgeTTS
logger = logging.getLogger(__name__)
@ -48,7 +51,7 @@ class App(customtkinter.CTk):
# self.logo_label.grid(row=0, column=0, padx=20, pady=(20, 10))
self.entry = customtkinter.CTkEntry(self, placeholder_text="输入内容")
self.entry.insert(0, "大家好,我是九零科技有限公司,虚拟数字人。")
self.entry.insert(0, "大家好,测试虚拟数字人。")
self.entry.grid(row=2, column=0, columnspan=2, padx=(20, 0), pady=(20, 20), sticky="nsew")
self.main_button_1 = customtkinter.CTkButton(master=self, fg_color="transparent", border_width=2,
@ -58,13 +61,14 @@ class App(customtkinter.CTk):
self._init_image_canvas()
self._is_play_audio = False
self._audio_render = AudioRender()
# self._human = Human()
self._queue = Queue()
self._human_context = HumanContext()
self._human_context.build()
render = self._human_context.render_handler
render.set_image_render(self)
render.set_audio_render(self._audio_render)
self._render()
# self.play_audio()
@ -84,13 +88,14 @@ class App(customtkinter.CTk):
self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES)
def _render(self):
after_time = 29
try:
image = self._queue.get()
image = self._queue.get(block=True, timeout=0.003)
if image is None:
self.after(20, self._render)
self.after(after_time, self._render)
return
except queue.Empty:
self.after(20, self._render)
self.after(after_time, self._render)
return
iheight, iwidth = image.shape[0], image.shape[1]
@ -111,7 +116,7 @@ class App(customtkinter.CTk):
height = self.winfo_height() * 0.5
self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk)
self._canvas.update()
self.after(20, self._render)
self.after(after_time, self._render)
def request_tts(self):
content = self.entry.get()
@ -148,27 +153,9 @@ class App(customtkinter.CTk):
sound = AudioSegment.from_mp3('./audio/mp3/' + file_name)
sound.export('./audio/wav/' + file_name + '.wav', format="wav")
# open('./audio/', 'wb') with
def config_logging(file_name: str, console_level: int=logging.INFO, file_level: int=logging.DEBUG):
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
file_handler.setFormatter(logging.Formatter(
'%(asctime)s [%(levelname)s] %(module)s.%(lineno)d %(name)s:\t%(message)s'
))
file_handler.setLevel(file_level)
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'[%(asctime)s %(levelname)s] %(message)s',
datefmt="%Y/%m/%d %H:%M:%S"
))
console_handler.setLevel(console_level)
logging.basicConfig(
level=min(console_level, file_level),
handlers=[file_handler, console_handler],
)
def on_render(self, image):
self._queue.put(image)
print('on_render', self._queue.qsize())
if __name__ == "__main__":

View File

@ -1,6 +1,6 @@
#encoding = utf8
from .async_task_queue import AsyncTaskQueue
from .utils import mirror_index, load_model, get_device, load_avatar
from .utils import mirror_index, load_model, get_device, load_avatar, config_logging
from .audio_utils import melspectrogram, save_wav

View File

@ -38,7 +38,7 @@ def read_files_path(path):
files = os.listdir(path)
for file in files:
if not os.path.isdir(file):
file_paths.append(path + file)
file_paths.append(os.path.join(path, file))
return file_paths
@ -160,6 +160,7 @@ def load_model(path):
def load_avatar(path, img_size, device):
print(f'load avatar:{path}')
face_images_path = path
face_images_path = read_files_path(face_images_path)
full_list_cycle = read_images(face_images_path)
@ -174,3 +175,23 @@ def load_avatar(path, img_size, device):
coord_frames.append(coord)
return full_list_cycle, face_frames, coord_frames
def config_logging(file_name: str, console_level: int=logging.INFO, file_level: int=logging.DEBUG):
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
file_handler.setFormatter(logging.Formatter(
'%(asctime)s [%(levelname)s] %(module)s.%(lineno)d %(name)s:\t%(message)s'
))
file_handler.setLevel(file_level)
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'[%(asctime)s %(levelname)s] %(message)s',
datefmt="%Y/%m/%d %H:%M:%S"
))
console_handler.setLevel(console_level)
logging.basicConfig(
level=min(console_level, file_level),
handlers=[file_handler, console_handler],
)