diff --git a/utils/log.py b/utils/log.py
new file mode 100644
index 0000000..a531729
--- /dev/null
+++ b/utils/log.py
@@ -0,0 +1,119 @@
+import logging
+import os
+import sys
+import redis
+from loguru import logger as logurulogger
+import redis.exceptions
+from app.config import config
+import json
+from redis.retry import Retry
+from redis.backoff import ExponentialBackoff
+
+LOG_FORMAT = (
+ "{level: <8} "
+ "{process.name} | " # 进程名
+ "{thread.name} | "
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} - "
+ "{process} "
+ "{module}.{function}:{line} - "
+ "{message}"
+)
+LOG_NAME = ["uvicorn", "uvicorn.access", "uvicorn.error", "flask"]
+
+# 配置 Redis 连接池
+redis_pool = redis.ConnectionPool(
+ host=config.LOG_REDIS_HOST, # Redis 服务器地址
+ port=config.LOG_REDIS_PORT, # Redis 服务器端口
+ db=config.LOG_REDIS_DB, # 数据库编号
+ password=config.LOG_REDIS_AUTH, # 密码
+ max_connections=config.max_connections, # 最大连接数
+ socket_connect_timeout=config.socket_connect_timeout, # 连接超时时间
+ socket_timeout=config.socket_timeout, # 等待超时时间
+)
+
+
+class InterceptHandler(logging.Handler):
+ def emit(self, record):
+ try:
+ level = logurulogger.level(record.levelname).name
+ except AttributeError:
+ level = logging._levelToName[record.levelno]
+
+ frame, depth = logging.currentframe(), 2
+ while frame.f_code.co_filename == logging.__file__:
+ frame = frame.f_back
+ depth += 1
+
+ logurulogger.opt(depth=depth, exception=record.exc_info).log(
+ level, record.getMessage()
+ )
+
+class Logging:
+ """自定义日志"""
+
+ def __init__(self):
+ self.log_path = "logs"
+ self._connect_redis()
+ if config.IS_LOCAL:
+ os.makedirs(self.log_path, exist_ok=True)
+ self._initlogger()
+ self._reset_log_handler()
+
+ def _connect_redis(self):
+ retry = Retry(ExponentialBackoff(), 3) # 重试3次,指数退避
+ self.redis_client = redis.Redis(connection_pool=redis_pool,retry=retry) # 使用连接池
+
+ def _initlogger(self):
+ """初始化loguru配置"""
+ logurulogger.remove()
+ if config.IS_LOCAL:
+ logurulogger.add(
+ os.path.join(self.log_path, "error.log.{time:YYYY-MM-DD}"),
+ format=LOG_FORMAT,
+ level=logging.ERROR,
+ rotation="00:00",
+ retention="1 week",
+ backtrace=True,
+ diagnose=True,
+ enqueue=True
+ )
+ logurulogger.add(
+ os.path.join(self.log_path, "info.log.{time:YYYY-MM-DD}"),
+ format=LOG_FORMAT,
+ level=logging.INFO,
+ rotation="00:00",
+ retention="1 week",
+ enqueue=True
+ )
+ logurulogger.add(
+ sys.stdout,
+ format=LOG_FORMAT,
+ level=logging.DEBUG,
+ colorize=True,
+ )
+
+ logurulogger.add(self._log_to_redis, level="INFO", format=LOG_FORMAT)
+ self.logger = logurulogger
+
+
+ def _log_to_redis(self, message):
+ """将日志写入 Redis 列表"""
+ try:
+ self.redis_client.rpush(f"nlp.logger.{config.env_version}.log", json.dumps({"message": message}))
+ except redis.exceptions.ConnectionError as e:
+ logger.error(f"write {message} Redis connection error: {e}")
+ except redis.exceptions.TimeoutError as e:
+ logger.error(f"write {message} Redis operation timed out: {e}")
+ except Exception as e:
+ logger.error(f"write {message} Unexpected error: {e}")
+
+ def _reset_log_handler(self):
+ for log in LOG_NAME:
+ logger = logging.getLogger(log)
+ logger.handlers = [InterceptHandler()]
+
+ def getlogger(self):
+ return self.logger
+
+logger = Logging().getlogger()
+
diff --git a/utils/loop_frame_tool.py b/utils/loop_frame_tool.py
new file mode 100644
index 0000000..1234644
--- /dev/null
+++ b/utils/loop_frame_tool.py
@@ -0,0 +1,334 @@
+from utils.log import logger
+
+
+def play_in_loop_v2(
+ segments,
+ startfrom,
+ batch_num,
+ last_direction,
+ is_silent,
+ first_speak,
+ last_speak,
+):
+ """
+ batch_num: 初始和结束,每一帧都这么判断
+ 1、静默时,在静默段循环, 左边界正向,右边界反向, 根据上一次方向和位置,给出新的方向和位置
+ 2、静默转说话: 就近到说话段,pre_falg, post_flag, 都为true VS 其中一个为true
+ 3、说话转静默: 动作段播完,再进入静默(如果还在持续说话,静默段不循环)
+ 4、在整个视频左端点: 开始端只能正向,静默时循环,说话时走2
+ 5、在整个视频右端点: 开始时只能反向,静默时循环,说话时走2
+ 6、根据方向获取batch_num 数量的视频帧,return batch_idxes, current_direction
+ Args:
+ segments: 循环帧配置 [[st, ed, True], ...]
+ startfrom: cur_pos
+ batch_num: 5
+ last_direction: 0反向1正向
+ is_silent: 0说话态1动作态
+ first_speak: 记录是不是第一次讲话
+ last_speak: 记录是不是讲话结束那一刻
+ """
+ frames = []
+ cur_pos = startfrom
+ cur_direction = last_direction
+ is_first_speak_frame = first_speak
+ is_last_speak_frame = True if last_speak and batch_num == 1 else False
+ while batch_num != 0:
+ # 获取当前帧的所在子分割区间
+ sub_seg_idx = subseg_judge(cur_pos, segments)
+ # 获取移动方向
+ next_direction, next_pos = get_next_direction(
+ segments,
+ cur_pos,
+ cur_direction,
+ is_silent,
+ sub_seg_idx,
+ is_first_speak_frame,
+ is_last_speak_frame,
+ )
+ # 获取指定方向的帧
+ next_pos = get_next_frame(next_pos, next_direction)
+ frames.append(next_pos)
+ batch_num -= 1
+ is_first_speak_frame = (
+ True if first_speak and batch_num == config.batch_size else False
+ )
+ is_last_speak_frame = True if last_speak and batch_num == 1 else False
+
+ cur_direction = next_direction
+ cur_pos = next_pos
+ return frames, next_direction
+
+
+def subseg_judge(cur_pos, segments):
+ for idx, frame_seg in enumerate(segments):
+ if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]:
+ return idx
+ if cur_pos == 0:
+ return 0
+
+def get_next_direction(
+ segments,
+ cur_pos,
+ cur_direction,
+ is_silent,
+ sub_seg_idx,
+ is_first_speak_frame: bool = False,
+ is_last_speak_frame: bool = False,
+):
+ """
+ 3.3.0 循环帧需求,想尽快走到预期状态
+ if 动作段:
+ if 开始说话:
+ if 边界:
+ if 正向:
+ pass
+ else:
+ pass
+ else:
+ if 正向:
+ pass
+ else:
+ pass
+ elif 静默:
+ 同上
+ elif 说话中:
+ 同上
+ elif 说话结束:
+ 同上
+ elif 静默段:
+ 同上
+ Args:
+ is_first_speak_frame: 开始说话flag
+ is_last_speak_frame: 说话结束flag
+ """
+ left, right, loop_flag = segments[sub_seg_idx]
+ if loop_flag:
+ if is_silent == 1:
+ next_direct, next_pos = pure_silent(
+ segments, left, right, cur_pos, cur_direction, sub_seg_idx
+ )
+ logger.debug(
+ f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}"
+ )
+ elif is_silent == 0:
+ next_direct, next_pos = silent2action(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_first_speak_frame,
+ )
+ logger.debug(
+ f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}"
+ )
+ else:
+ if is_silent == 1:
+ next_direct, next_pos = action2silent(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_last_speak_frame,
+ )
+ logger.debug(
+ f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}"
+ )
+ elif is_silent == 0:
+ next_direct, next_pos = pure_action(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_last_speak_frame,
+ )
+ logger.debug(
+ f"cur_pos:{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}"
+ )
+ return next_direct, next_pos
+
+def get_next_frame(cur_pos, cur_direction):
+ """根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的
+ 不需要再做额外的边界判断
+ """
+ # 正向
+ if cur_direction == 1:
+ return cur_pos + 1
+ # 反向
+ elif cur_direction == 0:
+ return cur_pos - 1
+
+def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx):
+ """
+ loop_flag == True and is_silent==1
+ whether border
+ whether forward
+ Return:
+ next_direction
+ """
+ # 左边界正向,右边界反向
+ if cur_pos == segments[0][0]:
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+ # 右边界,反向
+ if cur_pos == right:
+ return 0, cur_pos
+ # 左边界,正向
+ if cur_pos == left:
+ return 1, cur_pos
+ # 非边界,之前正向,则继续正向,否则反向
+ if cur_pos > left and cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+
+
+def pure_action(
+ segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
+):
+ """
+ loop_flag ==False and is_silent == 0
+ 动作播完,正向到静默段 (存在跳段行为)
+ whether border
+ whether forward # 正播反播
+ Args:
+ is_last_speak_frame: 最后说话结束时刻
+ Return: next_direction
+ """
+ if cur_pos == segments[0][0]:
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+
+ if is_last_speak_frame:
+ # 动作段在末尾,向前找静默
+ if sub_seg_idx == len(segments) - 1:
+ return 0, cur_pos
+ # 动作段在开始, 向后
+ if sub_seg_idx == 0:
+ return 1, cur_pos
+ # 动作段在中间,就近原则
+ mid = left + (right - left + 1) // 2
+ # 就近原则优先
+ if cur_pos < mid:
+ return 0, cur_pos
+ else:
+ return 1, cur_pos
+
+ else:
+ # 其他情况,播放方向一致
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+
+
+def silent2action(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_first_speak_frame: bool = False,
+):
+ """
+ 在静默区间但是在讲话
+ loop_flag=True and is_silent == 0
+ whether border
+ whether forward
+
+ Return: next_direction
+ """
+ # 向最近的动作段移动, 如果左面没有动作段
+ # TODO: 确认下面逻辑是否正确
+ if (
+ cur_pos == segments[0][0]
+ ): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+ # 在静默左边界处,且仍在讲话
+ if cur_pos == left:
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+ # 在静默右边界处,且仍在讲话
+ elif cur_pos == right:
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+ else:
+ mid = left + (right - left + 1) // 2
+ # !!就近原则只对第一次说话有效,其他情况遵循上一次状态
+ if is_first_speak_frame:
+ # 如果第一段
+ if sub_seg_idx == 0 and segments[0][2]:
+ return 1, cur_pos
+ # 如果最后一段
+ elif sub_seg_idx == len(segments) - 1 and segments[-1][2]:
+ return 0, cur_pos
+
+ if cur_pos < mid:
+ return 0, cur_pos
+ else:
+ return 1, cur_pos
+ else:
+ if cur_direction == 1:
+ return 1, cur_pos
+ elif cur_direction == 0:
+ return 0, cur_pos
+
+
+def action2silent(
+ segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
+):
+ """
+ loop_flag=False and is_silent==1
+ whether border
+ Return: next_direction
+ """
+ if cur_pos == segments[0][0]:
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+ # 动作段,说话结束转静默情况下,就近原则,进入静默
+ if is_last_speak_frame:
+ mid = left + (right - left + 1) // 2
+ if cur_pos < mid:
+ return 0, cur_pos
+ else:
+ return 1, cur_pos
+
+ else:
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+
+
+if __name__ == "__main__":
+ startfrom = 0 # 上一个batch的最后一帧
+ frame_config= []
+ audio_frame_length = len(mel_chunks) # TODO: 确认是否为 batch_size
+ startfrom = startfrom if startfrom>= frame_config[0][0] else frame_config[0][0]
+ first_speak, last_speak = True, False
+ is_silent= True # 当前batch是否为静默
+ last_direction = 1 # -1 为反方向
+ start_idx_list, current_direction = play_in_loop_v2(
+ frame_config,
+ startfrom,
+ audio_frame_length,
+ last_direction,
+ is_silent,
+ first_speak,
+ last_speak,
+ )
\ No newline at end of file
diff --git a/utils/wav2lip_processor.py b/utils/wav2lip_processor.py
new file mode 100644
index 0000000..03dc5c9
--- /dev/null
+++ b/utils/wav2lip_processor.py
@@ -0,0 +1,669 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+@File : self.py
+@Time : 2023/10/11 18:23:49
+@Author : LiWei
+@Version : 1.0
+"""
+
+import os
+import threading
+import time
+import traceback
+from concurrent.futures import ThreadPoolExecutor
+
+import cv2
+import numpy as np
+import torch
+
+import constants
+from config import config
+from face_parsing.model import BiSeNet
+from face_parsing.swap import cal_mask_single_img, image_to_parsing_ori
+from high_perf.buffer.infer_buffer import Infer_Cache
+from high_perf.buffer.video_buff import Video_Cache
+from log import logger
+from models.wav2lipv2 import Wav2Lip
+from utils.util import (
+ add_alpha,
+ datagen,
+ get_actor_id,
+ get_model_local_path,
+ load_face_box,
+ load_seg_model,
+ load_w2l_model,
+ morphing,
+ play_in_loop_v2,
+ save_raw_video,
+ save_raw_wav,
+ get_trans_idxes,
+ load_config
+)
+
+USING_JIT = False
+LOCAL_MODE = False
+index = 1
+
+
+def save_(current_speechid, result_frames, audio_segments):
+ raw_vid = save_raw_video(current_speechid, result_frames)
+ raw_wav = save_raw_wav(current_speechid, audio_segments)
+ command = "ffmpeg -y -i {} -i {} -c:v copy -c:a aac -strict experimental -shortest {}".format(
+ raw_wav, raw_vid, f"temp/{current_speechid}_result_final.mp4"
+ )
+ import platform
+ import subprocess
+
+ subprocess.call(command, shell=platform.system() != "Windows")
+ os.remove(raw_wav)
+
+
+class Wav2lip_Processor:
+ def __init__(
+ self,
+ task_queue,
+ result_queue,
+ stop_event,
+ resolution,
+ channel,
+ device: str = "cuda:0",
+ ) -> None:
+ self.task_queue = task_queue
+ self.result_queue = result_queue
+ self.stop_event = stop_event
+ self.write_lock = threading.Lock()
+ self.device = device
+ self.resolution = resolution
+ self.channel = channel
+
+ pool_size = config.read_frame_pool_size
+ self.using_pool = True
+ if pool_size <= 0:
+ self.using_pool = False
+
+ self.img_reading_pool = None
+
+ if self.using_pool:
+ self.img_reading_pool = ThreadPoolExecutor(max_workers=pool_size)
+
+ self.face_det_results = {}
+ self.w2l_models = {}
+ self.seg_net = None # 分割模型
+
+ mel_block_size = (config.wav2lip_batch_size, 80, 16)
+ audio_block_size = (np.prod(mel_block_size).tolist(), 1)
+ # 缓存区最大,缓存两秒的视频,超出的会被覆盖掉
+ self.video_cache = Video_Cache(
+ "video_cache",
+ "audio_audio_cache",
+ "speech_id_cache",
+ self.resolution,
+ channel=self.channel,
+ audio_block_size=audio_block_size[0],
+ create=True,
+ )
+
+ self.data_cache = Infer_Cache(
+ "infer_cache",
+ create=True,
+ mel_block_shape=mel_block_size,
+ audio_block_shape=audio_block_size,
+ )
+ # TODO: 这个verision 不生效
+ self.version = "v2"
+ self.current_speechid = ""
+ # torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1])
+ # torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1])
+ self.model_mel_batch = torch.rand((5, 1, 80, 16), dtype=torch.float).to(
+ self.device
+ )
+ self.model_img_batch = torch.rand((5, 6, 288, 288), dtype=torch.float).to(
+ self.device
+ )
+ self.infer_silent_idx = 1
+ self.post_morphint_idx = 1
+ self.first_start = True # 判断是否为开播瞬间的静默
+ self.file_index = 1
+
+ def prepare(self, model_speaker_id: str, need_update: bool = True):
+ """加载好模型文件"""
+ if not need_update:
+ return
+ self.model_url, self.frame_config, self.padding,self.face_classes,self.trans_method, self.infer_silent_num, self.morphing_num, pkl_version = load_config(model_speaker_id)
+ models_url_list = [self.model_url]
+ actor_info = [model_speaker_id]
+ for actor_url in actor_info:
+ # 这里不再下载资源,相关工作移到主进程
+ actor_id = get_actor_id(actor_url)
+ self.face_det_results[actor_id] = load_face_box(actor_id, pkl_version)
+ self.w2l_process_state = "READY"
+ self.__init_seg_model(f"{config.model_folder}/79999_iter.pth")
+ # 预先加载所有model 模型
+ # TODO: 采用旧的加载方式看下模型model 是否有错
+ for model_url in models_url_list:
+ w2l_name = get_model_local_path(model_url)
+ self.__init_w2l_model(w2l_name)
+ logger.debug("Prepare w2l data.")
+
+ def __init_w2l_model(self, w2l_name):
+ """ """
+ # TODO 这个最好是可以放在启动直播间之前----本模型运行时可能会切换,目前咱们不同的模特可能模型不同
+ if w2l_name in self.w2l_models:
+ return
+
+ w2l_path = f"{config.model_folder}/{w2l_name}"
+ if USING_JIT:
+ net = self.__jit_script(f"{config.model_folder}/w2l.pt", model=Wav2Lip())
+ else:
+ net = Wav2Lip()
+ logger.debug(f"加载模型 {w2l_name}")
+ self.w2l_models[w2l_name] = load_w2l_model(w2l_path, net, self.device)
+
+ def __jit_script(self, jit_module_path: str, model):
+ """
+ convert pytorch module to a runnable scripts.
+ """
+ if not os.path.exists(jit_module_path):
+ scripted_module = torch.jit.script(model)
+ torch.jit.save(scripted_module, jit_module_path)
+ net = torch.jit.load(jit_module_path)
+ return net
+
+ def __init_seg_model(self, seg_path: str):
+ """ """
+ # 本模型可以视作永不改变
+ n_classes = 19
+ net = BiSeNet(n_classes=n_classes)
+ if USING_JIT:
+ net = self.__jit_script(f"{config.model_folder}/swap.pt", model=net)
+ self.seg_net = load_seg_model(seg_path, net, self.device)
+
+ def __using_w2lmodel(self, model_url):
+ """ """
+ # 取出对应的模型
+ model_path = get_model_local_path(model_url)
+ if model_path in self.w2l_models:
+ return self.w2l_models[model_path]
+ else:
+ self.__init_w2l_model(model_path)
+ return self.w2l_models[model_path]
+
+ def infer(
+ self,
+ is_silent,
+ inference_id,
+ pkl_name,
+ startfrom=0,
+ last_direction=1,
+ duration=0,
+ end_with_silent=False,
+ first_speak=False,
+ last_speak=False,
+ before_speak=False,
+ model_url=None,
+ debug: bool = False,
+ test_data: tuple = (),
+ video_idx: list = [],
+ ):
+ """ """
+ start = time.perf_counter()
+ # TODO: 这是啥逻辑
+ is_silent_ = 0 if duration < 5 else 1
+ if not model_url:
+ model_url = self.model_url
+ model = self.__using_w2lmodel(model_url)
+ # 控制加载数据人的个数
+ self.current_speechid = inference_id
+ if debug:
+ audio_segments, mel_chunks = test_data
+ else:
+ audio_segments, mel_chunks = self.data_cache.get_data()
+ audio_frame_length = len(mel_chunks)
+
+ if video_idx:
+ mel_chunks, start_idx_list, current_direction = mel_chunks, video_idx, 1
+ else:
+ startfrom = startfrom if startfrom>= self.frame_config[0][0] else self.frame_config[0][0]
+ start_idx_list, current_direction = play_in_loop_v2(
+ self.frame_config,
+ startfrom,
+ audio_frame_length,
+ last_direction,
+ is_silent,
+ is_silent_,
+ first_speak,
+ last_speak,
+ )
+ logger.info(
+ f"start_idx_list :{start_idx_list}, speechid: {self.current_speechid}"
+ )
+ gan = datagen(
+ self.img_reading_pool,
+ mel_chunks,
+ start_idx_list,
+ pkl_name,
+ self.face_det_results[pkl_name],
+ self.padding,
+ is_silent,
+ )
+ self.file_index += 5
+ # 3.4.0 增加静默推理过渡融合, 说话后的n帧向静默过渡
+ need_pre_morphing=False
+ if self.trans_method == constants.TRANS_METHOD.mophing_infer:
+ if self.infer_silent_num or self.morphing_num or config.PRE_MORPHING_NUM:
+ logger.info(f"self.first_start{self.first_start}, {before_speak}, {self.current_speechid}")
+ if (
+ is_silent == 0
+ or (
+ before_speak)
+ or (
+ not self.first_start
+ and is_silent == 1
+ and ((self.infer_silent_idx < self.infer_silent_num) or (self.post_morphint_idx < self.morphing_num))
+ )
+ ):
+ # 开始说话状态时候,要把上次计数清理
+ need_infer_silent = (
+ True
+ if is_silent == 1
+ and self.infer_silent_idx < self.infer_silent_num and not before_speak and "&" not in self.current_speechid
+ else False
+ )
+ need_pre_morphing = True if before_speak else False
+ need_post_morphing = (
+ True
+ if is_silent == 1
+ and self.morphing_num
+ and self.post_morphint_idx < self.morphing_num
+ and not need_infer_silent
+ and not need_pre_morphing
+ else False
+ )
+
+ if is_silent == 0 or "&" in self.current_speechid:
+ self.infer_silent_idx = 1
+ self.post_morphint_idx = 1
+
+ self.first_start = False
+ logger.debug(
+ f"{start_idx_list} is_silent:{is_silent}, infer_silent_idx: {self.infer_silent_idx} morphint_idx:{self.post_morphint_idx} need_post_morphing: {need_post_morphing}, need_infer_silent:{need_infer_silent}, speech:{self.current_speechid},need_pre_morphing:{need_pre_morphing}"
+ )
+ if is_silent == 1 and need_post_morphing and self.post_morphint_idx < self.morphing_num:
+ self.post_morphint_idx += 5
+ result_frames = self.prediction(
+ gan,
+ model,
+ pkl_name,
+ need_post_morphing=need_post_morphing,
+ need_infer_silent=need_infer_silent,
+ is_silent=is_silent,
+ need_pre_morphing=need_pre_morphing,
+ )
+ if is_silent == 1 and need_infer_silent and self.infer_silent_idx < self.infer_silent_num:
+ self.infer_silent_idx += 5
+
+ # 第一次静默时候,直接进入这个逻辑
+ else:
+ self.no_prediction(gan, model, pkl_name)
+ else:
+ if is_silent==0:
+ self.prediction(
+ gan,
+ model,
+ pkl_name,
+ is_silent=is_silent,
+ need_pre_morphing=need_pre_morphing,
+ )
+ else:
+ self.no_prediction(gan, model, pkl_name)
+ elif self.trans_method == constants.TRANS_METHOD.all_infer:
+ result_frames = self.prediction(gan, model, pkl_name,is_silent=is_silent)
+ else:
+ raise ValueError(f"not supported {self.trans_method}")
+
+ if LOCAL_MODE:
+ saveThread = threading.Thread(
+ target=save_,
+ args=(self.current_speechid, result_frames, audio_segments),
+ )
+ saveThread.start()
+ self.video_cache.put_audio(audio_segments, self.current_speechid)
+
+ logger.debug(
+ f"视频合成时间{time.perf_counter() - start},inference_id:{self.current_speechid}"
+ )
+ return [
+ [0, 1, 2, 3, 4],
+ start_idx_list[-1],
+ current_direction,
+ ] # frames_return_list: 视频帧数据 res_index: 控制下一帧的开始位置 direction: 播放的顺序 正 反
+
+ @torch.no_grad()
+ def batch_to_tensor(self, img_batch, mel_batch, model, padding, frames, coords):
+ padding = 10
+ logger.debug(f"非静默,推理: {self.current_speechid} 准备推理数据 {id(model)}")
+ img_batch_tensor = torch.as_tensor(img_batch, dtype=torch.float32).to(
+ self.device, non_blocking=True
+ )
+ mel_batch_tensor = torch.as_tensor(mel_batch, dtype=torch.float32).to(
+ self.device, non_blocking=True
+ )
+
+ img_batch_tensor = img_batch_tensor.permute(0, 3, 1, 2)
+ mel_batch_tensor = mel_batch_tensor.permute(0, 3, 1, 2)
+ logger.debug(
+ f"非静默,推理: {self.current_speechid} 即将嘴型生成 {img_batch_tensor.shape} {mel_batch_tensor.shape}"
+ ) # torch.Size([5, 6, 288, 288]) torch.Size([5, 1, 80, 16])
+
+ pred_batch = model(mel_batch_tensor, img_batch_tensor) * 255.0
+
+ logger.debug(f"非静默,推理: {self.current_speechid} 嘴型生成完成")
+ pred_clone_batch = pred_batch.clone()
+ pred_batch_cpu = pred_batch.to(torch.uint8).to("cpu", non_blocking=True)
+ del pred_batch
+ del mel_batch_tensor
+ logger.debug(f"非静默,推理: {self.current_speechid} 生成面部准备好了")
+ large_faces = [
+ frame.copy()[:, :, :3][
+ y1 - padding : y2 + padding, x1 - padding : x2 + padding
+ ]
+ for frame, box in zip(frames, coords)
+ for y1, y2, x1, x2 in [box]
+ ]
+ logger.debug(f"非静默,推理: {self.current_speechid} 即将对生成的面部进行分割")
+ seg_out = image_to_parsing_ori(
+ pred_clone_batch, self.seg_net
+ ) # 假设这个函数在 GPU 上执行并输出 GPU tensor
+ del img_batch_tensor
+ del pred_clone_batch
+
+ torch.cuda.synchronize() # 等待异步结束
+ seg_out_cpu = seg_out.to(
+ "cpu", non_blocking=True
+ ) # 发起 seg_out 的异步数据传输
+ del seg_out
+ logger.debug(f"非静默,推理: {self.current_speechid} 对生成的面部进行分割结束")
+ pred_batch_cpu = pred_batch_cpu.numpy().transpose(
+ 0, 2, 3, 1
+ ) # 按需计算,用到了才计算,这样能不能使得GPU往cpu做数据copy的时间被隐藏?
+ for predict, frame, box in zip(pred_batch_cpu, frames, coords):
+ y1, y2, x1, x2 = box
+ width, height = x2 - x1, y2 - y1
+ frame[:, :, :3][y1:y2, x1:x2] = cv2.resize(
+ predict, (width, height)
+ ) # 无需再转为unit8,gpu上直接转为unit8,这样传输规模小一些
+ torch.cuda.synchronize() # 等待异步结束
+ return large_faces, seg_out_cpu, frames
+
+ def no_prediction(self, gan, model, pkl_name):
+ logger.debug(
+ f"无需推理: {self.current_speechid} {self.model_img_batch.shape} {self.model_mel_batch.shape}"
+ )
+ with torch.no_grad():
+ pred_ = model(self.model_mel_batch, self.model_img_batch)
+ seg_out = image_to_parsing_ori(
+ pred_, self.seg_net
+ ) # 不确定是不是需要也warmup第二个网络
+ for _, _, frames, full_masks, _, end_idx in gan:
+ speech_ids = [
+ f"{self.current_speechid}_{idx}" for idx in range(len(end_idx))
+ ]
+ offset_list = self.video_cache.occupy_video_pos(len(end_idx))
+ file_idxes, _, _ = get_trans_idxes(False, False, 0,0, self.file_index)
+ logger.info(f"self.file_index:{self.file_index}, file_idxes:{file_idxes}")
+ param_list = [
+ (
+ speech_id,
+ frame,
+ body_mask,
+ write_pos,
+ pkl_name,
+ self.video_cache,
+ file_idx
+ )
+ for speech_id, frame, body_mask, write_pos, file_idx in zip(
+ speech_ids,
+ frames,
+ full_masks,
+ offset_list,
+ file_idxes
+ )
+ ]
+ if self.using_pool:
+ futures = [
+ self.img_reading_pool.submit(self.silent_paste_back, *param)
+ for param in param_list
+ ]
+ _ = [future.result() for future in futures]
+ self.video_cache._inject_real_pos(len(end_idx))
+ # if config.debug_mode:
+ # for param in param_list:
+ # self.silent_paste_back(*param)
+ del pred_
+ del seg_out
+ torch.cuda.empty_cache()
+
+
+ def silent_paste_back(self, speech_id, frame, body_mask,write_pos,pkl_name, video_cache, file_index):
+ global index
+ if frame.shape[-1] == 4:
+ frame[:, :, :3] = frame[:, :, :3][:, :, ::-1]
+ elif config.output_alpha:
+ frame = add_alpha(frame, body_mask, config.alpha)
+ logger.debug(f"self.file_index:{self.file_index}, file_index:{file_index}")
+ if config.debug_mode:
+ logger.info(f"不用推理: {file_index} {frame.shape}")
+ if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame):
+ logger.error(f"save {file_index} err")
+ index += 1
+ logger.info(f"silent frame shape:{frame.shape}")
+ video_cache._put_raw_frame(frame, write_pos, speech_id)
+ return frame
+
+
+ def prediction(
+ self,
+ gan,
+ model,
+ pkl_name: str,
+ need_post_morphing: bool = False,
+ need_infer_silent: bool = False,
+ is_silent: int = 1,
+ need_pre_morphing: bool = False,
+ ):
+ """ """
+ logger.debug(f"非静默,推理: {self.current_speechid}")
+ for img_batch, mel_batch, frames, full_masks, coords, end_idx in gan:
+ large_faces, seg_out_cpu, frames = self.batch_to_tensor(
+ img_batch, mel_batch, model, self.padding, frames, coords
+ )
+ offset_list = self.video_cache.occupy_video_pos(len(end_idx))
+ speech_ids = [
+ f"{self.current_speechid}_{idx}" for idx in range(len(end_idx))
+ ]
+ file_idxes, post_morphint_idxes, infer_silent_idxes = get_trans_idxes(need_post_morphing, need_infer_silent,self.post_morphint_idx,self.infer_silent_idx, self.file_index)
+ logger.info(f"self.file_index:{self.file_index},infer_silent_idxes:{infer_silent_idxes},post_morphint_idxes:{post_morphint_idxes}, file_idxes:{file_idxes}")
+ param_list = [
+ (
+ speech_id,
+ padded_image,
+ frame,
+ body_mask,
+ seg_mask,
+ boxes,
+ self.padding,
+ self.video_cache,
+ write_pos,
+ pkl_name,
+ need_infer_silent,
+ need_post_morphing,
+ is_silent,
+ need_pre_morphing,
+ pre_morphing_idx,
+ infer_silent_idx,
+ post_morphint_idx,
+ file_idx
+
+ )
+ for speech_id, padded_image, frame, body_mask, seg_mask, boxes, write_pos, pre_morphing_idx,infer_silent_idx, post_morphint_idx,file_idx in zip(
+ speech_ids,
+ large_faces,
+ frames,
+ full_masks,
+ seg_out_cpu,
+ coords,
+ offset_list,
+ list(range(config.PRE_MORPHING_NUM)),
+ infer_silent_idxes,
+ post_morphint_idxes,
+ file_idxes
+ )
+ ]
+ result_frames = []
+ if self.using_pool:# and not config.debug_mode:
+ futures = [
+ self.img_reading_pool.submit(self.paste_back, *param)
+ for param in param_list
+ ]
+ _ = [future.result() for future in futures]
+ self.video_cache._inject_real_pos(len(end_idx))
+ # if config.debug_mode:
+ # for param in param_list:
+ # frame = self.paste_back(*param)
+ # if LOCAL_MODE:
+ # result_frames.append(frame)
+ return result_frames
+
+ def paste_back(
+ self,
+ current_speechid,
+ large_face,
+ frame,
+ body_mask,
+ mask,
+ boxes,
+ padding,
+ video_cache,
+ cache_start_offset,
+ pkl_name,
+ need_infer_silent: bool = False,
+ need_post_morphing: bool = False,
+ is_silent: bool = False,
+ need_pre_morphing: bool = False,
+ pre_morphing_idx: int = 0,
+ infer_silent_idx: int = 0,
+ post_morphint_idx: int = 0,
+ file_index: int = 1,
+ ):
+ """
+ 根据模型预测结果和原始图片进行融合
+ Args:
+ """
+ global index
+ padding = 10
+ y1, y2, x1, x2 = boxes
+ width, height = x2 - x1, y2 - y1
+ mask = cal_mask_single_img(
+ mask, use_old_mode=True, face_classes=self.face_classes
+ )
+ mask = np.repeat(mask[..., None], 3, axis=-1).astype("uint8")
+
+ mask_temp = np.zeros_like(large_face)
+ mask_out = cv2.resize(mask.astype(np.float) * 255.0, (width, height)).astype(
+ np.uint8
+ )
+ mask_temp[padding : height + padding, padding : width + padding] = mask_out
+ kernel = np.ones((9, 9), np.uint8)
+ mask_temp = cv2.erode(mask_temp, kernel, iterations=1) # 二值的
+ # gaosi_kernel = int(0.1 * large_face.shape[0] // 2 * 2) + 1
+ # mask_temp = cv2.GaussianBlur(
+ # mask_temp, (gaosi_kernel, gaosi_kernel), 0, 0, cv2.BORDER_DEFAULT
+ # )
+ mask_temp = cv2.GaussianBlur(mask_temp, (15, 15), 0, 0, cv2.BORDER_DEFAULT)
+ mask_temp = cv2.GaussianBlur(mask_temp, (5, 5), 0, 0, cv2.BORDER_DEFAULT)
+
+ f_background = large_face.copy()
+
+ frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding] = (
+ f_background * (1 - mask_temp / 255.0)
+ + frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding]
+ * (mask_temp / 255.0)
+ )#.astype("uint8")
+ if self.trans_method == constants.TRANS_METHOD.mophing_infer:
+ if is_silent == 1:
+ if need_pre_morphing:
+ frame = morphing(
+ large_face,
+ frame,
+ boxes,
+ mp_ratio=1 - ((pre_morphing_idx + 1) / config.PRE_MORPHING_NUM),
+ file_index=file_index,
+ )
+ logger.debug(
+ f"file_index:{file_index},pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}"
+ )
+ logger.debug(f"pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}")
+ #TODO: @txueduo 处理前过渡问题
+ elif need_post_morphing and post_morphint_idx:
+ mp_ratio = (post_morphint_idx) / self.morphing_num
+ frame = morphing(
+ large_face,
+ frame,
+ boxes,
+ mp_ratio=mp_ratio,
+ file_index=file_index
+ )
+ logger.debug(f"post_morphint_idx:{post_morphint_idx}, mp_ratio:{mp_ratio}, file_index:{file_index}, speech_id:{current_speechid}")
+ if frame.shape[-1]==4:# and not config.output_alpha:
+ frame[:,:,:3] = frame[:,:,:3][:,:,::-1]
+ if config.output_alpha and frame.shape[-1]!=4:
+ frame = add_alpha(frame, body_mask, config.alpha)
+ if config.debug_mode:
+ logger.info(f"推理:{file_index}")
+ if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame):
+ logger.error(f"save {file_index} err")
+ video_cache._put_raw_frame(frame, cache_start_offset, current_speechid)
+ index += 1
+ return frame
+
+ def destroy(self):
+ """ """
+ self.data_cache.destroy()
+ self.video_cache.destroy()
+ self.data_cache = None
+ self.video_cache = None
+ if self.using_pool and self.img_reading_pool is not None:
+ self.img_reading_pool.shutdown()
+ self.img_reading_pool = None
+ del self.model_img_batch
+ del self.model_mel_batch
+
+
+def process_wav2lip_predict(task_queue, stop_event, output_queue, resolution, channel):
+ # 实例化推理类
+ w2l_processor = Wav2lip_Processor(
+ task_queue, output_queue, stop_event, resolution, channel
+ )
+ need_update = True
+ while True:
+ if stop_event.is_set():
+ print("----------------------stop..")
+ w2l_processor.destroy() # 这一行代码可能有bug,没有同步的控制代码,可能在读的地方会有异常
+ break
+ try:
+ params = task_queue.get()
+ # TODO: 这个操作应该放在更前面, 确认是否更新的判定条件
+ start = time.perf_counter()
+ w2l_processor.prepare(params[2], need_update)
+ need_update = False
+ result = w2l_processor.infer(*params)
+ logger.debug(
+ f"推理结束 :{time.perf_counter() - start},inference_id:{params[1]}"
+ )
+ output_queue.put(result)
+ logger.debug(
+ f"结果通知到主进程{time.perf_counter() - start},inference_id:{params[1]}"
+ )
+ except Exception:
+ logger.error(f"process_wav2lip_predict :{traceback.format_exc()}")