modify path and proptry access

This commit is contained in:
jiegeaiai 2024-10-17 08:25:53 +08:00
parent 8175c50420
commit 6c0733d6b9
8 changed files with 141 additions and 47 deletions

View File

@ -26,7 +26,7 @@ class SherpaNcnnAsr(AsrBase):
self._recognizer = self._create_recognizer()
def _create_recognizer(self):
base_path = os.path.join(os.getcwd(), '..', 'data', 'asr', 'sherpa-ncnn',
base_path = os.path.join(os.getcwd(), 'data', 'asr', 'sherpa-ncnn',
'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
recognizer = sherpa_ncnn.Recognizer(
tokens=base_path + '/tokens.txt',

View File

@ -1,4 +1,4 @@
#encoding = utf8
from .human_context import HumanContext
from .audio_handler import AudioHandler
from .human_context import HumanContext

View File

@ -7,7 +7,7 @@ from threading import Event, Thread
import numpy as np
import torch
from human import AudioHandler
from .audio_handler import AudioHandler
from utils import load_model, mirror_index, get_device
@ -33,7 +33,7 @@ class AudioInferenceHandler(AudioHandler):
model = load_model(r'.\checkpoints\wav2lip.pth')
print("Model loaded")
face_list_cycle = self._context.face_list_cycle()
face_list_cycle = self._context.face_list_cycle
length = len(face_list_cycle)
index = 0
@ -47,7 +47,7 @@ class AudioInferenceHandler(AudioHandler):
while True:
if self._exit_event.is_set():
start_time = time.perf_counter()
batch_size = self._context.batch_size()
batch_size = self._context.batch_size
try:
mel_batch = self._mal_queue.get(block=True, timeout=0.1)
except queue.Empty:

View File

@ -7,7 +7,7 @@ from threading import Thread, Event
import numpy as np
from human import AudioHandler
from .audio_handler import AudioHandler
from utils import melspectrogram
logger = logging.getLogger(__name__)
@ -24,7 +24,7 @@ class AudioMalHandler(AudioHandler):
self._thread.start()
self.frames = []
self.chunk = context.sample_rate() // context.fps()
self.chunk = context.sample_rate // context.fps
def on_handle(self, stream, index):
self._queue.put(stream)
@ -38,25 +38,25 @@ class AudioMalHandler(AudioHandler):
logging.info('chunk2mal exit')
def _run_step(self):
for _ in range(self._context.batch_size() * 2):
for _ in range(self._context.batch_size * 2):
frame, _type = self.get_audio_frame()
self.frames.append(frame)
self.on_next_handle((frame, _type), 0)
# context not enough, do not run network.
if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size():
if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size:
return
inputs = np.concatenate(self.frames) # [N * chunk]
mel = melspectrogram(inputs)
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
# cut off stride
left = max(0, self._context.stride_left_size() * 80 / 50)
right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size() * 80 / 50)
mel_idx_multiplier = 80. * 2 / self._context.fps()
left = max(0, self._context.stride_left_size * 80 / 50)
right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50)
mel_idx_multiplier = 80. * 2 / self._context.fps
mel_step_size = 16
i = 0
mel_chunks = []
while i < (len(self.frames) - self._context.stride_left_size() - self._context.stride_right_size()) / 2:
while i < (len(self.frames) - self._context.stride_left_size - self._context.stride_right_size) / 2:
start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx)
if start_idx + mel_step_size > len(mel[0]):
@ -67,7 +67,7 @@ class AudioMalHandler(AudioHandler):
self.on_next_handle(mel_chunks, 1)
# discard the old part to save memory
self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):]
self.frames = self.frames[-(self._context.stride_left_size + self._context.stride_right_size):]
def get_audio_frame(self):
try:

View File

@ -2,6 +2,9 @@
import logging
from asr import SherpaNcnnAsr
from human.audio_inference_handler import AudioInferenceHandler
from human.audio_mal_handler import AudioMalHandler
from human.human_render import HumanRender
from nlp import PunctuationSplit, DouBao
from tts import TTSEdge, TTSAudioSplitHandle
from utils import load_avatar, get_device
@ -19,7 +22,8 @@ class HumanContext:
self._stride_right_size = 10
self._device = get_device()
full_images, face_frames, coord_frames = load_avatar(r'./face/', self._device, self._image_size)
print(f'device:{self._device}')
full_images, face_frames, coord_frames = load_avatar(r'./face/', self._image_size, self._device)
self._frame_list_cycle = full_images
self._face_list_cycle = face_frames
self._coord_list_cycle = coord_frames
@ -27,6 +31,14 @@ class HumanContext:
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
self.asr = None
self.nlp = None
self.tts = None
self.tts_handle = None
self.mal_handler = None
self.infer_handler = None
self._render_handler = None
@property
def fps(self):
return self._fps
@ -59,12 +71,27 @@ class HumanContext:
def face_list_cycle(self):
return self._face_list_cycle
@property
def frame_list_cycle(self):
return self._frame_list_cycle
@property
def coord_list_cycle(self):
return self._coord_list_cycle
@property
def render_handler(self):
return self.render_handler
def build(self):
tts_handle = TTSAudioSplitHandle(self, None)
tts = TTSEdge(tts_handle)
self._render_handler = HumanRender(self, None)
self.infer_handler = AudioInferenceHandler(self, self._render_handler)
self.mal_handler = AudioMalHandler(self, self.infer_handler)
self.tts_handle = TTSAudioSplitHandle(self, self.mal_handler)
self.tts = TTSEdge(self.tts_handle)
split = PunctuationSplit()
nlp = DouBao(split, tts)
asr = SherpaNcnnAsr()
asr.attach(nlp)
nlp = DouBao(split, self.tts)
self.asr = SherpaNcnnAsr()
self.asr.attach(nlp)

View File

@ -1,7 +1,16 @@
#encoding = utf8
import copy
import logging
import queue
import time
from queue import Queue
from threading import Thread, Event
from human import AudioHandler
import cv2
import numpy as np
from audio_render import AudioRender
from .audio_handler import AudioHandler
class HumanRender(AudioHandler):
@ -9,6 +18,63 @@ class HumanRender(AudioHandler):
super().__init__(context, handler)
self._queue = Queue(context.batch_size * 2)
self._audio_render = None
self._image_render = None
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._exit_event.set()
self._thread.start()
def _on_run(self):
logging.info('chunk2mal run')
while self._exit_event.is_set():
self._run_step()
time.sleep(0.002)
logging.info('chunk2mal exit')
def _run_step(self):
try:
res_frame, idx, audio_frames = self._queue.get(block=True, timeout=.002)
except queue.Empty:
# print('render queue.Empty:')
return None
if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
combine_frame = self._context.frame_list_cycle[idx]
else:
bbox = self._context.coord_list_cycle[idx]
combine_frame = copy.deepcopy(self._context.frame_list_cycle[idx])
y1, y2, x1, x2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except:
return
# combine_frame = get_image(ori_frame,res_frame,bbox)
# t=time.perf_counter()
combine_frame[y1:y2, x1:x2] = res_frame
image = combine_frame
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self._image_render is not None:
self._image_render.render(image)
for audio_frame in audio_frames:
frame, type_ = audio_frame
frame = (frame * 32767).astype(np.int16)
if self._audio_render is not None:
self._audio_render.write(frame.tobytes(), int(frame.shape[0]*2))
# new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
# new_frame.planes[0].update(frame.tobytes())
# new_frame.sample_rate = 16000
def set_audio_render(self, render):
self._audio_render = render
def set_image_render(self, render):
self._image_render = render
def on_handle(self, stream, index):
self._queue.put(stream)

View File

@ -31,8 +31,8 @@ class TTSAudioHandle(AudioHandler):
class TTSAudioSplitHandle(TTSAudioHandle):
def __init__(self, context, handler):
super().__init__(context, handler)
self.sample_rate = self._context.get_audio_sample_rate()
self._chunk = self.sample_rate // self._context.get_fps()
self.sample_rate = self._context.sample_rate
self._chunk = self.sample_rate // self._context.fps
def on_handle(self, stream, index):
stream_len = stream.shape[0]

43
ui.py
View File

@ -2,9 +2,12 @@
import json
import logging
import os
import queue
from logging import handlers
import tkinter
import tkinter.messagebox
from queue import Queue
import customtkinter
import cv2
import requests
@ -13,8 +16,9 @@ from PIL import Image, ImageTk
from playsound import playsound
from Human import Human
from tts.EdgeTTS import EdgeTTS
# from Human import Human
from human import HumanContext
# from tts.EdgeTTS import EdgeTTS
logger = logging.getLogger(__name__)
@ -55,7 +59,12 @@ class App(customtkinter.CTk):
self._init_image_canvas()
self._is_play_audio = False
self._human = Human()
# self._human = Human()
self._queue = Queue()
self._human_context = HumanContext()
self._human_context.build()
render = self._human_context.render_handler
render.set_image_render(self)
self._render()
# self.play_audio()
@ -65,29 +74,25 @@ class App(customtkinter.CTk):
def on_destroy(self):
logger.info('------------App destroy------------')
self._human.on_destroy()
# self._human.on_destroy()
def play_audio(self):
return
# if self._is_play_audio:
# return
# self._is_play_audio = True
# file = os.path.curdir + '/audio/test1.wav'
# print(file)
# winsound.PlaySound(file, winsound.SND_ASYNC or winsound.SND_FILENAME)
# playsound(file)
def render_image(self, image):
self._queue.put(image)
def _init_image_canvas(self):
self._canvas = customtkinter.CTkCanvas(self.image_frame)
self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES)
def _render(self):
image = self._human.render()
try:
image = self._queue.get()
if image is None:
self.after(100, self._render)
self.after(20, self._render)
return
except queue.Empty:
self.after(20, self._render)
return
# self.play_audio()
iheight, iwidth = image.shape[0], image.shape[1]
width = self.winfo_width()
height = self.winfo_height()
@ -95,10 +100,6 @@ class App(customtkinter.CTk):
image = cv2.resize(image, (int(width), int(iheight * width / iwidth)))
else:
image = cv2.resize(image, (int(iwidth * height / iheight), int(height)), interpolation=cv2.INTER_AREA)
# image = cv2.resize(image, (int(width), int(height)), interpolation=cv2.INTER_AREA)
# image = cv2.resize(image, (int(width), int(height)), interpolation=cv2.INTER_AREA)
img = Image.fromarray(image)
imgtk = ImageTk.PhotoImage(image=img)
@ -110,7 +111,7 @@ class App(customtkinter.CTk):
height = self.winfo_height() * 0.5
self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk)
self._canvas.update()
self.after(40, self._render)
self.after(20, self._render)
def request_tts(self):
content = self.entry.get()