modify human load face

This commit is contained in:
jiegeaiai 2024-09-12 08:15:09 +08:00
parent b0a600c7b7
commit 90ccaa222b
3 changed files with 63 additions and 13 deletions

View File

@ -7,16 +7,14 @@ import time
import numpy as np import numpy as np
import utils
from models import Wav2Lip from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal from tts.Chunk2Mal import Chunk2Mal
import torch import torch
import cv2 import cv2
from tqdm import tqdm from tqdm import tqdm
logger = logging.getLogger(__name__)
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
def _load(checkpoint_path): def _load(checkpoint_path):
@ -31,6 +29,7 @@ def _load(checkpoint_path):
def load_model(path): def load_model(path):
model = Wav2Lip() model = Wav2Lip()
print("Load checkpoint from: {}".format(path)) print("Load checkpoint from: {}".format(path))
logging.info(f'Load checkpoint from {path}')
checkpoint = _load(path) checkpoint = _load(path)
s = checkpoint["state_dict"] s = checkpoint["state_dict"]
new_s = {} new_s = {}
@ -45,6 +44,7 @@ def read_images(img_list):
frames = [] frames = []
print('reading images...') print('reading images...')
for img_path in tqdm(img_list): for img_path in tqdm(img_list):
print(f'read image path:{img_path}')
frame = cv2.imread(img_path) frame = cv2.imread(img_path)
frames.append(frame) frames.append(frame)
return frames return frames
@ -63,17 +63,25 @@ def __mirror_index(size, index):
# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face # python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
# .\face\img00016.jpg --audio .\audio\audio1.wav # .\face\img00016.jpg --audio .\audio\audio1.wav
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue): def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
logging.info(f'Using {device} for inference.')
print(f'Using {device} for inference.')
print(f'face_images_path: {face_images_path}')
model = load_model(r'.\checkpoints\wav2lip.pth') model = load_model(r'.\checkpoints\wav2lip.pth')
face_list_cycle = read_images(face_images_path) face_list_cycle = read_images(face_images_path)
face_images_length = len(face_list_cycle) face_images_length = len(face_list_cycle)
logger.info(f'face images length: {face_images_length}') logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
length = len(face_list_cycle) length = len(face_list_cycle)
index = 0 index = 0
count = 0 count = 0
count_time = 0 count_time = 0
logger.info('start inference') logging.info('start inference')
print(f'start inference: {render_event.is_set()}')
while render_event.is_set(): while render_event.is_set():
print('start inference')
try: try:
mel_batch = audio_feat_queue.get(block=True, timeout=1) mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty: except queue.Empty:
@ -88,6 +96,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
if type == 0: if type == 0:
is_all_silence = False is_all_silence = False
print(f'is_all_silence {is_all_silence}')
if is_all_silence: if is_all_silence:
for i in range(batch_size): for i in range(batch_size):
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2])) res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
@ -117,7 +126,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
count_time += (time.perf_counter() - t) count_time += (time.perf_counter() - t)
count += batch_size count += batch_size
if count >= 100: if count >= 100:
logger.info(f"------actual avg infer fps:{count/count_time:.4f}") logging.info(f"------actual avg infer fps:{count/count_time:.4f}")
count = 0 count = 0
count_time = 0 count_time = 0
@ -125,7 +134,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2])) res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
index = index + 1 index = index + 1
logger.info('finish inference') logging.info('finish inference')
class Human: class Human:
@ -142,9 +151,11 @@ class Human:
self._output_queue = mp.Queue() self._output_queue = mp.Queue()
self._res_frame_queue = mp.Queue(self._batch_size * 2) self._res_frame_queue = mp.Queue(self._batch_size * 2)
self.face_images_path = r'.\face' face_images_path = r'./face/'
self._face_image_paths = utils.read_files_path(face_images_path)
print(self._face_image_paths)
self.render_event = mp.Event() self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event, self._batch_size, self.face_images_path, mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths,
self._feat_queue, self._output_queue, self._res_frame_queue, self._feat_queue, self._output_queue, self._res_frame_queue,
)).start() )).start()
@ -168,7 +179,7 @@ class Human:
self._chunk_2_mal.stop() self._chunk_2_mal.stop()
if self._tts is not None: if self._tts is not None:
self._tts.stop() self._tts.stop()
logger.info('human destroy') logging.info('human destroy')
def set_tts(self, tts): def set_tts(self, tts):
if self._tts == tts: if self._tts == tts:
@ -180,7 +191,7 @@ class Human:
def read(self, txt): def read(self, txt):
if self._tts is None: if self._tts is None:
logger.warning('tts is none') logging.warning('tts is none')
return return
self._tts.push_txt(txt) self._tts.push_txt(txt)
@ -193,6 +204,13 @@ class Human:
self._feat_queue.put(mel_chunks) self._feat_queue.put(mel_chunks)
print("22") print("22")
def render(self):
try:
img, aud = self._res_frame_queue.get(block=True, timeout=.3)
except queue.Empty:
return None
return img
# def pull_audio_chunk(self): # def pull_audio_chunk(self):
# try: # try:
# chunk = self._audio_chunk_queue.get(block=True, timeout=1.0) # chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)

24
ui.py
View File

@ -6,7 +6,7 @@ import tkinter
import tkinter.messagebox import tkinter.messagebox
import customtkinter import customtkinter
import requests import requests
from urllib.parse import urlencode from PIL import Image, ImageTk
from Human import Human from Human import Human
from tts.EdgeTTS import EdgeTTS from tts.EdgeTTS import EdgeTTS
@ -49,9 +49,9 @@ class App(customtkinter.CTk):
self._init_image_canvas() self._init_image_canvas()
self._human = Human() self._human = Human()
tts = EdgeTTS(self._human) tts = EdgeTTS(self._human)
self._human.set_tts(tts) self._human.set_tts(tts)
self._render()
def on_destroy(self): def on_destroy(self):
logger.info('------------App destroy------------') logger.info('------------App destroy------------')
@ -61,6 +61,26 @@ class App(customtkinter.CTk):
self._canvas = customtkinter.CTkCanvas(self.image_frame) self._canvas = customtkinter.CTkCanvas(self.image_frame)
self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES) self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES)
def _render(self):
image = self._human.render()
if image is None:
self.after(100, self._render)
return
img = Image.fromarray(image)
imgtk = ImageTk.PhotoImage(image=img)
self._canvas.delete("all")
self._canvas.imgtk = imgtk
width = self.winfo_width() * 0.5
height = self.winfo_height() * 0.5
self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk)
self._canvas.update()
self.after(30, self._render)
def request_tts(self): def request_tts(self):
content = self.entry.get() content = self.entry.get()
print('content:', content) print('content:', content)

12
utils.py Normal file
View File

@ -0,0 +1,12 @@
#encoding = utf8
import os
def read_files_path(path):
file_paths = []
files = os.listdir(path)
for file in files:
if not os.path.isdir(file):
file_paths.append(path + file)
return file_paths