From 90ccaa222b1e28a58c6409cc437e8da35c2986bc Mon Sep 17 00:00:00 2001 From: jiegeaiai Date: Thu, 12 Sep 2024 08:15:09 +0800 Subject: [PATCH] modify human load face --- Human.py | 40 +++++++++++++++++++++++++++++----------- ui.py | 24 ++++++++++++++++++++++-- utils.py | 12 ++++++++++++ 3 files changed, 63 insertions(+), 13 deletions(-) create mode 100644 utils.py diff --git a/Human.py b/Human.py index a2ecf69..d73fdcd 100644 --- a/Human.py +++ b/Human.py @@ -7,16 +7,14 @@ import time import numpy as np +import utils from models import Wav2Lip from tts.Chunk2Mal import Chunk2Mal import torch import cv2 from tqdm import tqdm -logger = logging.getLogger(__name__) - device = 'cuda' if torch.cuda.is_available() else 'cpu' -print('Using {} for inference.'.format(device)) def _load(checkpoint_path): @@ -31,6 +29,7 @@ def _load(checkpoint_path): def load_model(path): model = Wav2Lip() print("Load checkpoint from: {}".format(path)) + logging.info(f'Load checkpoint from {path}') checkpoint = _load(path) s = checkpoint["state_dict"] new_s = {} @@ -45,6 +44,7 @@ def read_images(img_list): frames = [] print('reading images...') for img_path in tqdm(img_list): + print(f'read image path:{img_path}') frame = cv2.imread(img_path) frames.append(frame) return frames @@ -63,17 +63,25 @@ def __mirror_index(size, index): # python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face # .\face\img00016.jpg --audio .\audio\audio1.wav def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue): + logging.info(f'Using {device} for inference.') + print(f'Using {device} for inference.') + + print(f'face_images_path: {face_images_path}') + model = load_model(r'.\checkpoints\wav2lip.pth') face_list_cycle = read_images(face_images_path) face_images_length = len(face_list_cycle) - logger.info(f'face images length: {face_images_length}') + logging.info(f'face images length: {face_images_length}') + print(f'face images length: {face_images_length}') length = len(face_list_cycle) index = 0 count = 0 count_time = 0 - logger.info('start inference') + logging.info('start inference') + print(f'start inference: {render_event.is_set()}') while render_event.is_set(): + print('start inference') try: mel_batch = audio_feat_queue.get(block=True, timeout=1) except queue.Empty: @@ -88,6 +96,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi if type == 0: is_all_silence = False + print(f'is_all_silence {is_all_silence}') if is_all_silence: for i in range(batch_size): res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2])) @@ -117,7 +126,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi count_time += (time.perf_counter() - t) count += batch_size if count >= 100: - logger.info(f"------actual avg infer fps:{count/count_time:.4f}") + logging.info(f"------actual avg infer fps:{count/count_time:.4f}") count = 0 count_time = 0 @@ -125,7 +134,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2])) index = index + 1 - logger.info('finish inference') + logging.info('finish inference') class Human: @@ -142,9 +151,11 @@ class Human: self._output_queue = mp.Queue() self._res_frame_queue = mp.Queue(self._batch_size * 2) - self.face_images_path = r'.\face' + face_images_path = r'./face/' + self._face_image_paths = utils.read_files_path(face_images_path) + print(self._face_image_paths) self.render_event = mp.Event() - mp.Process(target=inference, args=(self.render_event, self._batch_size, self.face_images_path, + mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths, self._feat_queue, self._output_queue, self._res_frame_queue, )).start() @@ -168,7 +179,7 @@ class Human: self._chunk_2_mal.stop() if self._tts is not None: self._tts.stop() - logger.info('human destroy') + logging.info('human destroy') def set_tts(self, tts): if self._tts == tts: @@ -180,7 +191,7 @@ class Human: def read(self, txt): if self._tts is None: - logger.warning('tts is none') + logging.warning('tts is none') return self._tts.push_txt(txt) @@ -193,6 +204,13 @@ class Human: self._feat_queue.put(mel_chunks) print("22") + def render(self): + try: + img, aud = self._res_frame_queue.get(block=True, timeout=.3) + except queue.Empty: + return None + return img + # def pull_audio_chunk(self): # try: # chunk = self._audio_chunk_queue.get(block=True, timeout=1.0) diff --git a/ui.py b/ui.py index 1ace440..e6684ba 100644 --- a/ui.py +++ b/ui.py @@ -6,7 +6,7 @@ import tkinter import tkinter.messagebox import customtkinter import requests -from urllib.parse import urlencode +from PIL import Image, ImageTk from Human import Human from tts.EdgeTTS import EdgeTTS @@ -49,9 +49,9 @@ class App(customtkinter.CTk): self._init_image_canvas() self._human = Human() - tts = EdgeTTS(self._human) self._human.set_tts(tts) + self._render() def on_destroy(self): logger.info('------------App destroy------------') @@ -61,6 +61,26 @@ class App(customtkinter.CTk): self._canvas = customtkinter.CTkCanvas(self.image_frame) self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES) + def _render(self): + image = self._human.render() + if image is None: + self.after(100, self._render) + return + + img = Image.fromarray(image) + imgtk = ImageTk.PhotoImage(image=img) + + self._canvas.delete("all") + + self._canvas.imgtk = imgtk + + width = self.winfo_width() * 0.5 + height = self.winfo_height() * 0.5 + + self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk) + self._canvas.update() + self.after(30, self._render) + def request_tts(self): content = self.entry.get() print('content:', content) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..f5b614f --- /dev/null +++ b/utils.py @@ -0,0 +1,12 @@ +#encoding = utf8 +import os + + +def read_files_path(path): + file_paths = [] + files = os.listdir(path) + for file in files: + if not os.path.isdir(file): + file_paths.append(path + file) + return file_paths +