modify path and proptry access

This commit is contained in:
jiegeaiai 2024-10-17 08:25:53 +08:00
parent 8175c50420
commit 6c0733d6b9
8 changed files with 141 additions and 47 deletions

View File

@ -26,7 +26,7 @@ class SherpaNcnnAsr(AsrBase):
self._recognizer = self._create_recognizer() self._recognizer = self._create_recognizer()
def _create_recognizer(self): def _create_recognizer(self):
base_path = os.path.join(os.getcwd(), '..', 'data', 'asr', 'sherpa-ncnn', base_path = os.path.join(os.getcwd(), 'data', 'asr', 'sherpa-ncnn',
'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23') 'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
recognizer = sherpa_ncnn.Recognizer( recognizer = sherpa_ncnn.Recognizer(
tokens=base_path + '/tokens.txt', tokens=base_path + '/tokens.txt',

View File

@ -1,4 +1,4 @@
#encoding = utf8 #encoding = utf8
from .human_context import HumanContext
from .audio_handler import AudioHandler from .audio_handler import AudioHandler
from .human_context import HumanContext

View File

@ -7,7 +7,7 @@ from threading import Event, Thread
import numpy as np import numpy as np
import torch import torch
from human import AudioHandler from .audio_handler import AudioHandler
from utils import load_model, mirror_index, get_device from utils import load_model, mirror_index, get_device
@ -25,7 +25,7 @@ class AudioInferenceHandler(AudioHandler):
def on_handle(self, stream, type_): def on_handle(self, stream, type_):
if type_ == 1: if type_ == 1:
self._mal_queue.put(stream) self._mal_queue.put(stream)
elif type_ == 0: elif type_ == 0:
self._audio_queue.put(stream) self._audio_queue.put(stream)
@ -33,7 +33,7 @@ class AudioInferenceHandler(AudioHandler):
model = load_model(r'.\checkpoints\wav2lip.pth') model = load_model(r'.\checkpoints\wav2lip.pth')
print("Model loaded") print("Model loaded")
face_list_cycle = self._context.face_list_cycle() face_list_cycle = self._context.face_list_cycle
length = len(face_list_cycle) length = len(face_list_cycle)
index = 0 index = 0
@ -47,7 +47,7 @@ class AudioInferenceHandler(AudioHandler):
while True: while True:
if self._exit_event.is_set(): if self._exit_event.is_set():
start_time = time.perf_counter() start_time = time.perf_counter()
batch_size = self._context.batch_size() batch_size = self._context.batch_size
try: try:
mel_batch = self._mal_queue.get(block=True, timeout=0.1) mel_batch = self._mal_queue.get(block=True, timeout=0.1)
except queue.Empty: except queue.Empty:
@ -78,7 +78,7 @@ class AudioInferenceHandler(AudioHandler):
mel_batch = np.asarray(mel_batch) mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy() img_masked = img_batch.copy()
img_masked[:, face.shape[0] // 2:] = 0 img_masked[:, face.shape[0] // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, mel_batch = np.reshape(mel_batch,
[len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

View File

@ -7,7 +7,7 @@ from threading import Thread, Event
import numpy as np import numpy as np
from human import AudioHandler from .audio_handler import AudioHandler
from utils import melspectrogram from utils import melspectrogram
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -24,7 +24,7 @@ class AudioMalHandler(AudioHandler):
self._thread.start() self._thread.start()
self.frames = [] self.frames = []
self.chunk = context.sample_rate() // context.fps() self.chunk = context.sample_rate // context.fps
def on_handle(self, stream, index): def on_handle(self, stream, index):
self._queue.put(stream) self._queue.put(stream)
@ -38,25 +38,25 @@ class AudioMalHandler(AudioHandler):
logging.info('chunk2mal exit') logging.info('chunk2mal exit')
def _run_step(self): def _run_step(self):
for _ in range(self._context.batch_size() * 2): for _ in range(self._context.batch_size * 2):
frame, _type = self.get_audio_frame() frame, _type = self.get_audio_frame()
self.frames.append(frame) self.frames.append(frame)
self.on_next_handle((frame, _type), 0) self.on_next_handle((frame, _type), 0)
# context not enough, do not run network. # context not enough, do not run network.
if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size(): if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size:
return return
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]
mel = melspectrogram(inputs) mel = melspectrogram(inputs)
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames)) # print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
# cut off stride # cut off stride
left = max(0, self._context.stride_left_size() * 80 / 50) left = max(0, self._context.stride_left_size * 80 / 50)
right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size() * 80 / 50) right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50)
mel_idx_multiplier = 80. * 2 / self._context.fps() mel_idx_multiplier = 80. * 2 / self._context.fps
mel_step_size = 16 mel_step_size = 16
i = 0 i = 0
mel_chunks = [] mel_chunks = []
while i < (len(self.frames) - self._context.stride_left_size() - self._context.stride_right_size()) / 2: while i < (len(self.frames) - self._context.stride_left_size - self._context.stride_right_size) / 2:
start_idx = int(left + i * mel_idx_multiplier) start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx) # print(start_idx)
if start_idx + mel_step_size > len(mel[0]): if start_idx + mel_step_size > len(mel[0]):
@ -67,7 +67,7 @@ class AudioMalHandler(AudioHandler):
self.on_next_handle(mel_chunks, 1) self.on_next_handle(mel_chunks, 1)
# discard the old part to save memory # discard the old part to save memory
self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):] self.frames = self.frames[-(self._context.stride_left_size + self._context.stride_right_size):]
def get_audio_frame(self): def get_audio_frame(self):
try: try:

View File

@ -2,6 +2,9 @@
import logging import logging
from asr import SherpaNcnnAsr from asr import SherpaNcnnAsr
from human.audio_inference_handler import AudioInferenceHandler
from human.audio_mal_handler import AudioMalHandler
from human.human_render import HumanRender
from nlp import PunctuationSplit, DouBao from nlp import PunctuationSplit, DouBao
from tts import TTSEdge, TTSAudioSplitHandle from tts import TTSEdge, TTSAudioSplitHandle
from utils import load_avatar, get_device from utils import load_avatar, get_device
@ -19,7 +22,8 @@ class HumanContext:
self._stride_right_size = 10 self._stride_right_size = 10
self._device = get_device() self._device = get_device()
full_images, face_frames, coord_frames = load_avatar(r'./face/', self._device, self._image_size) print(f'device:{self._device}')
full_images, face_frames, coord_frames = load_avatar(r'./face/', self._image_size, self._device)
self._frame_list_cycle = full_images self._frame_list_cycle = full_images
self._face_list_cycle = face_frames self._face_list_cycle = face_frames
self._coord_list_cycle = coord_frames self._coord_list_cycle = coord_frames
@ -27,6 +31,14 @@ class HumanContext:
logging.info(f'face images length: {face_images_length}') logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}')
self.asr = None
self.nlp = None
self.tts = None
self.tts_handle = None
self.mal_handler = None
self.infer_handler = None
self._render_handler = None
@property @property
def fps(self): def fps(self):
return self._fps return self._fps
@ -59,12 +71,27 @@ class HumanContext:
def face_list_cycle(self): def face_list_cycle(self):
return self._face_list_cycle return self._face_list_cycle
@property
def frame_list_cycle(self):
return self._frame_list_cycle
@property
def coord_list_cycle(self):
return self._coord_list_cycle
@property
def render_handler(self):
return self.render_handler
def build(self): def build(self):
tts_handle = TTSAudioSplitHandle(self, None) self._render_handler = HumanRender(self, None)
tts = TTSEdge(tts_handle) self.infer_handler = AudioInferenceHandler(self, self._render_handler)
self.mal_handler = AudioMalHandler(self, self.infer_handler)
self.tts_handle = TTSAudioSplitHandle(self, self.mal_handler)
self.tts = TTSEdge(self.tts_handle)
split = PunctuationSplit() split = PunctuationSplit()
nlp = DouBao(split, tts) nlp = DouBao(split, self.tts)
asr = SherpaNcnnAsr() self.asr = SherpaNcnnAsr()
asr.attach(nlp) self.asr.attach(nlp)

View File

@ -1,7 +1,16 @@
#encoding = utf8 #encoding = utf8
import copy
import logging
import queue
import time
from queue import Queue from queue import Queue
from threading import Thread, Event
from human import AudioHandler import cv2
import numpy as np
from audio_render import AudioRender
from .audio_handler import AudioHandler
class HumanRender(AudioHandler): class HumanRender(AudioHandler):
@ -9,6 +18,63 @@ class HumanRender(AudioHandler):
super().__init__(context, handler) super().__init__(context, handler)
self._queue = Queue(context.batch_size * 2) self._queue = Queue(context.batch_size * 2)
self._audio_render = None
self._image_render = None
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._exit_event.set()
self._thread.start()
def _on_run(self):
logging.info('chunk2mal run')
while self._exit_event.is_set():
self._run_step()
time.sleep(0.002)
logging.info('chunk2mal exit')
def _run_step(self):
try:
res_frame, idx, audio_frames = self._queue.get(block=True, timeout=.002)
except queue.Empty:
# print('render queue.Empty:')
return None
if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
combine_frame = self._context.frame_list_cycle[idx]
else:
bbox = self._context.coord_list_cycle[idx]
combine_frame = copy.deepcopy(self._context.frame_list_cycle[idx])
y1, y2, x1, x2 = bbox
try:
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
except:
return
# combine_frame = get_image(ori_frame,res_frame,bbox)
# t=time.perf_counter()
combine_frame[y1:y2, x1:x2] = res_frame
image = combine_frame
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self._image_render is not None:
self._image_render.render(image)
for audio_frame in audio_frames:
frame, type_ = audio_frame
frame = (frame * 32767).astype(np.int16)
if self._audio_render is not None:
self._audio_render.write(frame.tobytes(), int(frame.shape[0]*2))
# new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
# new_frame.planes[0].update(frame.tobytes())
# new_frame.sample_rate = 16000
def set_audio_render(self, render):
self._audio_render = render
def set_image_render(self, render):
self._image_render = render
def on_handle(self, stream, index): def on_handle(self, stream, index):
self._queue.put(stream) self._queue.put(stream)

View File

@ -31,8 +31,8 @@ class TTSAudioHandle(AudioHandler):
class TTSAudioSplitHandle(TTSAudioHandle): class TTSAudioSplitHandle(TTSAudioHandle):
def __init__(self, context, handler): def __init__(self, context, handler):
super().__init__(context, handler) super().__init__(context, handler)
self.sample_rate = self._context.get_audio_sample_rate() self.sample_rate = self._context.sample_rate
self._chunk = self.sample_rate // self._context.get_fps() self._chunk = self.sample_rate // self._context.fps
def on_handle(self, stream, index): def on_handle(self, stream, index):
stream_len = stream.shape[0] stream_len = stream.shape[0]

45
ui.py
View File

@ -2,9 +2,12 @@
import json import json
import logging import logging
import os import os
import queue
from logging import handlers from logging import handlers
import tkinter import tkinter
import tkinter.messagebox import tkinter.messagebox
from queue import Queue
import customtkinter import customtkinter
import cv2 import cv2
import requests import requests
@ -13,8 +16,9 @@ from PIL import Image, ImageTk
from playsound import playsound from playsound import playsound
from Human import Human # from Human import Human
from tts.EdgeTTS import EdgeTTS from human import HumanContext
# from tts.EdgeTTS import EdgeTTS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,7 +59,12 @@ class App(customtkinter.CTk):
self._init_image_canvas() self._init_image_canvas()
self._is_play_audio = False self._is_play_audio = False
self._human = Human() # self._human = Human()
self._queue = Queue()
self._human_context = HumanContext()
self._human_context.build()
render = self._human_context.render_handler
render.set_image_render(self)
self._render() self._render()
# self.play_audio() # self.play_audio()
@ -65,29 +74,25 @@ class App(customtkinter.CTk):
def on_destroy(self): def on_destroy(self):
logger.info('------------App destroy------------') logger.info('------------App destroy------------')
self._human.on_destroy() # self._human.on_destroy()
def play_audio(self): def render_image(self, image):
return self._queue.put(image)
# if self._is_play_audio:
# return
# self._is_play_audio = True
# file = os.path.curdir + '/audio/test1.wav'
# print(file)
# winsound.PlaySound(file, winsound.SND_ASYNC or winsound.SND_FILENAME)
# playsound(file)
def _init_image_canvas(self): def _init_image_canvas(self):
self._canvas = customtkinter.CTkCanvas(self.image_frame) self._canvas = customtkinter.CTkCanvas(self.image_frame)
self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES) self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES)
def _render(self): def _render(self):
image = self._human.render() try:
if image is None: image = self._queue.get()
self.after(100, self._render) if image is None:
self.after(20, self._render)
return
except queue.Empty:
self.after(20, self._render)
return return
# self.play_audio()
iheight, iwidth = image.shape[0], image.shape[1] iheight, iwidth = image.shape[0], image.shape[1]
width = self.winfo_width() width = self.winfo_width()
height = self.winfo_height() height = self.winfo_height()
@ -95,10 +100,6 @@ class App(customtkinter.CTk):
image = cv2.resize(image, (int(width), int(iheight * width / iwidth))) image = cv2.resize(image, (int(width), int(iheight * width / iwidth)))
else: else:
image = cv2.resize(image, (int(iwidth * height / iheight), int(height)), interpolation=cv2.INTER_AREA) image = cv2.resize(image, (int(iwidth * height / iheight), int(height)), interpolation=cv2.INTER_AREA)
# image = cv2.resize(image, (int(width), int(height)), interpolation=cv2.INTER_AREA)
# image = cv2.resize(image, (int(width), int(height)), interpolation=cv2.INTER_AREA)
img = Image.fromarray(image) img = Image.fromarray(image)
imgtk = ImageTk.PhotoImage(image=img) imgtk = ImageTk.PhotoImage(image=img)
@ -110,7 +111,7 @@ class App(customtkinter.CTk):
height = self.winfo_height() * 0.5 height = self.winfo_height() * 0.5
self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk) self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk)
self._canvas.update() self._canvas.update()
self.after(40, self._render) self.after(20, self._render)
def request_tts(self): def request_tts(self):
content = self.entry.get() content = self.entry.get()