modify human load face

This commit is contained in:
jiegeaiai 2024-09-12 08:15:09 +08:00
parent b0a600c7b7
commit 90ccaa222b
3 changed files with 63 additions and 13 deletions

View File

@ -7,16 +7,14 @@ import time
import numpy as np
import utils
from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal
import torch
import cv2
from tqdm import tqdm
logger = logging.getLogger(__name__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
def _load(checkpoint_path):
@ -31,6 +29,7 @@ def _load(checkpoint_path):
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
logging.info(f'Load checkpoint from {path}')
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
@ -45,6 +44,7 @@ def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
print(f'read image path:{img_path}')
frame = cv2.imread(img_path)
frames.append(frame)
return frames
@ -63,17 +63,25 @@ def __mirror_index(size, index):
# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
# .\face\img00016.jpg --audio .\audio\audio1.wav
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
logging.info(f'Using {device} for inference.')
print(f'Using {device} for inference.')
print(f'face_images_path: {face_images_path}')
model = load_model(r'.\checkpoints\wav2lip.pth')
face_list_cycle = read_images(face_images_path)
face_images_length = len(face_list_cycle)
logger.info(f'face images length: {face_images_length}')
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
length = len(face_list_cycle)
index = 0
count = 0
count_time = 0
logger.info('start inference')
logging.info('start inference')
print(f'start inference: {render_event.is_set()}')
while render_event.is_set():
print('start inference')
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
@ -88,6 +96,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
if type == 0:
is_all_silence = False
print(f'is_all_silence {is_all_silence}')
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
@ -117,7 +126,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
count_time += (time.perf_counter() - t)
count += batch_size
if count >= 100:
logger.info(f"------actual avg infer fps:{count/count_time:.4f}")
logging.info(f"------actual avg infer fps:{count/count_time:.4f}")
count = 0
count_time = 0
@ -125,7 +134,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
index = index + 1
logger.info('finish inference')
logging.info('finish inference')
class Human:
@ -142,9 +151,11 @@ class Human:
self._output_queue = mp.Queue()
self._res_frame_queue = mp.Queue(self._batch_size * 2)
self.face_images_path = r'.\face'
face_images_path = r'./face/'
self._face_image_paths = utils.read_files_path(face_images_path)
print(self._face_image_paths)
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event, self._batch_size, self.face_images_path,
mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths,
self._feat_queue, self._output_queue, self._res_frame_queue,
)).start()
@ -168,7 +179,7 @@ class Human:
self._chunk_2_mal.stop()
if self._tts is not None:
self._tts.stop()
logger.info('human destroy')
logging.info('human destroy')
def set_tts(self, tts):
if self._tts == tts:
@ -180,7 +191,7 @@ class Human:
def read(self, txt):
if self._tts is None:
logger.warning('tts is none')
logging.warning('tts is none')
return
self._tts.push_txt(txt)
@ -193,6 +204,13 @@ class Human:
self._feat_queue.put(mel_chunks)
print("22")
def render(self):
try:
img, aud = self._res_frame_queue.get(block=True, timeout=.3)
except queue.Empty:
return None
return img
# def pull_audio_chunk(self):
# try:
# chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)

24
ui.py
View File

@ -6,7 +6,7 @@ import tkinter
import tkinter.messagebox
import customtkinter
import requests
from urllib.parse import urlencode
from PIL import Image, ImageTk
from Human import Human
from tts.EdgeTTS import EdgeTTS
@ -49,9 +49,9 @@ class App(customtkinter.CTk):
self._init_image_canvas()
self._human = Human()
tts = EdgeTTS(self._human)
self._human.set_tts(tts)
self._render()
def on_destroy(self):
logger.info('------------App destroy------------')
@ -61,6 +61,26 @@ class App(customtkinter.CTk):
self._canvas = customtkinter.CTkCanvas(self.image_frame)
self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES)
def _render(self):
image = self._human.render()
if image is None:
self.after(100, self._render)
return
img = Image.fromarray(image)
imgtk = ImageTk.PhotoImage(image=img)
self._canvas.delete("all")
self._canvas.imgtk = imgtk
width = self.winfo_width() * 0.5
height = self.winfo_height() * 0.5
self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk)
self._canvas.update()
self.after(30, self._render)
def request_tts(self):
content = self.entry.get()
print('content:', content)

12
utils.py Normal file
View File

@ -0,0 +1,12 @@
#encoding = utf8
import os
def read_files_path(path):
file_paths = []
files = os.listdir(path)
for file in files:
if not os.path.isdir(file):
file_paths.append(path + file)
return file_paths