human/utils/wav2lip_processor.py
2025-06-10 15:20:17 +08:00

670 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@File : self.py
@Time : 2023/10/11 18:23:49
@Author : LiWei
@Version : 1.0
"""
import os
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
import cv2
import numpy as np
import torch
import constants
from config import config
from face_parsing.model import BiSeNet
from face_parsing.swap import cal_mask_single_img, image_to_parsing_ori
from high_perf.buffer.infer_buffer import Infer_Cache
from high_perf.buffer.video_buff import Video_Cache
from log import logger
from models.wav2lipv2 import Wav2Lip
from utils.util import (
add_alpha,
datagen,
get_actor_id,
get_model_local_path,
load_face_box,
load_seg_model,
load_w2l_model,
morphing,
play_in_loop_v2,
save_raw_video,
save_raw_wav,
get_trans_idxes,
load_config
)
USING_JIT = False
LOCAL_MODE = False
index = 1
def save_(current_speechid, result_frames, audio_segments):
raw_vid = save_raw_video(current_speechid, result_frames)
raw_wav = save_raw_wav(current_speechid, audio_segments)
command = "ffmpeg -y -i {} -i {} -c:v copy -c:a aac -strict experimental -shortest {}".format(
raw_wav, raw_vid, f"temp/{current_speechid}_result_final.mp4"
)
import platform
import subprocess
subprocess.call(command, shell=platform.system() != "Windows")
os.remove(raw_wav)
class Wav2lip_Processor:
def __init__(
self,
task_queue,
result_queue,
stop_event,
resolution,
channel,
device: str = "cuda:0",
) -> None:
self.task_queue = task_queue
self.result_queue = result_queue
self.stop_event = stop_event
self.write_lock = threading.Lock()
self.device = device
self.resolution = resolution
self.channel = channel
pool_size = config.read_frame_pool_size
self.using_pool = True
if pool_size <= 0:
self.using_pool = False
self.img_reading_pool = None
if self.using_pool:
self.img_reading_pool = ThreadPoolExecutor(max_workers=pool_size)
self.face_det_results = {}
self.w2l_models = {}
self.seg_net = None # 分割模型
mel_block_size = (config.wav2lip_batch_size, 80, 16)
audio_block_size = (np.prod(mel_block_size).tolist(), 1)
# 缓存区最大,缓存两秒的视频,超出的会被覆盖掉
self.video_cache = Video_Cache(
"video_cache",
"audio_audio_cache",
"speech_id_cache",
self.resolution,
channel=self.channel,
audio_block_size=audio_block_size[0],
create=True,
)
self.data_cache = Infer_Cache(
"infer_cache",
create=True,
mel_block_shape=mel_block_size,
audio_block_shape=audio_block_size,
)
# TODO: 这个verision 不生效
self.version = "v2"
self.current_speechid = ""
# torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1])
# torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1])
self.model_mel_batch = torch.rand((5, 1, 80, 16), dtype=torch.float).to(
self.device
)
self.model_img_batch = torch.rand((5, 6, 288, 288), dtype=torch.float).to(
self.device
)
self.infer_silent_idx = 1
self.post_morphint_idx = 1
self.first_start = True # 判断是否为开播瞬间的静默
self.file_index = 1
def prepare(self, model_speaker_id: str, need_update: bool = True):
"""加载好模型文件"""
if not need_update:
return
self.model_url, self.frame_config, self.padding,self.face_classes,self.trans_method, self.infer_silent_num, self.morphing_num, pkl_version = load_config(model_speaker_id)
models_url_list = [self.model_url]
actor_info = [model_speaker_id]
for actor_url in actor_info:
# 这里不再下载资源,相关工作移到主进程
actor_id = get_actor_id(actor_url)
self.face_det_results[actor_id] = load_face_box(actor_id, pkl_version)
self.w2l_process_state = "READY"
self.__init_seg_model(f"{config.model_folder}/79999_iter.pth")
# 预先加载所有model 模型
# TODO: 采用旧的加载方式看下模型model 是否有错
for model_url in models_url_list:
w2l_name = get_model_local_path(model_url)
self.__init_w2l_model(w2l_name)
logger.debug("Prepare w2l data.")
def __init_w2l_model(self, w2l_name):
""" """
# TODO 这个最好是可以放在启动直播间之前----本模型运行时可能会切换,目前咱们不同的模特可能模型不同
if w2l_name in self.w2l_models:
return
w2l_path = f"{config.model_folder}/{w2l_name}"
if USING_JIT:
net = self.__jit_script(f"{config.model_folder}/w2l.pt", model=Wav2Lip())
else:
net = Wav2Lip()
logger.debug(f"加载模型 {w2l_name}")
self.w2l_models[w2l_name] = load_w2l_model(w2l_path, net, self.device)
def __jit_script(self, jit_module_path: str, model):
"""
convert pytorch module to a runnable scripts.
"""
if not os.path.exists(jit_module_path):
scripted_module = torch.jit.script(model)
torch.jit.save(scripted_module, jit_module_path)
net = torch.jit.load(jit_module_path)
return net
def __init_seg_model(self, seg_path: str):
""" """
# 本模型可以视作永不改变
n_classes = 19
net = BiSeNet(n_classes=n_classes)
if USING_JIT:
net = self.__jit_script(f"{config.model_folder}/swap.pt", model=net)
self.seg_net = load_seg_model(seg_path, net, self.device)
def __using_w2lmodel(self, model_url):
""" """
# 取出对应的模型
model_path = get_model_local_path(model_url)
if model_path in self.w2l_models:
return self.w2l_models[model_path]
else:
self.__init_w2l_model(model_path)
return self.w2l_models[model_path]
def infer(
self,
is_silent,
inference_id,
pkl_name,
startfrom=0,
last_direction=1,
duration=0,
end_with_silent=False,
first_speak=False,
last_speak=False,
before_speak=False,
model_url=None,
debug: bool = False,
test_data: tuple = (),
video_idx: list = [],
):
""" """
start = time.perf_counter()
# TODO: 这是啥逻辑
is_silent_ = 0 if duration < 5 else 1
if not model_url:
model_url = self.model_url
model = self.__using_w2lmodel(model_url)
# 控制加载数据人的个数
self.current_speechid = inference_id
if debug:
audio_segments, mel_chunks = test_data
else:
audio_segments, mel_chunks = self.data_cache.get_data()
audio_frame_length = len(mel_chunks)
if video_idx:
mel_chunks, start_idx_list, current_direction = mel_chunks, video_idx, 1
else:
startfrom = startfrom if startfrom>= self.frame_config[0][0] else self.frame_config[0][0]
start_idx_list, current_direction = play_in_loop_v2(
self.frame_config,
startfrom,
audio_frame_length,
last_direction,
is_silent,
is_silent_,
first_speak,
last_speak,
)
logger.info(
f"start_idx_list :{start_idx_list}, speechid: {self.current_speechid}"
)
gan = datagen(
self.img_reading_pool,
mel_chunks,
start_idx_list,
pkl_name,
self.face_det_results[pkl_name],
self.padding,
is_silent,
)
self.file_index += 5
# 3.4.0 增加静默推理过渡融合, 说话后的n帧向静默过渡
need_pre_morphing=False
if self.trans_method == constants.TRANS_METHOD.mophing_infer:
if self.infer_silent_num or self.morphing_num or config.PRE_MORPHING_NUM:
logger.info(f"self.first_start{self.first_start}, {before_speak}, {self.current_speechid}")
if (
is_silent == 0
or (
before_speak)
or (
not self.first_start
and is_silent == 1
and ((self.infer_silent_idx < self.infer_silent_num) or (self.post_morphint_idx < self.morphing_num))
)
):
# 开始说话状态时候,要把上次计数清理
need_infer_silent = (
True
if is_silent == 1
and self.infer_silent_idx < self.infer_silent_num and not before_speak and "&" not in self.current_speechid
else False
)
need_pre_morphing = True if before_speak else False
need_post_morphing = (
True
if is_silent == 1
and self.morphing_num
and self.post_morphint_idx < self.morphing_num
and not need_infer_silent
and not need_pre_morphing
else False
)
if is_silent == 0 or "&" in self.current_speechid:
self.infer_silent_idx = 1
self.post_morphint_idx = 1
self.first_start = False
logger.debug(
f"{start_idx_list} is_silent:{is_silent}, infer_silent_idx: {self.infer_silent_idx} morphint_idx:{self.post_morphint_idx} need_post_morphing: {need_post_morphing}, need_infer_silent:{need_infer_silent}, speech:{self.current_speechid},need_pre_morphing:{need_pre_morphing}"
)
if is_silent == 1 and need_post_morphing and self.post_morphint_idx < self.morphing_num:
self.post_morphint_idx += 5
result_frames = self.prediction(
gan,
model,
pkl_name,
need_post_morphing=need_post_morphing,
need_infer_silent=need_infer_silent,
is_silent=is_silent,
need_pre_morphing=need_pre_morphing,
)
if is_silent == 1 and need_infer_silent and self.infer_silent_idx < self.infer_silent_num:
self.infer_silent_idx += 5
# 第一次静默时候,直接进入这个逻辑
else:
self.no_prediction(gan, model, pkl_name)
else:
if is_silent==0:
self.prediction(
gan,
model,
pkl_name,
is_silent=is_silent,
need_pre_morphing=need_pre_morphing,
)
else:
self.no_prediction(gan, model, pkl_name)
elif self.trans_method == constants.TRANS_METHOD.all_infer:
result_frames = self.prediction(gan, model, pkl_name,is_silent=is_silent)
else:
raise ValueError(f"not supported {self.trans_method}")
if LOCAL_MODE:
saveThread = threading.Thread(
target=save_,
args=(self.current_speechid, result_frames, audio_segments),
)
saveThread.start()
self.video_cache.put_audio(audio_segments, self.current_speechid)
logger.debug(
f"视频合成时间{time.perf_counter() - start},inference_id:{self.current_speechid}"
)
return [
[0, 1, 2, 3, 4],
start_idx_list[-1],
current_direction,
] # frames_return_list: 视频帧数据 res_index 控制下一帧的开始位置 direction: 播放的顺序 正 反
@torch.no_grad()
def batch_to_tensor(self, img_batch, mel_batch, model, padding, frames, coords):
padding = 10
logger.debug(f"非静默,推理: {self.current_speechid} 准备推理数据 {id(model)}")
img_batch_tensor = torch.as_tensor(img_batch, dtype=torch.float32).to(
self.device, non_blocking=True
)
mel_batch_tensor = torch.as_tensor(mel_batch, dtype=torch.float32).to(
self.device, non_blocking=True
)
img_batch_tensor = img_batch_tensor.permute(0, 3, 1, 2)
mel_batch_tensor = mel_batch_tensor.permute(0, 3, 1, 2)
logger.debug(
f"非静默,推理: {self.current_speechid} 即将嘴型生成 {img_batch_tensor.shape} {mel_batch_tensor.shape}"
) # torch.Size([5, 6, 288, 288]) torch.Size([5, 1, 80, 16])
pred_batch = model(mel_batch_tensor, img_batch_tensor) * 255.0
logger.debug(f"非静默,推理: {self.current_speechid} 嘴型生成完成")
pred_clone_batch = pred_batch.clone()
pred_batch_cpu = pred_batch.to(torch.uint8).to("cpu", non_blocking=True)
del pred_batch
del mel_batch_tensor
logger.debug(f"非静默,推理: {self.current_speechid} 生成面部准备好了")
large_faces = [
frame.copy()[:, :, :3][
y1 - padding : y2 + padding, x1 - padding : x2 + padding
]
for frame, box in zip(frames, coords)
for y1, y2, x1, x2 in [box]
]
logger.debug(f"非静默,推理: {self.current_speechid} 即将对生成的面部进行分割")
seg_out = image_to_parsing_ori(
pred_clone_batch, self.seg_net
) # 假设这个函数在 GPU 上执行并输出 GPU tensor
del img_batch_tensor
del pred_clone_batch
torch.cuda.synchronize() # 等待异步结束
seg_out_cpu = seg_out.to(
"cpu", non_blocking=True
) # 发起 seg_out 的异步数据传输
del seg_out
logger.debug(f"非静默,推理: {self.current_speechid} 对生成的面部进行分割结束")
pred_batch_cpu = pred_batch_cpu.numpy().transpose(
0, 2, 3, 1
) # 按需计算,用到了才计算,这样能不能使得GPU往cpu做数据copy的时间被隐藏
for predict, frame, box in zip(pred_batch_cpu, frames, coords):
y1, y2, x1, x2 = box
width, height = x2 - x1, y2 - y1
frame[:, :, :3][y1:y2, x1:x2] = cv2.resize(
predict, (width, height)
) # 无需再转为unit8gpu上直接转为unit8这样传输规模小一些
torch.cuda.synchronize() # 等待异步结束
return large_faces, seg_out_cpu, frames
def no_prediction(self, gan, model, pkl_name):
logger.debug(
f"无需推理: {self.current_speechid} {self.model_img_batch.shape} {self.model_mel_batch.shape}"
)
with torch.no_grad():
pred_ = model(self.model_mel_batch, self.model_img_batch)
seg_out = image_to_parsing_ori(
pred_, self.seg_net
) # 不确定是不是需要也warmup第二个网络
for _, _, frames, full_masks, _, end_idx in gan:
speech_ids = [
f"{self.current_speechid}_{idx}" for idx in range(len(end_idx))
]
offset_list = self.video_cache.occupy_video_pos(len(end_idx))
file_idxes, _, _ = get_trans_idxes(False, False, 0,0, self.file_index)
logger.info(f"self.file_index:{self.file_index} file_idxes:{file_idxes}")
param_list = [
(
speech_id,
frame,
body_mask,
write_pos,
pkl_name,
self.video_cache,
file_idx
)
for speech_id, frame, body_mask, write_pos, file_idx in zip(
speech_ids,
frames,
full_masks,
offset_list,
file_idxes
)
]
if self.using_pool:
futures = [
self.img_reading_pool.submit(self.silent_paste_back, *param)
for param in param_list
]
_ = [future.result() for future in futures]
self.video_cache._inject_real_pos(len(end_idx))
# if config.debug_mode:
# for param in param_list:
# self.silent_paste_back(*param)
del pred_
del seg_out
torch.cuda.empty_cache()
def silent_paste_back(self, speech_id, frame, body_mask,write_pos,pkl_name, video_cache, file_index):
global index
if frame.shape[-1] == 4:
frame[:, :, :3] = frame[:, :, :3][:, :, ::-1]
elif config.output_alpha:
frame = add_alpha(frame, body_mask, config.alpha)
logger.debug(f"self.file_index:{self.file_index}, file_index{file_index}")
if config.debug_mode:
logger.info(f"不用推理: {file_index} {frame.shape}")
if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame):
logger.error(f"save {file_index} err")
index += 1
logger.info(f"silent frame shape:{frame.shape}")
video_cache._put_raw_frame(frame, write_pos, speech_id)
return frame
def prediction(
self,
gan,
model,
pkl_name: str,
need_post_morphing: bool = False,
need_infer_silent: bool = False,
is_silent: int = 1,
need_pre_morphing: bool = False,
):
""" """
logger.debug(f"非静默,推理: {self.current_speechid}")
for img_batch, mel_batch, frames, full_masks, coords, end_idx in gan:
large_faces, seg_out_cpu, frames = self.batch_to_tensor(
img_batch, mel_batch, model, self.padding, frames, coords
)
offset_list = self.video_cache.occupy_video_pos(len(end_idx))
speech_ids = [
f"{self.current_speechid}_{idx}" for idx in range(len(end_idx))
]
file_idxes, post_morphint_idxes, infer_silent_idxes = get_trans_idxes(need_post_morphing, need_infer_silent,self.post_morphint_idx,self.infer_silent_idx, self.file_index)
logger.info(f"self.file_index:{self.file_index}infer_silent_idxes:{infer_silent_idxes},post_morphint_idxes:{post_morphint_idxes} file_idxes:{file_idxes}")
param_list = [
(
speech_id,
padded_image,
frame,
body_mask,
seg_mask,
boxes,
self.padding,
self.video_cache,
write_pos,
pkl_name,
need_infer_silent,
need_post_morphing,
is_silent,
need_pre_morphing,
pre_morphing_idx,
infer_silent_idx,
post_morphint_idx,
file_idx
)
for speech_id, padded_image, frame, body_mask, seg_mask, boxes, write_pos, pre_morphing_idx,infer_silent_idx, post_morphint_idx,file_idx in zip(
speech_ids,
large_faces,
frames,
full_masks,
seg_out_cpu,
coords,
offset_list,
list(range(config.PRE_MORPHING_NUM)),
infer_silent_idxes,
post_morphint_idxes,
file_idxes
)
]
result_frames = []
if self.using_pool:# and not config.debug_mode:
futures = [
self.img_reading_pool.submit(self.paste_back, *param)
for param in param_list
]
_ = [future.result() for future in futures]
self.video_cache._inject_real_pos(len(end_idx))
# if config.debug_mode:
# for param in param_list:
# frame = self.paste_back(*param)
# if LOCAL_MODE:
# result_frames.append(frame)
return result_frames
def paste_back(
self,
current_speechid,
large_face,
frame,
body_mask,
mask,
boxes,
padding,
video_cache,
cache_start_offset,
pkl_name,
need_infer_silent: bool = False,
need_post_morphing: bool = False,
is_silent: bool = False,
need_pre_morphing: bool = False,
pre_morphing_idx: int = 0,
infer_silent_idx: int = 0,
post_morphint_idx: int = 0,
file_index: int = 1,
):
"""
根据模型预测结果和原始图片进行融合
Args:
"""
global index
padding = 10
y1, y2, x1, x2 = boxes
width, height = x2 - x1, y2 - y1
mask = cal_mask_single_img(
mask, use_old_mode=True, face_classes=self.face_classes
)
mask = np.repeat(mask[..., None], 3, axis=-1).astype("uint8")
mask_temp = np.zeros_like(large_face)
mask_out = cv2.resize(mask.astype(np.float) * 255.0, (width, height)).astype(
np.uint8
)
mask_temp[padding : height + padding, padding : width + padding] = mask_out
kernel = np.ones((9, 9), np.uint8)
mask_temp = cv2.erode(mask_temp, kernel, iterations=1) # 二值的
# gaosi_kernel = int(0.1 * large_face.shape[0] // 2 * 2) + 1
# mask_temp = cv2.GaussianBlur(
# mask_temp, (gaosi_kernel, gaosi_kernel), 0, 0, cv2.BORDER_DEFAULT
# )
mask_temp = cv2.GaussianBlur(mask_temp, (15, 15), 0, 0, cv2.BORDER_DEFAULT)
mask_temp = cv2.GaussianBlur(mask_temp, (5, 5), 0, 0, cv2.BORDER_DEFAULT)
f_background = large_face.copy()
frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding] = (
f_background * (1 - mask_temp / 255.0)
+ frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding]
* (mask_temp / 255.0)
)#.astype("uint8")
if self.trans_method == constants.TRANS_METHOD.mophing_infer:
if is_silent == 1:
if need_pre_morphing:
frame = morphing(
large_face,
frame,
boxes,
mp_ratio=1 - ((pre_morphing_idx + 1) / config.PRE_MORPHING_NUM),
file_index=file_index,
)
logger.debug(
f"file_index{file_index},pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}"
)
logger.debug(f"pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}")
#TODO: @txueduo 处理前过渡问题
elif need_post_morphing and post_morphint_idx:
mp_ratio = (post_morphint_idx) / self.morphing_num
frame = morphing(
large_face,
frame,
boxes,
mp_ratio=mp_ratio,
file_index=file_index
)
logger.debug(f"post_morphint_idx:{post_morphint_idx}, mp_ratio:{mp_ratio}, file_index:{file_index}, speech_id:{current_speechid}")
if frame.shape[-1]==4:# and not config.output_alpha:
frame[:,:,:3] = frame[:,:,:3][:,:,::-1]
if config.output_alpha and frame.shape[-1]!=4:
frame = add_alpha(frame, body_mask, config.alpha)
if config.debug_mode:
logger.info(f"推理:{file_index}")
if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame):
logger.error(f"save {file_index} err")
video_cache._put_raw_frame(frame, cache_start_offset, current_speechid)
index += 1
return frame
def destroy(self):
""" """
self.data_cache.destroy()
self.video_cache.destroy()
self.data_cache = None
self.video_cache = None
if self.using_pool and self.img_reading_pool is not None:
self.img_reading_pool.shutdown()
self.img_reading_pool = None
del self.model_img_batch
del self.model_mel_batch
def process_wav2lip_predict(task_queue, stop_event, output_queue, resolution, channel):
# 实例化推理类
w2l_processor = Wav2lip_Processor(
task_queue, output_queue, stop_event, resolution, channel
)
need_update = True
while True:
if stop_event.is_set():
print("----------------------stop..")
w2l_processor.destroy() # 这一行代码可能有bug没有同步的控制代码可能在读的地方会有异常
break
try:
params = task_queue.get()
# TODO: 这个操作应该放在更前面, 确认是否更新的判定条件
start = time.perf_counter()
w2l_processor.prepare(params[2], need_update)
need_update = False
result = w2l_processor.infer(*params)
logger.debug(
f"推理结束 :{time.perf_counter() - start},inference_id:{params[1]}"
)
output_queue.put(result)
logger.debug(
f"结果通知到主进程{time.perf_counter() - start},inference_id:{params[1]}"
)
except Exception:
logger.error(f"process_wav2lip_predict :{traceback.format_exc()}")