Compare commits

...

1 Commits

Author SHA1 Message Date
jocelyn
e219702ee2 [ADD]add logic of loop frame 2025-06-10 15:20:17 +08:00
3 changed files with 1122 additions and 0 deletions

119
utils/log.py Normal file
View File

@ -0,0 +1,119 @@
import logging
import os
import sys
import redis
from loguru import logger as logurulogger
import redis.exceptions
from app.config import config
import json
from redis.retry import Retry
from redis.backoff import ExponentialBackoff
LOG_FORMAT = (
"<level>{level: <8}</level> "
"{process.name} | " # 进程名
"{thread.name} | "
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> - "
"<blue>{process}</blue> "
"<cyan>{module}</cyan>.<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
"<level>{message}</level>"
)
LOG_NAME = ["uvicorn", "uvicorn.access", "uvicorn.error", "flask"]
# 配置 Redis 连接池
redis_pool = redis.ConnectionPool(
host=config.LOG_REDIS_HOST, # Redis 服务器地址
port=config.LOG_REDIS_PORT, # Redis 服务器端口
db=config.LOG_REDIS_DB, # 数据库编号
password=config.LOG_REDIS_AUTH, # 密码
max_connections=config.max_connections, # 最大连接数
socket_connect_timeout=config.socket_connect_timeout, # 连接超时时间
socket_timeout=config.socket_timeout, # 等待超时时间
)
class InterceptHandler(logging.Handler):
def emit(self, record):
try:
level = logurulogger.level(record.levelname).name
except AttributeError:
level = logging._levelToName[record.levelno]
frame, depth = logging.currentframe(), 2
while frame.f_code.co_filename == logging.__file__:
frame = frame.f_back
depth += 1
logurulogger.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
class Logging:
"""自定义日志"""
def __init__(self):
self.log_path = "logs"
self._connect_redis()
if config.IS_LOCAL:
os.makedirs(self.log_path, exist_ok=True)
self._initlogger()
self._reset_log_handler()
def _connect_redis(self):
retry = Retry(ExponentialBackoff(), 3) # 重试3次指数退避
self.redis_client = redis.Redis(connection_pool=redis_pool,retry=retry) # 使用连接池
def _initlogger(self):
"""初始化loguru配置"""
logurulogger.remove()
if config.IS_LOCAL:
logurulogger.add(
os.path.join(self.log_path, "error.log.{time:YYYY-MM-DD}"),
format=LOG_FORMAT,
level=logging.ERROR,
rotation="00:00",
retention="1 week",
backtrace=True,
diagnose=True,
enqueue=True
)
logurulogger.add(
os.path.join(self.log_path, "info.log.{time:YYYY-MM-DD}"),
format=LOG_FORMAT,
level=logging.INFO,
rotation="00:00",
retention="1 week",
enqueue=True
)
logurulogger.add(
sys.stdout,
format=LOG_FORMAT,
level=logging.DEBUG,
colorize=True,
)
logurulogger.add(self._log_to_redis, level="INFO", format=LOG_FORMAT)
self.logger = logurulogger
def _log_to_redis(self, message):
"""将日志写入 Redis 列表"""
try:
self.redis_client.rpush(f"nlp.logger.{config.env_version}.log", json.dumps({"message": message}))
except redis.exceptions.ConnectionError as e:
logger.error(f"write {message} Redis connection error: {e}")
except redis.exceptions.TimeoutError as e:
logger.error(f"write {message} Redis operation timed out: {e}")
except Exception as e:
logger.error(f"write {message} Unexpected error: {e}")
def _reset_log_handler(self):
for log in LOG_NAME:
logger = logging.getLogger(log)
logger.handlers = [InterceptHandler()]
def getlogger(self):
return self.logger
logger = Logging().getlogger()

334
utils/loop_frame_tool.py Normal file
View File

@ -0,0 +1,334 @@
from utils.log import logger
def play_in_loop_v2(
segments,
startfrom,
batch_num,
last_direction,
is_silent,
first_speak,
last_speak,
):
"""
batch_num: 初始和结束每一帧都这么判断
1静默时在静默段循环 左边界正向右边界反向, 根据上一次方向和位置给出新的方向和位置
2静默转说话 就近到说话段pre_falg, post_flag, 都为true VS 其中一个为true
3说话转静默 动作段播完再进入静默(如果还在持续说话静默段不循环)
4在整个视频左端点 开始端只能正向静默时循环说话时走2
5在整个视频右端点 开始时只能反向静默时循环说话时走2
6根据方向获取batch_num 数量的视频帧return batch_idxes, current_direction
Args:
segments: 循环帧配置 [[st, ed, True], ...]
startfrom: cur_pos
batch_num: 5
last_direction: 0反向1正向
is_silent: 0说话态1动作态
first_speak: 记录是不是第一次讲话
last_speak: 记录是不是讲话结束那一刻
"""
frames = []
cur_pos = startfrom
cur_direction = last_direction
is_first_speak_frame = first_speak
is_last_speak_frame = True if last_speak and batch_num == 1 else False
while batch_num != 0:
# 获取当前帧的所在子分割区间
sub_seg_idx = subseg_judge(cur_pos, segments)
# 获取移动方向
next_direction, next_pos = get_next_direction(
segments,
cur_pos,
cur_direction,
is_silent,
sub_seg_idx,
is_first_speak_frame,
is_last_speak_frame,
)
# 获取指定方向的帧
next_pos = get_next_frame(next_pos, next_direction)
frames.append(next_pos)
batch_num -= 1
is_first_speak_frame = (
True if first_speak and batch_num == config.batch_size else False
)
is_last_speak_frame = True if last_speak and batch_num == 1 else False
cur_direction = next_direction
cur_pos = next_pos
return frames, next_direction
def subseg_judge(cur_pos, segments):
for idx, frame_seg in enumerate(segments):
if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]:
return idx
if cur_pos == 0:
return 0
def get_next_direction(
segments,
cur_pos,
cur_direction,
is_silent,
sub_seg_idx,
is_first_speak_frame: bool = False,
is_last_speak_frame: bool = False,
):
"""
3.3.0 循环帧需求想尽快走到预期状态
if 动作段
if 开始说话
if 边界
if 正向
pass
else:
pass
else:
if 正向
pass
else:
pass
elif 静默:
同上
elif 说话中
同上
elif 说话结束
同上
elif 静默段
同上
Args:
is_first_speak_frame: 开始说话flag
is_last_speak_frame 说话结束flag
"""
left, right, loop_flag = segments[sub_seg_idx]
if loop_flag:
if is_silent == 1:
next_direct, next_pos = pure_silent(
segments, left, right, cur_pos, cur_direction, sub_seg_idx
)
logger.debug(
f"cur_pos{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}"
)
elif is_silent == 0:
next_direct, next_pos = silent2action(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_first_speak_frame,
)
logger.debug(
f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}"
)
else:
if is_silent == 1:
next_direct, next_pos = action2silent(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_last_speak_frame,
)
logger.debug(
f"cur_pos{cur_pos} next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}"
)
elif is_silent == 0:
next_direct, next_pos = pure_action(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_last_speak_frame,
)
logger.debug(
f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}"
)
return next_direct, next_pos
def get_next_frame(cur_pos, cur_direction):
"""根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的
不需要再做额外的边界判断
"""
# 正向
if cur_direction == 1:
return cur_pos + 1
# 反向
elif cur_direction == 0:
return cur_pos - 1
def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx):
"""
loop_flag == True and is_silent==1
whether border
whether forward
Return:
next_direction
"""
# 左边界正向,右边界反向
if cur_pos == segments[0][0]:
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
# 右边界,反向
if cur_pos == right:
return 0, cur_pos
# 左边界,正向
if cur_pos == left:
return 1, cur_pos
# 非边界,之前正向,则继续正向,否则反向
if cur_pos > left and cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
def pure_action(
segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
):
"""
loop_flag ==False and is_silent == 0
动作播完正向到静默段 (存在跳段行为)
whether border
whether forward # 正播反播
Args:
is_last_speak_frame: 最后说话结束时刻
Return: next_direction
"""
if cur_pos == segments[0][0]:
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
if is_last_speak_frame:
# 动作段在末尾,向前找静默
if sub_seg_idx == len(segments) - 1:
return 0, cur_pos
# 动作段在开始, 向后
if sub_seg_idx == 0:
return 1, cur_pos
# 动作段在中间,就近原则
mid = left + (right - left + 1) // 2
# 就近原则优先
if cur_pos < mid:
return 0, cur_pos
else:
return 1, cur_pos
else:
# 其他情况,播放方向一致
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
def silent2action(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_first_speak_frame: bool = False,
):
"""
在静默区间但是在讲话
loop_flag=True and is_silent == 0
whether border
whether forward
Return: next_direction
"""
# 向最近的动作段移动, 如果左面没有动作段
# TODO: 确认下面逻辑是否正确
if (
cur_pos == segments[0][0]
): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
# 在静默左边界处,且仍在讲话
if cur_pos == left:
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
# 在静默右边界处,且仍在讲话
elif cur_pos == right:
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
else:
mid = left + (right - left + 1) // 2
# !!就近原则只对第一次说话有效,其他情况遵循上一次状态
if is_first_speak_frame:
# 如果第一段
if sub_seg_idx == 0 and segments[0][2]:
return 1, cur_pos
# 如果最后一段
elif sub_seg_idx == len(segments) - 1 and segments[-1][2]:
return 0, cur_pos
if cur_pos < mid:
return 0, cur_pos
else:
return 1, cur_pos
else:
if cur_direction == 1:
return 1, cur_pos
elif cur_direction == 0:
return 0, cur_pos
def action2silent(
segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
):
"""
loop_flag=False and is_silent==1
whether border
Return: next_direction
"""
if cur_pos == segments[0][0]:
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
# 动作段,说话结束转静默情况下,就近原则,进入静默
if is_last_speak_frame:
mid = left + (right - left + 1) // 2
if cur_pos < mid:
return 0, cur_pos
else:
return 1, cur_pos
else:
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
if __name__ == "__main__":
startfrom = 0 # 上一个batch的最后一帧
frame_config= []
audio_frame_length = len(mel_chunks) # TODO: 确认是否为 batch_size
startfrom = startfrom if startfrom>= frame_config[0][0] else frame_config[0][0]
first_speak, last_speak = True, False
is_silent= True # 当前batch是否为静默
last_direction = 1 # -1 为反方向
start_idx_list, current_direction = play_in_loop_v2(
frame_config,
startfrom,
audio_frame_length,
last_direction,
is_silent,
first_speak,
last_speak,
)

669
utils/wav2lip_processor.py Normal file
View File

@ -0,0 +1,669 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@File : self.py
@Time : 2023/10/11 18:23:49
@Author : LiWei
@Version : 1.0
"""
import os
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor
import cv2
import numpy as np
import torch
import constants
from config import config
from face_parsing.model import BiSeNet
from face_parsing.swap import cal_mask_single_img, image_to_parsing_ori
from high_perf.buffer.infer_buffer import Infer_Cache
from high_perf.buffer.video_buff import Video_Cache
from log import logger
from models.wav2lipv2 import Wav2Lip
from utils.util import (
add_alpha,
datagen,
get_actor_id,
get_model_local_path,
load_face_box,
load_seg_model,
load_w2l_model,
morphing,
play_in_loop_v2,
save_raw_video,
save_raw_wav,
get_trans_idxes,
load_config
)
USING_JIT = False
LOCAL_MODE = False
index = 1
def save_(current_speechid, result_frames, audio_segments):
raw_vid = save_raw_video(current_speechid, result_frames)
raw_wav = save_raw_wav(current_speechid, audio_segments)
command = "ffmpeg -y -i {} -i {} -c:v copy -c:a aac -strict experimental -shortest {}".format(
raw_wav, raw_vid, f"temp/{current_speechid}_result_final.mp4"
)
import platform
import subprocess
subprocess.call(command, shell=platform.system() != "Windows")
os.remove(raw_wav)
class Wav2lip_Processor:
def __init__(
self,
task_queue,
result_queue,
stop_event,
resolution,
channel,
device: str = "cuda:0",
) -> None:
self.task_queue = task_queue
self.result_queue = result_queue
self.stop_event = stop_event
self.write_lock = threading.Lock()
self.device = device
self.resolution = resolution
self.channel = channel
pool_size = config.read_frame_pool_size
self.using_pool = True
if pool_size <= 0:
self.using_pool = False
self.img_reading_pool = None
if self.using_pool:
self.img_reading_pool = ThreadPoolExecutor(max_workers=pool_size)
self.face_det_results = {}
self.w2l_models = {}
self.seg_net = None # 分割模型
mel_block_size = (config.wav2lip_batch_size, 80, 16)
audio_block_size = (np.prod(mel_block_size).tolist(), 1)
# 缓存区最大,缓存两秒的视频,超出的会被覆盖掉
self.video_cache = Video_Cache(
"video_cache",
"audio_audio_cache",
"speech_id_cache",
self.resolution,
channel=self.channel,
audio_block_size=audio_block_size[0],
create=True,
)
self.data_cache = Infer_Cache(
"infer_cache",
create=True,
mel_block_shape=mel_block_size,
audio_block_shape=audio_block_size,
)
# TODO: 这个verision 不生效
self.version = "v2"
self.current_speechid = ""
# torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1])
# torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1])
self.model_mel_batch = torch.rand((5, 1, 80, 16), dtype=torch.float).to(
self.device
)
self.model_img_batch = torch.rand((5, 6, 288, 288), dtype=torch.float).to(
self.device
)
self.infer_silent_idx = 1
self.post_morphint_idx = 1
self.first_start = True # 判断是否为开播瞬间的静默
self.file_index = 1
def prepare(self, model_speaker_id: str, need_update: bool = True):
"""加载好模型文件"""
if not need_update:
return
self.model_url, self.frame_config, self.padding,self.face_classes,self.trans_method, self.infer_silent_num, self.morphing_num, pkl_version = load_config(model_speaker_id)
models_url_list = [self.model_url]
actor_info = [model_speaker_id]
for actor_url in actor_info:
# 这里不再下载资源,相关工作移到主进程
actor_id = get_actor_id(actor_url)
self.face_det_results[actor_id] = load_face_box(actor_id, pkl_version)
self.w2l_process_state = "READY"
self.__init_seg_model(f"{config.model_folder}/79999_iter.pth")
# 预先加载所有model 模型
# TODO: 采用旧的加载方式看下模型model 是否有错
for model_url in models_url_list:
w2l_name = get_model_local_path(model_url)
self.__init_w2l_model(w2l_name)
logger.debug("Prepare w2l data.")
def __init_w2l_model(self, w2l_name):
""" """
# TODO 这个最好是可以放在启动直播间之前----本模型运行时可能会切换,目前咱们不同的模特可能模型不同
if w2l_name in self.w2l_models:
return
w2l_path = f"{config.model_folder}/{w2l_name}"
if USING_JIT:
net = self.__jit_script(f"{config.model_folder}/w2l.pt", model=Wav2Lip())
else:
net = Wav2Lip()
logger.debug(f"加载模型 {w2l_name}")
self.w2l_models[w2l_name] = load_w2l_model(w2l_path, net, self.device)
def __jit_script(self, jit_module_path: str, model):
"""
convert pytorch module to a runnable scripts.
"""
if not os.path.exists(jit_module_path):
scripted_module = torch.jit.script(model)
torch.jit.save(scripted_module, jit_module_path)
net = torch.jit.load(jit_module_path)
return net
def __init_seg_model(self, seg_path: str):
""" """
# 本模型可以视作永不改变
n_classes = 19
net = BiSeNet(n_classes=n_classes)
if USING_JIT:
net = self.__jit_script(f"{config.model_folder}/swap.pt", model=net)
self.seg_net = load_seg_model(seg_path, net, self.device)
def __using_w2lmodel(self, model_url):
""" """
# 取出对应的模型
model_path = get_model_local_path(model_url)
if model_path in self.w2l_models:
return self.w2l_models[model_path]
else:
self.__init_w2l_model(model_path)
return self.w2l_models[model_path]
def infer(
self,
is_silent,
inference_id,
pkl_name,
startfrom=0,
last_direction=1,
duration=0,
end_with_silent=False,
first_speak=False,
last_speak=False,
before_speak=False,
model_url=None,
debug: bool = False,
test_data: tuple = (),
video_idx: list = [],
):
""" """
start = time.perf_counter()
# TODO: 这是啥逻辑
is_silent_ = 0 if duration < 5 else 1
if not model_url:
model_url = self.model_url
model = self.__using_w2lmodel(model_url)
# 控制加载数据人的个数
self.current_speechid = inference_id
if debug:
audio_segments, mel_chunks = test_data
else:
audio_segments, mel_chunks = self.data_cache.get_data()
audio_frame_length = len(mel_chunks)
if video_idx:
mel_chunks, start_idx_list, current_direction = mel_chunks, video_idx, 1
else:
startfrom = startfrom if startfrom>= self.frame_config[0][0] else self.frame_config[0][0]
start_idx_list, current_direction = play_in_loop_v2(
self.frame_config,
startfrom,
audio_frame_length,
last_direction,
is_silent,
is_silent_,
first_speak,
last_speak,
)
logger.info(
f"start_idx_list :{start_idx_list}, speechid: {self.current_speechid}"
)
gan = datagen(
self.img_reading_pool,
mel_chunks,
start_idx_list,
pkl_name,
self.face_det_results[pkl_name],
self.padding,
is_silent,
)
self.file_index += 5
# 3.4.0 增加静默推理过渡融合, 说话后的n帧向静默过渡
need_pre_morphing=False
if self.trans_method == constants.TRANS_METHOD.mophing_infer:
if self.infer_silent_num or self.morphing_num or config.PRE_MORPHING_NUM:
logger.info(f"self.first_start{self.first_start}, {before_speak}, {self.current_speechid}")
if (
is_silent == 0
or (
before_speak)
or (
not self.first_start
and is_silent == 1
and ((self.infer_silent_idx < self.infer_silent_num) or (self.post_morphint_idx < self.morphing_num))
)
):
# 开始说话状态时候,要把上次计数清理
need_infer_silent = (
True
if is_silent == 1
and self.infer_silent_idx < self.infer_silent_num and not before_speak and "&" not in self.current_speechid
else False
)
need_pre_morphing = True if before_speak else False
need_post_morphing = (
True
if is_silent == 1
and self.morphing_num
and self.post_morphint_idx < self.morphing_num
and not need_infer_silent
and not need_pre_morphing
else False
)
if is_silent == 0 or "&" in self.current_speechid:
self.infer_silent_idx = 1
self.post_morphint_idx = 1
self.first_start = False
logger.debug(
f"{start_idx_list} is_silent:{is_silent}, infer_silent_idx: {self.infer_silent_idx} morphint_idx:{self.post_morphint_idx} need_post_morphing: {need_post_morphing}, need_infer_silent:{need_infer_silent}, speech:{self.current_speechid},need_pre_morphing:{need_pre_morphing}"
)
if is_silent == 1 and need_post_morphing and self.post_morphint_idx < self.morphing_num:
self.post_morphint_idx += 5
result_frames = self.prediction(
gan,
model,
pkl_name,
need_post_morphing=need_post_morphing,
need_infer_silent=need_infer_silent,
is_silent=is_silent,
need_pre_morphing=need_pre_morphing,
)
if is_silent == 1 and need_infer_silent and self.infer_silent_idx < self.infer_silent_num:
self.infer_silent_idx += 5
# 第一次静默时候,直接进入这个逻辑
else:
self.no_prediction(gan, model, pkl_name)
else:
if is_silent==0:
self.prediction(
gan,
model,
pkl_name,
is_silent=is_silent,
need_pre_morphing=need_pre_morphing,
)
else:
self.no_prediction(gan, model, pkl_name)
elif self.trans_method == constants.TRANS_METHOD.all_infer:
result_frames = self.prediction(gan, model, pkl_name,is_silent=is_silent)
else:
raise ValueError(f"not supported {self.trans_method}")
if LOCAL_MODE:
saveThread = threading.Thread(
target=save_,
args=(self.current_speechid, result_frames, audio_segments),
)
saveThread.start()
self.video_cache.put_audio(audio_segments, self.current_speechid)
logger.debug(
f"视频合成时间{time.perf_counter() - start},inference_id:{self.current_speechid}"
)
return [
[0, 1, 2, 3, 4],
start_idx_list[-1],
current_direction,
] # frames_return_list: 视频帧数据 res_index 控制下一帧的开始位置 direction: 播放的顺序 正 反
@torch.no_grad()
def batch_to_tensor(self, img_batch, mel_batch, model, padding, frames, coords):
padding = 10
logger.debug(f"非静默,推理: {self.current_speechid} 准备推理数据 {id(model)}")
img_batch_tensor = torch.as_tensor(img_batch, dtype=torch.float32).to(
self.device, non_blocking=True
)
mel_batch_tensor = torch.as_tensor(mel_batch, dtype=torch.float32).to(
self.device, non_blocking=True
)
img_batch_tensor = img_batch_tensor.permute(0, 3, 1, 2)
mel_batch_tensor = mel_batch_tensor.permute(0, 3, 1, 2)
logger.debug(
f"非静默,推理: {self.current_speechid} 即将嘴型生成 {img_batch_tensor.shape} {mel_batch_tensor.shape}"
) # torch.Size([5, 6, 288, 288]) torch.Size([5, 1, 80, 16])
pred_batch = model(mel_batch_tensor, img_batch_tensor) * 255.0
logger.debug(f"非静默,推理: {self.current_speechid} 嘴型生成完成")
pred_clone_batch = pred_batch.clone()
pred_batch_cpu = pred_batch.to(torch.uint8).to("cpu", non_blocking=True)
del pred_batch
del mel_batch_tensor
logger.debug(f"非静默,推理: {self.current_speechid} 生成面部准备好了")
large_faces = [
frame.copy()[:, :, :3][
y1 - padding : y2 + padding, x1 - padding : x2 + padding
]
for frame, box in zip(frames, coords)
for y1, y2, x1, x2 in [box]
]
logger.debug(f"非静默,推理: {self.current_speechid} 即将对生成的面部进行分割")
seg_out = image_to_parsing_ori(
pred_clone_batch, self.seg_net
) # 假设这个函数在 GPU 上执行并输出 GPU tensor
del img_batch_tensor
del pred_clone_batch
torch.cuda.synchronize() # 等待异步结束
seg_out_cpu = seg_out.to(
"cpu", non_blocking=True
) # 发起 seg_out 的异步数据传输
del seg_out
logger.debug(f"非静默,推理: {self.current_speechid} 对生成的面部进行分割结束")
pred_batch_cpu = pred_batch_cpu.numpy().transpose(
0, 2, 3, 1
) # 按需计算,用到了才计算,这样能不能使得GPU往cpu做数据copy的时间被隐藏
for predict, frame, box in zip(pred_batch_cpu, frames, coords):
y1, y2, x1, x2 = box
width, height = x2 - x1, y2 - y1
frame[:, :, :3][y1:y2, x1:x2] = cv2.resize(
predict, (width, height)
) # 无需再转为unit8gpu上直接转为unit8这样传输规模小一些
torch.cuda.synchronize() # 等待异步结束
return large_faces, seg_out_cpu, frames
def no_prediction(self, gan, model, pkl_name):
logger.debug(
f"无需推理: {self.current_speechid} {self.model_img_batch.shape} {self.model_mel_batch.shape}"
)
with torch.no_grad():
pred_ = model(self.model_mel_batch, self.model_img_batch)
seg_out = image_to_parsing_ori(
pred_, self.seg_net
) # 不确定是不是需要也warmup第二个网络
for _, _, frames, full_masks, _, end_idx in gan:
speech_ids = [
f"{self.current_speechid}_{idx}" for idx in range(len(end_idx))
]
offset_list = self.video_cache.occupy_video_pos(len(end_idx))
file_idxes, _, _ = get_trans_idxes(False, False, 0,0, self.file_index)
logger.info(f"self.file_index:{self.file_index} file_idxes:{file_idxes}")
param_list = [
(
speech_id,
frame,
body_mask,
write_pos,
pkl_name,
self.video_cache,
file_idx
)
for speech_id, frame, body_mask, write_pos, file_idx in zip(
speech_ids,
frames,
full_masks,
offset_list,
file_idxes
)
]
if self.using_pool:
futures = [
self.img_reading_pool.submit(self.silent_paste_back, *param)
for param in param_list
]
_ = [future.result() for future in futures]
self.video_cache._inject_real_pos(len(end_idx))
# if config.debug_mode:
# for param in param_list:
# self.silent_paste_back(*param)
del pred_
del seg_out
torch.cuda.empty_cache()
def silent_paste_back(self, speech_id, frame, body_mask,write_pos,pkl_name, video_cache, file_index):
global index
if frame.shape[-1] == 4:
frame[:, :, :3] = frame[:, :, :3][:, :, ::-1]
elif config.output_alpha:
frame = add_alpha(frame, body_mask, config.alpha)
logger.debug(f"self.file_index:{self.file_index}, file_index{file_index}")
if config.debug_mode:
logger.info(f"不用推理: {file_index} {frame.shape}")
if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame):
logger.error(f"save {file_index} err")
index += 1
logger.info(f"silent frame shape:{frame.shape}")
video_cache._put_raw_frame(frame, write_pos, speech_id)
return frame
def prediction(
self,
gan,
model,
pkl_name: str,
need_post_morphing: bool = False,
need_infer_silent: bool = False,
is_silent: int = 1,
need_pre_morphing: bool = False,
):
""" """
logger.debug(f"非静默,推理: {self.current_speechid}")
for img_batch, mel_batch, frames, full_masks, coords, end_idx in gan:
large_faces, seg_out_cpu, frames = self.batch_to_tensor(
img_batch, mel_batch, model, self.padding, frames, coords
)
offset_list = self.video_cache.occupy_video_pos(len(end_idx))
speech_ids = [
f"{self.current_speechid}_{idx}" for idx in range(len(end_idx))
]
file_idxes, post_morphint_idxes, infer_silent_idxes = get_trans_idxes(need_post_morphing, need_infer_silent,self.post_morphint_idx,self.infer_silent_idx, self.file_index)
logger.info(f"self.file_index:{self.file_index}infer_silent_idxes:{infer_silent_idxes},post_morphint_idxes:{post_morphint_idxes} file_idxes:{file_idxes}")
param_list = [
(
speech_id,
padded_image,
frame,
body_mask,
seg_mask,
boxes,
self.padding,
self.video_cache,
write_pos,
pkl_name,
need_infer_silent,
need_post_morphing,
is_silent,
need_pre_morphing,
pre_morphing_idx,
infer_silent_idx,
post_morphint_idx,
file_idx
)
for speech_id, padded_image, frame, body_mask, seg_mask, boxes, write_pos, pre_morphing_idx,infer_silent_idx, post_morphint_idx,file_idx in zip(
speech_ids,
large_faces,
frames,
full_masks,
seg_out_cpu,
coords,
offset_list,
list(range(config.PRE_MORPHING_NUM)),
infer_silent_idxes,
post_morphint_idxes,
file_idxes
)
]
result_frames = []
if self.using_pool:# and not config.debug_mode:
futures = [
self.img_reading_pool.submit(self.paste_back, *param)
for param in param_list
]
_ = [future.result() for future in futures]
self.video_cache._inject_real_pos(len(end_idx))
# if config.debug_mode:
# for param in param_list:
# frame = self.paste_back(*param)
# if LOCAL_MODE:
# result_frames.append(frame)
return result_frames
def paste_back(
self,
current_speechid,
large_face,
frame,
body_mask,
mask,
boxes,
padding,
video_cache,
cache_start_offset,
pkl_name,
need_infer_silent: bool = False,
need_post_morphing: bool = False,
is_silent: bool = False,
need_pre_morphing: bool = False,
pre_morphing_idx: int = 0,
infer_silent_idx: int = 0,
post_morphint_idx: int = 0,
file_index: int = 1,
):
"""
根据模型预测结果和原始图片进行融合
Args:
"""
global index
padding = 10
y1, y2, x1, x2 = boxes
width, height = x2 - x1, y2 - y1
mask = cal_mask_single_img(
mask, use_old_mode=True, face_classes=self.face_classes
)
mask = np.repeat(mask[..., None], 3, axis=-1).astype("uint8")
mask_temp = np.zeros_like(large_face)
mask_out = cv2.resize(mask.astype(np.float) * 255.0, (width, height)).astype(
np.uint8
)
mask_temp[padding : height + padding, padding : width + padding] = mask_out
kernel = np.ones((9, 9), np.uint8)
mask_temp = cv2.erode(mask_temp, kernel, iterations=1) # 二值的
# gaosi_kernel = int(0.1 * large_face.shape[0] // 2 * 2) + 1
# mask_temp = cv2.GaussianBlur(
# mask_temp, (gaosi_kernel, gaosi_kernel), 0, 0, cv2.BORDER_DEFAULT
# )
mask_temp = cv2.GaussianBlur(mask_temp, (15, 15), 0, 0, cv2.BORDER_DEFAULT)
mask_temp = cv2.GaussianBlur(mask_temp, (5, 5), 0, 0, cv2.BORDER_DEFAULT)
f_background = large_face.copy()
frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding] = (
f_background * (1 - mask_temp / 255.0)
+ frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding]
* (mask_temp / 255.0)
)#.astype("uint8")
if self.trans_method == constants.TRANS_METHOD.mophing_infer:
if is_silent == 1:
if need_pre_morphing:
frame = morphing(
large_face,
frame,
boxes,
mp_ratio=1 - ((pre_morphing_idx + 1) / config.PRE_MORPHING_NUM),
file_index=file_index,
)
logger.debug(
f"file_index{file_index},pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}"
)
logger.debug(f"pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}")
#TODO: @txueduo 处理前过渡问题
elif need_post_morphing and post_morphint_idx:
mp_ratio = (post_morphint_idx) / self.morphing_num
frame = morphing(
large_face,
frame,
boxes,
mp_ratio=mp_ratio,
file_index=file_index
)
logger.debug(f"post_morphint_idx:{post_morphint_idx}, mp_ratio:{mp_ratio}, file_index:{file_index}, speech_id:{current_speechid}")
if frame.shape[-1]==4:# and not config.output_alpha:
frame[:,:,:3] = frame[:,:,:3][:,:,::-1]
if config.output_alpha and frame.shape[-1]!=4:
frame = add_alpha(frame, body_mask, config.alpha)
if config.debug_mode:
logger.info(f"推理:{file_index}")
if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame):
logger.error(f"save {file_index} err")
video_cache._put_raw_frame(frame, cache_start_offset, current_speechid)
index += 1
return frame
def destroy(self):
""" """
self.data_cache.destroy()
self.video_cache.destroy()
self.data_cache = None
self.video_cache = None
if self.using_pool and self.img_reading_pool is not None:
self.img_reading_pool.shutdown()
self.img_reading_pool = None
del self.model_img_batch
del self.model_mel_batch
def process_wav2lip_predict(task_queue, stop_event, output_queue, resolution, channel):
# 实例化推理类
w2l_processor = Wav2lip_Processor(
task_queue, output_queue, stop_event, resolution, channel
)
need_update = True
while True:
if stop_event.is_set():
print("----------------------stop..")
w2l_processor.destroy() # 这一行代码可能有bug没有同步的控制代码可能在读的地方会有异常
break
try:
params = task_queue.get()
# TODO: 这个操作应该放在更前面, 确认是否更新的判定条件
start = time.perf_counter()
w2l_processor.prepare(params[2], need_update)
need_update = False
result = w2l_processor.infer(*params)
logger.debug(
f"推理结束 :{time.perf_counter() - start},inference_id:{params[1]}"
)
output_queue.put(result)
logger.debug(
f"结果通知到主进程{time.perf_counter() - start},inference_id:{params[1]}"
)
except Exception:
logger.error(f"process_wav2lip_predict :{traceback.format_exc()}")