modify human load face
This commit is contained in:
parent
b0a600c7b7
commit
90ccaa222b
40
Human.py
40
Human.py
@ -7,16 +7,14 @@ import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import utils
|
||||
from models import Wav2Lip
|
||||
from tts.Chunk2Mal import Chunk2Mal
|
||||
import torch
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
print('Using {} for inference.'.format(device))
|
||||
|
||||
|
||||
def _load(checkpoint_path):
|
||||
@ -31,6 +29,7 @@ def _load(checkpoint_path):
|
||||
def load_model(path):
|
||||
model = Wav2Lip()
|
||||
print("Load checkpoint from: {}".format(path))
|
||||
logging.info(f'Load checkpoint from {path}')
|
||||
checkpoint = _load(path)
|
||||
s = checkpoint["state_dict"]
|
||||
new_s = {}
|
||||
@ -45,6 +44,7 @@ def read_images(img_list):
|
||||
frames = []
|
||||
print('reading images...')
|
||||
for img_path in tqdm(img_list):
|
||||
print(f'read image path:{img_path}')
|
||||
frame = cv2.imread(img_path)
|
||||
frames.append(frame)
|
||||
return frames
|
||||
@ -63,17 +63,25 @@ def __mirror_index(size, index):
|
||||
# python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
|
||||
# .\face\img00016.jpg --audio .\audio\audio1.wav
|
||||
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
|
||||
logging.info(f'Using {device} for inference.')
|
||||
print(f'Using {device} for inference.')
|
||||
|
||||
print(f'face_images_path: {face_images_path}')
|
||||
|
||||
model = load_model(r'.\checkpoints\wav2lip.pth')
|
||||
face_list_cycle = read_images(face_images_path)
|
||||
face_images_length = len(face_list_cycle)
|
||||
logger.info(f'face images length: {face_images_length}')
|
||||
logging.info(f'face images length: {face_images_length}')
|
||||
print(f'face images length: {face_images_length}')
|
||||
|
||||
length = len(face_list_cycle)
|
||||
index = 0
|
||||
count = 0
|
||||
count_time = 0
|
||||
logger.info('start inference')
|
||||
logging.info('start inference')
|
||||
print(f'start inference: {render_event.is_set()}')
|
||||
while render_event.is_set():
|
||||
print('start inference')
|
||||
try:
|
||||
mel_batch = audio_feat_queue.get(block=True, timeout=1)
|
||||
except queue.Empty:
|
||||
@ -88,6 +96,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
|
||||
if type == 0:
|
||||
is_all_silence = False
|
||||
|
||||
print(f'is_all_silence {is_all_silence}')
|
||||
if is_all_silence:
|
||||
for i in range(batch_size):
|
||||
res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
|
||||
@ -117,7 +126,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
|
||||
count_time += (time.perf_counter() - t)
|
||||
count += batch_size
|
||||
if count >= 100:
|
||||
logger.info(f"------actual avg infer fps:{count/count_time:.4f}")
|
||||
logging.info(f"------actual avg infer fps:{count/count_time:.4f}")
|
||||
count = 0
|
||||
count_time = 0
|
||||
|
||||
@ -125,7 +134,7 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
|
||||
res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
|
||||
index = index + 1
|
||||
|
||||
logger.info('finish inference')
|
||||
logging.info('finish inference')
|
||||
|
||||
|
||||
class Human:
|
||||
@ -142,9 +151,11 @@ class Human:
|
||||
self._output_queue = mp.Queue()
|
||||
self._res_frame_queue = mp.Queue(self._batch_size * 2)
|
||||
|
||||
self.face_images_path = r'.\face'
|
||||
face_images_path = r'./face/'
|
||||
self._face_image_paths = utils.read_files_path(face_images_path)
|
||||
print(self._face_image_paths)
|
||||
self.render_event = mp.Event()
|
||||
mp.Process(target=inference, args=(self.render_event, self._batch_size, self.face_images_path,
|
||||
mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths,
|
||||
self._feat_queue, self._output_queue, self._res_frame_queue,
|
||||
)).start()
|
||||
|
||||
@ -168,7 +179,7 @@ class Human:
|
||||
self._chunk_2_mal.stop()
|
||||
if self._tts is not None:
|
||||
self._tts.stop()
|
||||
logger.info('human destroy')
|
||||
logging.info('human destroy')
|
||||
|
||||
def set_tts(self, tts):
|
||||
if self._tts == tts:
|
||||
@ -180,7 +191,7 @@ class Human:
|
||||
|
||||
def read(self, txt):
|
||||
if self._tts is None:
|
||||
logger.warning('tts is none')
|
||||
logging.warning('tts is none')
|
||||
return
|
||||
|
||||
self._tts.push_txt(txt)
|
||||
@ -193,6 +204,13 @@ class Human:
|
||||
self._feat_queue.put(mel_chunks)
|
||||
print("22")
|
||||
|
||||
def render(self):
|
||||
try:
|
||||
img, aud = self._res_frame_queue.get(block=True, timeout=.3)
|
||||
except queue.Empty:
|
||||
return None
|
||||
return img
|
||||
|
||||
# def pull_audio_chunk(self):
|
||||
# try:
|
||||
# chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
|
||||
|
24
ui.py
24
ui.py
@ -6,7 +6,7 @@ import tkinter
|
||||
import tkinter.messagebox
|
||||
import customtkinter
|
||||
import requests
|
||||
from urllib.parse import urlencode
|
||||
from PIL import Image, ImageTk
|
||||
|
||||
from Human import Human
|
||||
from tts.EdgeTTS import EdgeTTS
|
||||
@ -49,9 +49,9 @@ class App(customtkinter.CTk):
|
||||
self._init_image_canvas()
|
||||
|
||||
self._human = Human()
|
||||
|
||||
tts = EdgeTTS(self._human)
|
||||
self._human.set_tts(tts)
|
||||
self._render()
|
||||
|
||||
def on_destroy(self):
|
||||
logger.info('------------App destroy------------')
|
||||
@ -61,6 +61,26 @@ class App(customtkinter.CTk):
|
||||
self._canvas = customtkinter.CTkCanvas(self.image_frame)
|
||||
self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES)
|
||||
|
||||
def _render(self):
|
||||
image = self._human.render()
|
||||
if image is None:
|
||||
self.after(100, self._render)
|
||||
return
|
||||
|
||||
img = Image.fromarray(image)
|
||||
imgtk = ImageTk.PhotoImage(image=img)
|
||||
|
||||
self._canvas.delete("all")
|
||||
|
||||
self._canvas.imgtk = imgtk
|
||||
|
||||
width = self.winfo_width() * 0.5
|
||||
height = self.winfo_height() * 0.5
|
||||
|
||||
self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk)
|
||||
self._canvas.update()
|
||||
self.after(30, self._render)
|
||||
|
||||
def request_tts(self):
|
||||
content = self.entry.get()
|
||||
print('content:', content)
|
||||
|
Loading…
Reference in New Issue
Block a user