[ADD]add logic of loop frame
This commit is contained in:
parent
2bd94b9680
commit
e219702ee2
119
utils/log.py
Normal file
119
utils/log.py
Normal file
@ -0,0 +1,119 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import redis
|
||||
from loguru import logger as logurulogger
|
||||
import redis.exceptions
|
||||
from app.config import config
|
||||
import json
|
||||
from redis.retry import Retry
|
||||
from redis.backoff import ExponentialBackoff
|
||||
|
||||
LOG_FORMAT = (
|
||||
"<level>{level: <8}</level> "
|
||||
"{process.name} | " # 进程名
|
||||
"{thread.name} | "
|
||||
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> - "
|
||||
"<blue>{process}</blue> "
|
||||
"<cyan>{module}</cyan>.<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
||||
"<level>{message}</level>"
|
||||
)
|
||||
LOG_NAME = ["uvicorn", "uvicorn.access", "uvicorn.error", "flask"]
|
||||
|
||||
# 配置 Redis 连接池
|
||||
redis_pool = redis.ConnectionPool(
|
||||
host=config.LOG_REDIS_HOST, # Redis 服务器地址
|
||||
port=config.LOG_REDIS_PORT, # Redis 服务器端口
|
||||
db=config.LOG_REDIS_DB, # 数据库编号
|
||||
password=config.LOG_REDIS_AUTH, # 密码
|
||||
max_connections=config.max_connections, # 最大连接数
|
||||
socket_connect_timeout=config.socket_connect_timeout, # 连接超时时间
|
||||
socket_timeout=config.socket_timeout, # 等待超时时间
|
||||
)
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
level = logurulogger.level(record.levelname).name
|
||||
except AttributeError:
|
||||
level = logging._levelToName[record.levelno]
|
||||
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logurulogger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
class Logging:
|
||||
"""自定义日志"""
|
||||
|
||||
def __init__(self):
|
||||
self.log_path = "logs"
|
||||
self._connect_redis()
|
||||
if config.IS_LOCAL:
|
||||
os.makedirs(self.log_path, exist_ok=True)
|
||||
self._initlogger()
|
||||
self._reset_log_handler()
|
||||
|
||||
def _connect_redis(self):
|
||||
retry = Retry(ExponentialBackoff(), 3) # 重试3次,指数退避
|
||||
self.redis_client = redis.Redis(connection_pool=redis_pool,retry=retry) # 使用连接池
|
||||
|
||||
def _initlogger(self):
|
||||
"""初始化loguru配置"""
|
||||
logurulogger.remove()
|
||||
if config.IS_LOCAL:
|
||||
logurulogger.add(
|
||||
os.path.join(self.log_path, "error.log.{time:YYYY-MM-DD}"),
|
||||
format=LOG_FORMAT,
|
||||
level=logging.ERROR,
|
||||
rotation="00:00",
|
||||
retention="1 week",
|
||||
backtrace=True,
|
||||
diagnose=True,
|
||||
enqueue=True
|
||||
)
|
||||
logurulogger.add(
|
||||
os.path.join(self.log_path, "info.log.{time:YYYY-MM-DD}"),
|
||||
format=LOG_FORMAT,
|
||||
level=logging.INFO,
|
||||
rotation="00:00",
|
||||
retention="1 week",
|
||||
enqueue=True
|
||||
)
|
||||
logurulogger.add(
|
||||
sys.stdout,
|
||||
format=LOG_FORMAT,
|
||||
level=logging.DEBUG,
|
||||
colorize=True,
|
||||
)
|
||||
|
||||
logurulogger.add(self._log_to_redis, level="INFO", format=LOG_FORMAT)
|
||||
self.logger = logurulogger
|
||||
|
||||
|
||||
def _log_to_redis(self, message):
|
||||
"""将日志写入 Redis 列表"""
|
||||
try:
|
||||
self.redis_client.rpush(f"nlp.logger.{config.env_version}.log", json.dumps({"message": message}))
|
||||
except redis.exceptions.ConnectionError as e:
|
||||
logger.error(f"write {message} Redis connection error: {e}")
|
||||
except redis.exceptions.TimeoutError as e:
|
||||
logger.error(f"write {message} Redis operation timed out: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"write {message} Unexpected error: {e}")
|
||||
|
||||
def _reset_log_handler(self):
|
||||
for log in LOG_NAME:
|
||||
logger = logging.getLogger(log)
|
||||
logger.handlers = [InterceptHandler()]
|
||||
|
||||
def getlogger(self):
|
||||
return self.logger
|
||||
|
||||
logger = Logging().getlogger()
|
||||
|
334
utils/loop_frame_tool.py
Normal file
334
utils/loop_frame_tool.py
Normal file
@ -0,0 +1,334 @@
|
||||
from utils.log import logger
|
||||
|
||||
|
||||
def play_in_loop_v2(
|
||||
segments,
|
||||
startfrom,
|
||||
batch_num,
|
||||
last_direction,
|
||||
is_silent,
|
||||
first_speak,
|
||||
last_speak,
|
||||
):
|
||||
"""
|
||||
batch_num: 初始和结束,每一帧都这么判断
|
||||
1、静默时,在静默段循环, 左边界正向,右边界反向, 根据上一次方向和位置,给出新的方向和位置
|
||||
2、静默转说话: 就近到说话段,pre_falg, post_flag, 都为true VS 其中一个为true
|
||||
3、说话转静默: 动作段播完,再进入静默(如果还在持续说话,静默段不循环)
|
||||
4、在整个视频左端点: 开始端只能正向,静默时循环,说话时走2
|
||||
5、在整个视频右端点: 开始时只能反向,静默时循环,说话时走2
|
||||
6、根据方向获取batch_num 数量的视频帧,return batch_idxes, current_direction
|
||||
Args:
|
||||
segments: 循环帧配置 [[st, ed, True], ...]
|
||||
startfrom: cur_pos
|
||||
batch_num: 5
|
||||
last_direction: 0反向1正向
|
||||
is_silent: 0说话态1动作态
|
||||
first_speak: 记录是不是第一次讲话
|
||||
last_speak: 记录是不是讲话结束那一刻
|
||||
"""
|
||||
frames = []
|
||||
cur_pos = startfrom
|
||||
cur_direction = last_direction
|
||||
is_first_speak_frame = first_speak
|
||||
is_last_speak_frame = True if last_speak and batch_num == 1 else False
|
||||
while batch_num != 0:
|
||||
# 获取当前帧的所在子分割区间
|
||||
sub_seg_idx = subseg_judge(cur_pos, segments)
|
||||
# 获取移动方向
|
||||
next_direction, next_pos = get_next_direction(
|
||||
segments,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
is_silent,
|
||||
sub_seg_idx,
|
||||
is_first_speak_frame,
|
||||
is_last_speak_frame,
|
||||
)
|
||||
# 获取指定方向的帧
|
||||
next_pos = get_next_frame(next_pos, next_direction)
|
||||
frames.append(next_pos)
|
||||
batch_num -= 1
|
||||
is_first_speak_frame = (
|
||||
True if first_speak and batch_num == config.batch_size else False
|
||||
)
|
||||
is_last_speak_frame = True if last_speak and batch_num == 1 else False
|
||||
|
||||
cur_direction = next_direction
|
||||
cur_pos = next_pos
|
||||
return frames, next_direction
|
||||
|
||||
|
||||
def subseg_judge(cur_pos, segments):
|
||||
for idx, frame_seg in enumerate(segments):
|
||||
if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]:
|
||||
return idx
|
||||
if cur_pos == 0:
|
||||
return 0
|
||||
|
||||
def get_next_direction(
|
||||
segments,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
is_silent,
|
||||
sub_seg_idx,
|
||||
is_first_speak_frame: bool = False,
|
||||
is_last_speak_frame: bool = False,
|
||||
):
|
||||
"""
|
||||
3.3.0 循环帧需求,想尽快走到预期状态
|
||||
if 动作段:
|
||||
if 开始说话:
|
||||
if 边界:
|
||||
if 正向:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
if 正向:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
elif 静默:
|
||||
同上
|
||||
elif 说话中:
|
||||
同上
|
||||
elif 说话结束:
|
||||
同上
|
||||
elif 静默段:
|
||||
同上
|
||||
Args:
|
||||
is_first_speak_frame: 开始说话flag
|
||||
is_last_speak_frame: 说话结束flag
|
||||
"""
|
||||
left, right, loop_flag = segments[sub_seg_idx]
|
||||
if loop_flag:
|
||||
if is_silent == 1:
|
||||
next_direct, next_pos = pure_silent(
|
||||
segments, left, right, cur_pos, cur_direction, sub_seg_idx
|
||||
)
|
||||
logger.debug(
|
||||
f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}"
|
||||
)
|
||||
elif is_silent == 0:
|
||||
next_direct, next_pos = silent2action(
|
||||
segments,
|
||||
left,
|
||||
right,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
sub_seg_idx,
|
||||
is_first_speak_frame,
|
||||
)
|
||||
logger.debug(
|
||||
f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}"
|
||||
)
|
||||
else:
|
||||
if is_silent == 1:
|
||||
next_direct, next_pos = action2silent(
|
||||
segments,
|
||||
left,
|
||||
right,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
sub_seg_idx,
|
||||
is_last_speak_frame,
|
||||
)
|
||||
logger.debug(
|
||||
f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}"
|
||||
)
|
||||
elif is_silent == 0:
|
||||
next_direct, next_pos = pure_action(
|
||||
segments,
|
||||
left,
|
||||
right,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
sub_seg_idx,
|
||||
is_last_speak_frame,
|
||||
)
|
||||
logger.debug(
|
||||
f"cur_pos:{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}"
|
||||
)
|
||||
return next_direct, next_pos
|
||||
|
||||
def get_next_frame(cur_pos, cur_direction):
|
||||
"""根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的
|
||||
不需要再做额外的边界判断
|
||||
"""
|
||||
# 正向
|
||||
if cur_direction == 1:
|
||||
return cur_pos + 1
|
||||
# 反向
|
||||
elif cur_direction == 0:
|
||||
return cur_pos - 1
|
||||
|
||||
def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx):
|
||||
"""
|
||||
loop_flag == True and is_silent==1
|
||||
whether border
|
||||
whether forward
|
||||
Return:
|
||||
next_direction
|
||||
"""
|
||||
# 左边界正向,右边界反向
|
||||
if cur_pos == segments[0][0]:
|
||||
return 1, cur_pos
|
||||
if cur_pos == segments[-1][1]:
|
||||
return 0, cur_pos
|
||||
# 右边界,反向
|
||||
if cur_pos == right:
|
||||
return 0, cur_pos
|
||||
# 左边界,正向
|
||||
if cur_pos == left:
|
||||
return 1, cur_pos
|
||||
# 非边界,之前正向,则继续正向,否则反向
|
||||
if cur_pos > left and cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
return 0, cur_pos
|
||||
|
||||
|
||||
def pure_action(
|
||||
segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
|
||||
):
|
||||
"""
|
||||
loop_flag ==False and is_silent == 0
|
||||
动作播完,正向到静默段 (存在跳段行为)
|
||||
whether border
|
||||
whether forward # 正播反播
|
||||
Args:
|
||||
is_last_speak_frame: 最后说话结束时刻
|
||||
Return: next_direction
|
||||
"""
|
||||
if cur_pos == segments[0][0]:
|
||||
return 1, cur_pos
|
||||
if cur_pos == segments[-1][1]:
|
||||
return 0, cur_pos
|
||||
|
||||
if is_last_speak_frame:
|
||||
# 动作段在末尾,向前找静默
|
||||
if sub_seg_idx == len(segments) - 1:
|
||||
return 0, cur_pos
|
||||
# 动作段在开始, 向后
|
||||
if sub_seg_idx == 0:
|
||||
return 1, cur_pos
|
||||
# 动作段在中间,就近原则
|
||||
mid = left + (right - left + 1) // 2
|
||||
# 就近原则优先
|
||||
if cur_pos < mid:
|
||||
return 0, cur_pos
|
||||
else:
|
||||
return 1, cur_pos
|
||||
|
||||
else:
|
||||
# 其他情况,播放方向一致
|
||||
if cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
return 0, cur_pos
|
||||
|
||||
|
||||
def silent2action(
|
||||
segments,
|
||||
left,
|
||||
right,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
sub_seg_idx,
|
||||
is_first_speak_frame: bool = False,
|
||||
):
|
||||
"""
|
||||
在静默区间但是在讲话
|
||||
loop_flag=True and is_silent == 0
|
||||
whether border
|
||||
whether forward
|
||||
|
||||
Return: next_direction
|
||||
"""
|
||||
# 向最近的动作段移动, 如果左面没有动作段
|
||||
# TODO: 确认下面逻辑是否正确
|
||||
if (
|
||||
cur_pos == segments[0][0]
|
||||
): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行
|
||||
return 1, cur_pos
|
||||
if cur_pos == segments[-1][1]:
|
||||
return 0, cur_pos
|
||||
# 在静默左边界处,且仍在讲话
|
||||
if cur_pos == left:
|
||||
if cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
return 0, cur_pos
|
||||
# 在静默右边界处,且仍在讲话
|
||||
elif cur_pos == right:
|
||||
if cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
return 0, cur_pos
|
||||
else:
|
||||
mid = left + (right - left + 1) // 2
|
||||
# !!就近原则只对第一次说话有效,其他情况遵循上一次状态
|
||||
if is_first_speak_frame:
|
||||
# 如果第一段
|
||||
if sub_seg_idx == 0 and segments[0][2]:
|
||||
return 1, cur_pos
|
||||
# 如果最后一段
|
||||
elif sub_seg_idx == len(segments) - 1 and segments[-1][2]:
|
||||
return 0, cur_pos
|
||||
|
||||
if cur_pos < mid:
|
||||
return 0, cur_pos
|
||||
else:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
if cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
elif cur_direction == 0:
|
||||
return 0, cur_pos
|
||||
|
||||
|
||||
def action2silent(
|
||||
segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
|
||||
):
|
||||
"""
|
||||
loop_flag=False and is_silent==1
|
||||
whether border
|
||||
Return: next_direction
|
||||
"""
|
||||
if cur_pos == segments[0][0]:
|
||||
return 1, cur_pos
|
||||
if cur_pos == segments[-1][1]:
|
||||
return 0, cur_pos
|
||||
# 动作段,说话结束转静默情况下,就近原则,进入静默
|
||||
if is_last_speak_frame:
|
||||
mid = left + (right - left + 1) // 2
|
||||
if cur_pos < mid:
|
||||
return 0, cur_pos
|
||||
else:
|
||||
return 1, cur_pos
|
||||
|
||||
else:
|
||||
if cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
return 0, cur_pos
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
startfrom = 0 # 上一个batch的最后一帧
|
||||
frame_config= []
|
||||
audio_frame_length = len(mel_chunks) # TODO: 确认是否为 batch_size
|
||||
startfrom = startfrom if startfrom>= frame_config[0][0] else frame_config[0][0]
|
||||
first_speak, last_speak = True, False
|
||||
is_silent= True # 当前batch是否为静默
|
||||
last_direction = 1 # -1 为反方向
|
||||
start_idx_list, current_direction = play_in_loop_v2(
|
||||
frame_config,
|
||||
startfrom,
|
||||
audio_frame_length,
|
||||
last_direction,
|
||||
is_silent,
|
||||
first_speak,
|
||||
last_speak,
|
||||
)
|
669
utils/wav2lip_processor.py
Normal file
669
utils/wav2lip_processor.py
Normal file
@ -0,0 +1,669 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File : self.py
|
||||
@Time : 2023/10/11 18:23:49
|
||||
@Author : LiWei
|
||||
@Version : 1.0
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import constants
|
||||
from config import config
|
||||
from face_parsing.model import BiSeNet
|
||||
from face_parsing.swap import cal_mask_single_img, image_to_parsing_ori
|
||||
from high_perf.buffer.infer_buffer import Infer_Cache
|
||||
from high_perf.buffer.video_buff import Video_Cache
|
||||
from log import logger
|
||||
from models.wav2lipv2 import Wav2Lip
|
||||
from utils.util import (
|
||||
add_alpha,
|
||||
datagen,
|
||||
get_actor_id,
|
||||
get_model_local_path,
|
||||
load_face_box,
|
||||
load_seg_model,
|
||||
load_w2l_model,
|
||||
morphing,
|
||||
play_in_loop_v2,
|
||||
save_raw_video,
|
||||
save_raw_wav,
|
||||
get_trans_idxes,
|
||||
load_config
|
||||
)
|
||||
|
||||
USING_JIT = False
|
||||
LOCAL_MODE = False
|
||||
index = 1
|
||||
|
||||
|
||||
def save_(current_speechid, result_frames, audio_segments):
|
||||
raw_vid = save_raw_video(current_speechid, result_frames)
|
||||
raw_wav = save_raw_wav(current_speechid, audio_segments)
|
||||
command = "ffmpeg -y -i {} -i {} -c:v copy -c:a aac -strict experimental -shortest {}".format(
|
||||
raw_wav, raw_vid, f"temp/{current_speechid}_result_final.mp4"
|
||||
)
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
subprocess.call(command, shell=platform.system() != "Windows")
|
||||
os.remove(raw_wav)
|
||||
|
||||
|
||||
class Wav2lip_Processor:
|
||||
def __init__(
|
||||
self,
|
||||
task_queue,
|
||||
result_queue,
|
||||
stop_event,
|
||||
resolution,
|
||||
channel,
|
||||
device: str = "cuda:0",
|
||||
) -> None:
|
||||
self.task_queue = task_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.write_lock = threading.Lock()
|
||||
self.device = device
|
||||
self.resolution = resolution
|
||||
self.channel = channel
|
||||
|
||||
pool_size = config.read_frame_pool_size
|
||||
self.using_pool = True
|
||||
if pool_size <= 0:
|
||||
self.using_pool = False
|
||||
|
||||
self.img_reading_pool = None
|
||||
|
||||
if self.using_pool:
|
||||
self.img_reading_pool = ThreadPoolExecutor(max_workers=pool_size)
|
||||
|
||||
self.face_det_results = {}
|
||||
self.w2l_models = {}
|
||||
self.seg_net = None # 分割模型
|
||||
|
||||
mel_block_size = (config.wav2lip_batch_size, 80, 16)
|
||||
audio_block_size = (np.prod(mel_block_size).tolist(), 1)
|
||||
# 缓存区最大,缓存两秒的视频,超出的会被覆盖掉
|
||||
self.video_cache = Video_Cache(
|
||||
"video_cache",
|
||||
"audio_audio_cache",
|
||||
"speech_id_cache",
|
||||
self.resolution,
|
||||
channel=self.channel,
|
||||
audio_block_size=audio_block_size[0],
|
||||
create=True,
|
||||
)
|
||||
|
||||
self.data_cache = Infer_Cache(
|
||||
"infer_cache",
|
||||
create=True,
|
||||
mel_block_shape=mel_block_size,
|
||||
audio_block_shape=audio_block_size,
|
||||
)
|
||||
# TODO: 这个verision 不生效
|
||||
self.version = "v2"
|
||||
self.current_speechid = ""
|
||||
# torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1])
|
||||
# torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1])
|
||||
self.model_mel_batch = torch.rand((5, 1, 80, 16), dtype=torch.float).to(
|
||||
self.device
|
||||
)
|
||||
self.model_img_batch = torch.rand((5, 6, 288, 288), dtype=torch.float).to(
|
||||
self.device
|
||||
)
|
||||
self.infer_silent_idx = 1
|
||||
self.post_morphint_idx = 1
|
||||
self.first_start = True # 判断是否为开播瞬间的静默
|
||||
self.file_index = 1
|
||||
|
||||
def prepare(self, model_speaker_id: str, need_update: bool = True):
|
||||
"""加载好模型文件"""
|
||||
if not need_update:
|
||||
return
|
||||
self.model_url, self.frame_config, self.padding,self.face_classes,self.trans_method, self.infer_silent_num, self.morphing_num, pkl_version = load_config(model_speaker_id)
|
||||
models_url_list = [self.model_url]
|
||||
actor_info = [model_speaker_id]
|
||||
for actor_url in actor_info:
|
||||
# 这里不再下载资源,相关工作移到主进程
|
||||
actor_id = get_actor_id(actor_url)
|
||||
self.face_det_results[actor_id] = load_face_box(actor_id, pkl_version)
|
||||
self.w2l_process_state = "READY"
|
||||
self.__init_seg_model(f"{config.model_folder}/79999_iter.pth")
|
||||
# 预先加载所有model 模型
|
||||
# TODO: 采用旧的加载方式看下模型model 是否有错
|
||||
for model_url in models_url_list:
|
||||
w2l_name = get_model_local_path(model_url)
|
||||
self.__init_w2l_model(w2l_name)
|
||||
logger.debug("Prepare w2l data.")
|
||||
|
||||
def __init_w2l_model(self, w2l_name):
|
||||
""" """
|
||||
# TODO 这个最好是可以放在启动直播间之前----本模型运行时可能会切换,目前咱们不同的模特可能模型不同
|
||||
if w2l_name in self.w2l_models:
|
||||
return
|
||||
|
||||
w2l_path = f"{config.model_folder}/{w2l_name}"
|
||||
if USING_JIT:
|
||||
net = self.__jit_script(f"{config.model_folder}/w2l.pt", model=Wav2Lip())
|
||||
else:
|
||||
net = Wav2Lip()
|
||||
logger.debug(f"加载模型 {w2l_name}")
|
||||
self.w2l_models[w2l_name] = load_w2l_model(w2l_path, net, self.device)
|
||||
|
||||
def __jit_script(self, jit_module_path: str, model):
|
||||
"""
|
||||
convert pytorch module to a runnable scripts.
|
||||
"""
|
||||
if not os.path.exists(jit_module_path):
|
||||
scripted_module = torch.jit.script(model)
|
||||
torch.jit.save(scripted_module, jit_module_path)
|
||||
net = torch.jit.load(jit_module_path)
|
||||
return net
|
||||
|
||||
def __init_seg_model(self, seg_path: str):
|
||||
""" """
|
||||
# 本模型可以视作永不改变
|
||||
n_classes = 19
|
||||
net = BiSeNet(n_classes=n_classes)
|
||||
if USING_JIT:
|
||||
net = self.__jit_script(f"{config.model_folder}/swap.pt", model=net)
|
||||
self.seg_net = load_seg_model(seg_path, net, self.device)
|
||||
|
||||
def __using_w2lmodel(self, model_url):
|
||||
""" """
|
||||
# 取出对应的模型
|
||||
model_path = get_model_local_path(model_url)
|
||||
if model_path in self.w2l_models:
|
||||
return self.w2l_models[model_path]
|
||||
else:
|
||||
self.__init_w2l_model(model_path)
|
||||
return self.w2l_models[model_path]
|
||||
|
||||
def infer(
|
||||
self,
|
||||
is_silent,
|
||||
inference_id,
|
||||
pkl_name,
|
||||
startfrom=0,
|
||||
last_direction=1,
|
||||
duration=0,
|
||||
end_with_silent=False,
|
||||
first_speak=False,
|
||||
last_speak=False,
|
||||
before_speak=False,
|
||||
model_url=None,
|
||||
debug: bool = False,
|
||||
test_data: tuple = (),
|
||||
video_idx: list = [],
|
||||
):
|
||||
""" """
|
||||
start = time.perf_counter()
|
||||
# TODO: 这是啥逻辑
|
||||
is_silent_ = 0 if duration < 5 else 1
|
||||
if not model_url:
|
||||
model_url = self.model_url
|
||||
model = self.__using_w2lmodel(model_url)
|
||||
# 控制加载数据人的个数
|
||||
self.current_speechid = inference_id
|
||||
if debug:
|
||||
audio_segments, mel_chunks = test_data
|
||||
else:
|
||||
audio_segments, mel_chunks = self.data_cache.get_data()
|
||||
audio_frame_length = len(mel_chunks)
|
||||
|
||||
if video_idx:
|
||||
mel_chunks, start_idx_list, current_direction = mel_chunks, video_idx, 1
|
||||
else:
|
||||
startfrom = startfrom if startfrom>= self.frame_config[0][0] else self.frame_config[0][0]
|
||||
start_idx_list, current_direction = play_in_loop_v2(
|
||||
self.frame_config,
|
||||
startfrom,
|
||||
audio_frame_length,
|
||||
last_direction,
|
||||
is_silent,
|
||||
is_silent_,
|
||||
first_speak,
|
||||
last_speak,
|
||||
)
|
||||
logger.info(
|
||||
f"start_idx_list :{start_idx_list}, speechid: {self.current_speechid}"
|
||||
)
|
||||
gan = datagen(
|
||||
self.img_reading_pool,
|
||||
mel_chunks,
|
||||
start_idx_list,
|
||||
pkl_name,
|
||||
self.face_det_results[pkl_name],
|
||||
self.padding,
|
||||
is_silent,
|
||||
)
|
||||
self.file_index += 5
|
||||
# 3.4.0 增加静默推理过渡融合, 说话后的n帧向静默过渡
|
||||
need_pre_morphing=False
|
||||
if self.trans_method == constants.TRANS_METHOD.mophing_infer:
|
||||
if self.infer_silent_num or self.morphing_num or config.PRE_MORPHING_NUM:
|
||||
logger.info(f"self.first_start{self.first_start}, {before_speak}, {self.current_speechid}")
|
||||
if (
|
||||
is_silent == 0
|
||||
or (
|
||||
before_speak)
|
||||
or (
|
||||
not self.first_start
|
||||
and is_silent == 1
|
||||
and ((self.infer_silent_idx < self.infer_silent_num) or (self.post_morphint_idx < self.morphing_num))
|
||||
)
|
||||
):
|
||||
# 开始说话状态时候,要把上次计数清理
|
||||
need_infer_silent = (
|
||||
True
|
||||
if is_silent == 1
|
||||
and self.infer_silent_idx < self.infer_silent_num and not before_speak and "&" not in self.current_speechid
|
||||
else False
|
||||
)
|
||||
need_pre_morphing = True if before_speak else False
|
||||
need_post_morphing = (
|
||||
True
|
||||
if is_silent == 1
|
||||
and self.morphing_num
|
||||
and self.post_morphint_idx < self.morphing_num
|
||||
and not need_infer_silent
|
||||
and not need_pre_morphing
|
||||
else False
|
||||
)
|
||||
|
||||
if is_silent == 0 or "&" in self.current_speechid:
|
||||
self.infer_silent_idx = 1
|
||||
self.post_morphint_idx = 1
|
||||
|
||||
self.first_start = False
|
||||
logger.debug(
|
||||
f"{start_idx_list} is_silent:{is_silent}, infer_silent_idx: {self.infer_silent_idx} morphint_idx:{self.post_morphint_idx} need_post_morphing: {need_post_morphing}, need_infer_silent:{need_infer_silent}, speech:{self.current_speechid},need_pre_morphing:{need_pre_morphing}"
|
||||
)
|
||||
if is_silent == 1 and need_post_morphing and self.post_morphint_idx < self.morphing_num:
|
||||
self.post_morphint_idx += 5
|
||||
result_frames = self.prediction(
|
||||
gan,
|
||||
model,
|
||||
pkl_name,
|
||||
need_post_morphing=need_post_morphing,
|
||||
need_infer_silent=need_infer_silent,
|
||||
is_silent=is_silent,
|
||||
need_pre_morphing=need_pre_morphing,
|
||||
)
|
||||
if is_silent == 1 and need_infer_silent and self.infer_silent_idx < self.infer_silent_num:
|
||||
self.infer_silent_idx += 5
|
||||
|
||||
# 第一次静默时候,直接进入这个逻辑
|
||||
else:
|
||||
self.no_prediction(gan, model, pkl_name)
|
||||
else:
|
||||
if is_silent==0:
|
||||
self.prediction(
|
||||
gan,
|
||||
model,
|
||||
pkl_name,
|
||||
is_silent=is_silent,
|
||||
need_pre_morphing=need_pre_morphing,
|
||||
)
|
||||
else:
|
||||
self.no_prediction(gan, model, pkl_name)
|
||||
elif self.trans_method == constants.TRANS_METHOD.all_infer:
|
||||
result_frames = self.prediction(gan, model, pkl_name,is_silent=is_silent)
|
||||
else:
|
||||
raise ValueError(f"not supported {self.trans_method}")
|
||||
|
||||
if LOCAL_MODE:
|
||||
saveThread = threading.Thread(
|
||||
target=save_,
|
||||
args=(self.current_speechid, result_frames, audio_segments),
|
||||
)
|
||||
saveThread.start()
|
||||
self.video_cache.put_audio(audio_segments, self.current_speechid)
|
||||
|
||||
logger.debug(
|
||||
f"视频合成时间{time.perf_counter() - start},inference_id:{self.current_speechid}"
|
||||
)
|
||||
return [
|
||||
[0, 1, 2, 3, 4],
|
||||
start_idx_list[-1],
|
||||
current_direction,
|
||||
] # frames_return_list: 视频帧数据 res_index: 控制下一帧的开始位置 direction: 播放的顺序 正 反
|
||||
|
||||
@torch.no_grad()
|
||||
def batch_to_tensor(self, img_batch, mel_batch, model, padding, frames, coords):
|
||||
padding = 10
|
||||
logger.debug(f"非静默,推理: {self.current_speechid} 准备推理数据 {id(model)}")
|
||||
img_batch_tensor = torch.as_tensor(img_batch, dtype=torch.float32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
mel_batch_tensor = torch.as_tensor(mel_batch, dtype=torch.float32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
img_batch_tensor = img_batch_tensor.permute(0, 3, 1, 2)
|
||||
mel_batch_tensor = mel_batch_tensor.permute(0, 3, 1, 2)
|
||||
logger.debug(
|
||||
f"非静默,推理: {self.current_speechid} 即将嘴型生成 {img_batch_tensor.shape} {mel_batch_tensor.shape}"
|
||||
) # torch.Size([5, 6, 288, 288]) torch.Size([5, 1, 80, 16])
|
||||
|
||||
pred_batch = model(mel_batch_tensor, img_batch_tensor) * 255.0
|
||||
|
||||
logger.debug(f"非静默,推理: {self.current_speechid} 嘴型生成完成")
|
||||
pred_clone_batch = pred_batch.clone()
|
||||
pred_batch_cpu = pred_batch.to(torch.uint8).to("cpu", non_blocking=True)
|
||||
del pred_batch
|
||||
del mel_batch_tensor
|
||||
logger.debug(f"非静默,推理: {self.current_speechid} 生成面部准备好了")
|
||||
large_faces = [
|
||||
frame.copy()[:, :, :3][
|
||||
y1 - padding : y2 + padding, x1 - padding : x2 + padding
|
||||
]
|
||||
for frame, box in zip(frames, coords)
|
||||
for y1, y2, x1, x2 in [box]
|
||||
]
|
||||
logger.debug(f"非静默,推理: {self.current_speechid} 即将对生成的面部进行分割")
|
||||
seg_out = image_to_parsing_ori(
|
||||
pred_clone_batch, self.seg_net
|
||||
) # 假设这个函数在 GPU 上执行并输出 GPU tensor
|
||||
del img_batch_tensor
|
||||
del pred_clone_batch
|
||||
|
||||
torch.cuda.synchronize() # 等待异步结束
|
||||
seg_out_cpu = seg_out.to(
|
||||
"cpu", non_blocking=True
|
||||
) # 发起 seg_out 的异步数据传输
|
||||
del seg_out
|
||||
logger.debug(f"非静默,推理: {self.current_speechid} 对生成的面部进行分割结束")
|
||||
pred_batch_cpu = pred_batch_cpu.numpy().transpose(
|
||||
0, 2, 3, 1
|
||||
) # 按需计算,用到了才计算,这样能不能使得GPU往cpu做数据copy的时间被隐藏?
|
||||
for predict, frame, box in zip(pred_batch_cpu, frames, coords):
|
||||
y1, y2, x1, x2 = box
|
||||
width, height = x2 - x1, y2 - y1
|
||||
frame[:, :, :3][y1:y2, x1:x2] = cv2.resize(
|
||||
predict, (width, height)
|
||||
) # 无需再转为unit8,gpu上直接转为unit8,这样传输规模小一些
|
||||
torch.cuda.synchronize() # 等待异步结束
|
||||
return large_faces, seg_out_cpu, frames
|
||||
|
||||
def no_prediction(self, gan, model, pkl_name):
|
||||
logger.debug(
|
||||
f"无需推理: {self.current_speechid} {self.model_img_batch.shape} {self.model_mel_batch.shape}"
|
||||
)
|
||||
with torch.no_grad():
|
||||
pred_ = model(self.model_mel_batch, self.model_img_batch)
|
||||
seg_out = image_to_parsing_ori(
|
||||
pred_, self.seg_net
|
||||
) # 不确定是不是需要也warmup第二个网络
|
||||
for _, _, frames, full_masks, _, end_idx in gan:
|
||||
speech_ids = [
|
||||
f"{self.current_speechid}_{idx}" for idx in range(len(end_idx))
|
||||
]
|
||||
offset_list = self.video_cache.occupy_video_pos(len(end_idx))
|
||||
file_idxes, _, _ = get_trans_idxes(False, False, 0,0, self.file_index)
|
||||
logger.info(f"self.file_index:{self.file_index}, file_idxes:{file_idxes}")
|
||||
param_list = [
|
||||
(
|
||||
speech_id,
|
||||
frame,
|
||||
body_mask,
|
||||
write_pos,
|
||||
pkl_name,
|
||||
self.video_cache,
|
||||
file_idx
|
||||
)
|
||||
for speech_id, frame, body_mask, write_pos, file_idx in zip(
|
||||
speech_ids,
|
||||
frames,
|
||||
full_masks,
|
||||
offset_list,
|
||||
file_idxes
|
||||
)
|
||||
]
|
||||
if self.using_pool:
|
||||
futures = [
|
||||
self.img_reading_pool.submit(self.silent_paste_back, *param)
|
||||
for param in param_list
|
||||
]
|
||||
_ = [future.result() for future in futures]
|
||||
self.video_cache._inject_real_pos(len(end_idx))
|
||||
# if config.debug_mode:
|
||||
# for param in param_list:
|
||||
# self.silent_paste_back(*param)
|
||||
del pred_
|
||||
del seg_out
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def silent_paste_back(self, speech_id, frame, body_mask,write_pos,pkl_name, video_cache, file_index):
|
||||
global index
|
||||
if frame.shape[-1] == 4:
|
||||
frame[:, :, :3] = frame[:, :, :3][:, :, ::-1]
|
||||
elif config.output_alpha:
|
||||
frame = add_alpha(frame, body_mask, config.alpha)
|
||||
logger.debug(f"self.file_index:{self.file_index}, file_index:{file_index}")
|
||||
if config.debug_mode:
|
||||
logger.info(f"不用推理: {file_index} {frame.shape}")
|
||||
if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame):
|
||||
logger.error(f"save {file_index} err")
|
||||
index += 1
|
||||
logger.info(f"silent frame shape:{frame.shape}")
|
||||
video_cache._put_raw_frame(frame, write_pos, speech_id)
|
||||
return frame
|
||||
|
||||
|
||||
def prediction(
|
||||
self,
|
||||
gan,
|
||||
model,
|
||||
pkl_name: str,
|
||||
need_post_morphing: bool = False,
|
||||
need_infer_silent: bool = False,
|
||||
is_silent: int = 1,
|
||||
need_pre_morphing: bool = False,
|
||||
):
|
||||
""" """
|
||||
logger.debug(f"非静默,推理: {self.current_speechid}")
|
||||
for img_batch, mel_batch, frames, full_masks, coords, end_idx in gan:
|
||||
large_faces, seg_out_cpu, frames = self.batch_to_tensor(
|
||||
img_batch, mel_batch, model, self.padding, frames, coords
|
||||
)
|
||||
offset_list = self.video_cache.occupy_video_pos(len(end_idx))
|
||||
speech_ids = [
|
||||
f"{self.current_speechid}_{idx}" for idx in range(len(end_idx))
|
||||
]
|
||||
file_idxes, post_morphint_idxes, infer_silent_idxes = get_trans_idxes(need_post_morphing, need_infer_silent,self.post_morphint_idx,self.infer_silent_idx, self.file_index)
|
||||
logger.info(f"self.file_index:{self.file_index},infer_silent_idxes:{infer_silent_idxes},post_morphint_idxes:{post_morphint_idxes}, file_idxes:{file_idxes}")
|
||||
param_list = [
|
||||
(
|
||||
speech_id,
|
||||
padded_image,
|
||||
frame,
|
||||
body_mask,
|
||||
seg_mask,
|
||||
boxes,
|
||||
self.padding,
|
||||
self.video_cache,
|
||||
write_pos,
|
||||
pkl_name,
|
||||
need_infer_silent,
|
||||
need_post_morphing,
|
||||
is_silent,
|
||||
need_pre_morphing,
|
||||
pre_morphing_idx,
|
||||
infer_silent_idx,
|
||||
post_morphint_idx,
|
||||
file_idx
|
||||
|
||||
)
|
||||
for speech_id, padded_image, frame, body_mask, seg_mask, boxes, write_pos, pre_morphing_idx,infer_silent_idx, post_morphint_idx,file_idx in zip(
|
||||
speech_ids,
|
||||
large_faces,
|
||||
frames,
|
||||
full_masks,
|
||||
seg_out_cpu,
|
||||
coords,
|
||||
offset_list,
|
||||
list(range(config.PRE_MORPHING_NUM)),
|
||||
infer_silent_idxes,
|
||||
post_morphint_idxes,
|
||||
file_idxes
|
||||
)
|
||||
]
|
||||
result_frames = []
|
||||
if self.using_pool:# and not config.debug_mode:
|
||||
futures = [
|
||||
self.img_reading_pool.submit(self.paste_back, *param)
|
||||
for param in param_list
|
||||
]
|
||||
_ = [future.result() for future in futures]
|
||||
self.video_cache._inject_real_pos(len(end_idx))
|
||||
# if config.debug_mode:
|
||||
# for param in param_list:
|
||||
# frame = self.paste_back(*param)
|
||||
# if LOCAL_MODE:
|
||||
# result_frames.append(frame)
|
||||
return result_frames
|
||||
|
||||
def paste_back(
|
||||
self,
|
||||
current_speechid,
|
||||
large_face,
|
||||
frame,
|
||||
body_mask,
|
||||
mask,
|
||||
boxes,
|
||||
padding,
|
||||
video_cache,
|
||||
cache_start_offset,
|
||||
pkl_name,
|
||||
need_infer_silent: bool = False,
|
||||
need_post_morphing: bool = False,
|
||||
is_silent: bool = False,
|
||||
need_pre_morphing: bool = False,
|
||||
pre_morphing_idx: int = 0,
|
||||
infer_silent_idx: int = 0,
|
||||
post_morphint_idx: int = 0,
|
||||
file_index: int = 1,
|
||||
):
|
||||
"""
|
||||
根据模型预测结果和原始图片进行融合
|
||||
Args:
|
||||
"""
|
||||
global index
|
||||
padding = 10
|
||||
y1, y2, x1, x2 = boxes
|
||||
width, height = x2 - x1, y2 - y1
|
||||
mask = cal_mask_single_img(
|
||||
mask, use_old_mode=True, face_classes=self.face_classes
|
||||
)
|
||||
mask = np.repeat(mask[..., None], 3, axis=-1).astype("uint8")
|
||||
|
||||
mask_temp = np.zeros_like(large_face)
|
||||
mask_out = cv2.resize(mask.astype(np.float) * 255.0, (width, height)).astype(
|
||||
np.uint8
|
||||
)
|
||||
mask_temp[padding : height + padding, padding : width + padding] = mask_out
|
||||
kernel = np.ones((9, 9), np.uint8)
|
||||
mask_temp = cv2.erode(mask_temp, kernel, iterations=1) # 二值的
|
||||
# gaosi_kernel = int(0.1 * large_face.shape[0] // 2 * 2) + 1
|
||||
# mask_temp = cv2.GaussianBlur(
|
||||
# mask_temp, (gaosi_kernel, gaosi_kernel), 0, 0, cv2.BORDER_DEFAULT
|
||||
# )
|
||||
mask_temp = cv2.GaussianBlur(mask_temp, (15, 15), 0, 0, cv2.BORDER_DEFAULT)
|
||||
mask_temp = cv2.GaussianBlur(mask_temp, (5, 5), 0, 0, cv2.BORDER_DEFAULT)
|
||||
|
||||
f_background = large_face.copy()
|
||||
|
||||
frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding] = (
|
||||
f_background * (1 - mask_temp / 255.0)
|
||||
+ frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding]
|
||||
* (mask_temp / 255.0)
|
||||
)#.astype("uint8")
|
||||
if self.trans_method == constants.TRANS_METHOD.mophing_infer:
|
||||
if is_silent == 1:
|
||||
if need_pre_morphing:
|
||||
frame = morphing(
|
||||
large_face,
|
||||
frame,
|
||||
boxes,
|
||||
mp_ratio=1 - ((pre_morphing_idx + 1) / config.PRE_MORPHING_NUM),
|
||||
file_index=file_index,
|
||||
)
|
||||
logger.debug(
|
||||
f"file_index:{file_index},pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}"
|
||||
)
|
||||
logger.debug(f"pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}")
|
||||
#TODO: @txueduo 处理前过渡问题
|
||||
elif need_post_morphing and post_morphint_idx:
|
||||
mp_ratio = (post_morphint_idx) / self.morphing_num
|
||||
frame = morphing(
|
||||
large_face,
|
||||
frame,
|
||||
boxes,
|
||||
mp_ratio=mp_ratio,
|
||||
file_index=file_index
|
||||
)
|
||||
logger.debug(f"post_morphint_idx:{post_morphint_idx}, mp_ratio:{mp_ratio}, file_index:{file_index}, speech_id:{current_speechid}")
|
||||
if frame.shape[-1]==4:# and not config.output_alpha:
|
||||
frame[:,:,:3] = frame[:,:,:3][:,:,::-1]
|
||||
if config.output_alpha and frame.shape[-1]!=4:
|
||||
frame = add_alpha(frame, body_mask, config.alpha)
|
||||
if config.debug_mode:
|
||||
logger.info(f"推理:{file_index}")
|
||||
if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame):
|
||||
logger.error(f"save {file_index} err")
|
||||
video_cache._put_raw_frame(frame, cache_start_offset, current_speechid)
|
||||
index += 1
|
||||
return frame
|
||||
|
||||
def destroy(self):
|
||||
""" """
|
||||
self.data_cache.destroy()
|
||||
self.video_cache.destroy()
|
||||
self.data_cache = None
|
||||
self.video_cache = None
|
||||
if self.using_pool and self.img_reading_pool is not None:
|
||||
self.img_reading_pool.shutdown()
|
||||
self.img_reading_pool = None
|
||||
del self.model_img_batch
|
||||
del self.model_mel_batch
|
||||
|
||||
|
||||
def process_wav2lip_predict(task_queue, stop_event, output_queue, resolution, channel):
|
||||
# 实例化推理类
|
||||
w2l_processor = Wav2lip_Processor(
|
||||
task_queue, output_queue, stop_event, resolution, channel
|
||||
)
|
||||
need_update = True
|
||||
while True:
|
||||
if stop_event.is_set():
|
||||
print("----------------------stop..")
|
||||
w2l_processor.destroy() # 这一行代码可能有bug,没有同步的控制代码,可能在读的地方会有异常
|
||||
break
|
||||
try:
|
||||
params = task_queue.get()
|
||||
# TODO: 这个操作应该放在更前面, 确认是否更新的判定条件
|
||||
start = time.perf_counter()
|
||||
w2l_processor.prepare(params[2], need_update)
|
||||
need_update = False
|
||||
result = w2l_processor.infer(*params)
|
||||
logger.debug(
|
||||
f"推理结束 :{time.perf_counter() - start},inference_id:{params[1]}"
|
||||
)
|
||||
output_queue.put(result)
|
||||
logger.debug(
|
||||
f"结果通知到主进程{time.perf_counter() - start},inference_id:{params[1]}"
|
||||
)
|
||||
except Exception:
|
||||
logger.error(f"process_wav2lip_predict :{traceback.format_exc()}")
|
Loading…
Reference in New Issue
Block a user