diff --git a/utils/log.py b/utils/log.py new file mode 100644 index 0000000..a531729 --- /dev/null +++ b/utils/log.py @@ -0,0 +1,119 @@ +import logging +import os +import sys +import redis +from loguru import logger as logurulogger +import redis.exceptions +from app.config import config +import json +from redis.retry import Retry +from redis.backoff import ExponentialBackoff + +LOG_FORMAT = ( + "{level: <8} " + "{process.name} | " # 进程名 + "{thread.name} | " + "{time:YYYY-MM-DD HH:mm:ss.SSS} - " + "{process} " + "{module}.{function}:{line} - " + "{message}" +) +LOG_NAME = ["uvicorn", "uvicorn.access", "uvicorn.error", "flask"] + +# 配置 Redis 连接池 +redis_pool = redis.ConnectionPool( + host=config.LOG_REDIS_HOST, # Redis 服务器地址 + port=config.LOG_REDIS_PORT, # Redis 服务器端口 + db=config.LOG_REDIS_DB, # 数据库编号 + password=config.LOG_REDIS_AUTH, # 密码 + max_connections=config.max_connections, # 最大连接数 + socket_connect_timeout=config.socket_connect_timeout, # 连接超时时间 + socket_timeout=config.socket_timeout, # 等待超时时间 +) + + +class InterceptHandler(logging.Handler): + def emit(self, record): + try: + level = logurulogger.level(record.levelname).name + except AttributeError: + level = logging._levelToName[record.levelno] + + frame, depth = logging.currentframe(), 2 + while frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logurulogger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + +class Logging: + """自定义日志""" + + def __init__(self): + self.log_path = "logs" + self._connect_redis() + if config.IS_LOCAL: + os.makedirs(self.log_path, exist_ok=True) + self._initlogger() + self._reset_log_handler() + + def _connect_redis(self): + retry = Retry(ExponentialBackoff(), 3) # 重试3次,指数退避 + self.redis_client = redis.Redis(connection_pool=redis_pool,retry=retry) # 使用连接池 + + def _initlogger(self): + """初始化loguru配置""" + logurulogger.remove() + if config.IS_LOCAL: + logurulogger.add( + os.path.join(self.log_path, "error.log.{time:YYYY-MM-DD}"), + format=LOG_FORMAT, + level=logging.ERROR, + rotation="00:00", + retention="1 week", + backtrace=True, + diagnose=True, + enqueue=True + ) + logurulogger.add( + os.path.join(self.log_path, "info.log.{time:YYYY-MM-DD}"), + format=LOG_FORMAT, + level=logging.INFO, + rotation="00:00", + retention="1 week", + enqueue=True + ) + logurulogger.add( + sys.stdout, + format=LOG_FORMAT, + level=logging.DEBUG, + colorize=True, + ) + + logurulogger.add(self._log_to_redis, level="INFO", format=LOG_FORMAT) + self.logger = logurulogger + + + def _log_to_redis(self, message): + """将日志写入 Redis 列表""" + try: + self.redis_client.rpush(f"nlp.logger.{config.env_version}.log", json.dumps({"message": message})) + except redis.exceptions.ConnectionError as e: + logger.error(f"write {message} Redis connection error: {e}") + except redis.exceptions.TimeoutError as e: + logger.error(f"write {message} Redis operation timed out: {e}") + except Exception as e: + logger.error(f"write {message} Unexpected error: {e}") + + def _reset_log_handler(self): + for log in LOG_NAME: + logger = logging.getLogger(log) + logger.handlers = [InterceptHandler()] + + def getlogger(self): + return self.logger + +logger = Logging().getlogger() + diff --git a/utils/loop_frame_tool.py b/utils/loop_frame_tool.py new file mode 100644 index 0000000..1234644 --- /dev/null +++ b/utils/loop_frame_tool.py @@ -0,0 +1,334 @@ +from utils.log import logger + + +def play_in_loop_v2( + segments, + startfrom, + batch_num, + last_direction, + is_silent, + first_speak, + last_speak, +): + """ + batch_num: 初始和结束,每一帧都这么判断 + 1、静默时,在静默段循环, 左边界正向,右边界反向, 根据上一次方向和位置,给出新的方向和位置 + 2、静默转说话: 就近到说话段,pre_falg, post_flag, 都为true VS 其中一个为true + 3、说话转静默: 动作段播完,再进入静默(如果还在持续说话,静默段不循环) + 4、在整个视频左端点: 开始端只能正向,静默时循环,说话时走2 + 5、在整个视频右端点: 开始时只能反向,静默时循环,说话时走2 + 6、根据方向获取batch_num 数量的视频帧,return batch_idxes, current_direction + Args: + segments: 循环帧配置 [[st, ed, True], ...] + startfrom: cur_pos + batch_num: 5 + last_direction: 0反向1正向 + is_silent: 0说话态1动作态 + first_speak: 记录是不是第一次讲话 + last_speak: 记录是不是讲话结束那一刻 + """ + frames = [] + cur_pos = startfrom + cur_direction = last_direction + is_first_speak_frame = first_speak + is_last_speak_frame = True if last_speak and batch_num == 1 else False + while batch_num != 0: + # 获取当前帧的所在子分割区间 + sub_seg_idx = subseg_judge(cur_pos, segments) + # 获取移动方向 + next_direction, next_pos = get_next_direction( + segments, + cur_pos, + cur_direction, + is_silent, + sub_seg_idx, + is_first_speak_frame, + is_last_speak_frame, + ) + # 获取指定方向的帧 + next_pos = get_next_frame(next_pos, next_direction) + frames.append(next_pos) + batch_num -= 1 + is_first_speak_frame = ( + True if first_speak and batch_num == config.batch_size else False + ) + is_last_speak_frame = True if last_speak and batch_num == 1 else False + + cur_direction = next_direction + cur_pos = next_pos + return frames, next_direction + + +def subseg_judge(cur_pos, segments): + for idx, frame_seg in enumerate(segments): + if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]: + return idx + if cur_pos == 0: + return 0 + +def get_next_direction( + segments, + cur_pos, + cur_direction, + is_silent, + sub_seg_idx, + is_first_speak_frame: bool = False, + is_last_speak_frame: bool = False, +): + """ + 3.3.0 循环帧需求,想尽快走到预期状态 + if 动作段: + if 开始说话: + if 边界: + if 正向: + pass + else: + pass + else: + if 正向: + pass + else: + pass + elif 静默: + 同上 + elif 说话中: + 同上 + elif 说话结束: + 同上 + elif 静默段: + 同上 + Args: + is_first_speak_frame: 开始说话flag + is_last_speak_frame: 说话结束flag + """ + left, right, loop_flag = segments[sub_seg_idx] + if loop_flag: + if is_silent == 1: + next_direct, next_pos = pure_silent( + segments, left, right, cur_pos, cur_direction, sub_seg_idx + ) + logger.debug( + f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}" + ) + elif is_silent == 0: + next_direct, next_pos = silent2action( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_first_speak_frame, + ) + logger.debug( + f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}" + ) + else: + if is_silent == 1: + next_direct, next_pos = action2silent( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_last_speak_frame, + ) + logger.debug( + f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}" + ) + elif is_silent == 0: + next_direct, next_pos = pure_action( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_last_speak_frame, + ) + logger.debug( + f"cur_pos:{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}" + ) + return next_direct, next_pos + +def get_next_frame(cur_pos, cur_direction): + """根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的 + 不需要再做额外的边界判断 + """ + # 正向 + if cur_direction == 1: + return cur_pos + 1 + # 反向 + elif cur_direction == 0: + return cur_pos - 1 + +def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx): + """ + loop_flag == True and is_silent==1 + whether border + whether forward + Return: + next_direction + """ + # 左边界正向,右边界反向 + if cur_pos == segments[0][0]: + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + # 右边界,反向 + if cur_pos == right: + return 0, cur_pos + # 左边界,正向 + if cur_pos == left: + return 1, cur_pos + # 非边界,之前正向,则继续正向,否则反向 + if cur_pos > left and cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + + +def pure_action( + segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame +): + """ + loop_flag ==False and is_silent == 0 + 动作播完,正向到静默段 (存在跳段行为) + whether border + whether forward # 正播反播 + Args: + is_last_speak_frame: 最后说话结束时刻 + Return: next_direction + """ + if cur_pos == segments[0][0]: + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + + if is_last_speak_frame: + # 动作段在末尾,向前找静默 + if sub_seg_idx == len(segments) - 1: + return 0, cur_pos + # 动作段在开始, 向后 + if sub_seg_idx == 0: + return 1, cur_pos + # 动作段在中间,就近原则 + mid = left + (right - left + 1) // 2 + # 就近原则优先 + if cur_pos < mid: + return 0, cur_pos + else: + return 1, cur_pos + + else: + # 其他情况,播放方向一致 + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + + +def silent2action( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_first_speak_frame: bool = False, +): + """ + 在静默区间但是在讲话 + loop_flag=True and is_silent == 0 + whether border + whether forward + + Return: next_direction + """ + # 向最近的动作段移动, 如果左面没有动作段 + # TODO: 确认下面逻辑是否正确 + if ( + cur_pos == segments[0][0] + ): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行 + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + # 在静默左边界处,且仍在讲话 + if cur_pos == left: + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + # 在静默右边界处,且仍在讲话 + elif cur_pos == right: + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + else: + mid = left + (right - left + 1) // 2 + # !!就近原则只对第一次说话有效,其他情况遵循上一次状态 + if is_first_speak_frame: + # 如果第一段 + if sub_seg_idx == 0 and segments[0][2]: + return 1, cur_pos + # 如果最后一段 + elif sub_seg_idx == len(segments) - 1 and segments[-1][2]: + return 0, cur_pos + + if cur_pos < mid: + return 0, cur_pos + else: + return 1, cur_pos + else: + if cur_direction == 1: + return 1, cur_pos + elif cur_direction == 0: + return 0, cur_pos + + +def action2silent( + segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame +): + """ + loop_flag=False and is_silent==1 + whether border + Return: next_direction + """ + if cur_pos == segments[0][0]: + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + # 动作段,说话结束转静默情况下,就近原则,进入静默 + if is_last_speak_frame: + mid = left + (right - left + 1) // 2 + if cur_pos < mid: + return 0, cur_pos + else: + return 1, cur_pos + + else: + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + + +if __name__ == "__main__": + startfrom = 0 # 上一个batch的最后一帧 + frame_config= [] + audio_frame_length = len(mel_chunks) # TODO: 确认是否为 batch_size + startfrom = startfrom if startfrom>= frame_config[0][0] else frame_config[0][0] + first_speak, last_speak = True, False + is_silent= True # 当前batch是否为静默 + last_direction = 1 # -1 为反方向 + start_idx_list, current_direction = play_in_loop_v2( + frame_config, + startfrom, + audio_frame_length, + last_direction, + is_silent, + first_speak, + last_speak, + ) \ No newline at end of file diff --git a/utils/wav2lip_processor.py b/utils/wav2lip_processor.py new file mode 100644 index 0000000..03dc5c9 --- /dev/null +++ b/utils/wav2lip_processor.py @@ -0,0 +1,669 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +@File : self.py +@Time : 2023/10/11 18:23:49 +@Author : LiWei +@Version : 1.0 +""" + +import os +import threading +import time +import traceback +from concurrent.futures import ThreadPoolExecutor + +import cv2 +import numpy as np +import torch + +import constants +from config import config +from face_parsing.model import BiSeNet +from face_parsing.swap import cal_mask_single_img, image_to_parsing_ori +from high_perf.buffer.infer_buffer import Infer_Cache +from high_perf.buffer.video_buff import Video_Cache +from log import logger +from models.wav2lipv2 import Wav2Lip +from utils.util import ( + add_alpha, + datagen, + get_actor_id, + get_model_local_path, + load_face_box, + load_seg_model, + load_w2l_model, + morphing, + play_in_loop_v2, + save_raw_video, + save_raw_wav, + get_trans_idxes, + load_config +) + +USING_JIT = False +LOCAL_MODE = False +index = 1 + + +def save_(current_speechid, result_frames, audio_segments): + raw_vid = save_raw_video(current_speechid, result_frames) + raw_wav = save_raw_wav(current_speechid, audio_segments) + command = "ffmpeg -y -i {} -i {} -c:v copy -c:a aac -strict experimental -shortest {}".format( + raw_wav, raw_vid, f"temp/{current_speechid}_result_final.mp4" + ) + import platform + import subprocess + + subprocess.call(command, shell=platform.system() != "Windows") + os.remove(raw_wav) + + +class Wav2lip_Processor: + def __init__( + self, + task_queue, + result_queue, + stop_event, + resolution, + channel, + device: str = "cuda:0", + ) -> None: + self.task_queue = task_queue + self.result_queue = result_queue + self.stop_event = stop_event + self.write_lock = threading.Lock() + self.device = device + self.resolution = resolution + self.channel = channel + + pool_size = config.read_frame_pool_size + self.using_pool = True + if pool_size <= 0: + self.using_pool = False + + self.img_reading_pool = None + + if self.using_pool: + self.img_reading_pool = ThreadPoolExecutor(max_workers=pool_size) + + self.face_det_results = {} + self.w2l_models = {} + self.seg_net = None # 分割模型 + + mel_block_size = (config.wav2lip_batch_size, 80, 16) + audio_block_size = (np.prod(mel_block_size).tolist(), 1) + # 缓存区最大,缓存两秒的视频,超出的会被覆盖掉 + self.video_cache = Video_Cache( + "video_cache", + "audio_audio_cache", + "speech_id_cache", + self.resolution, + channel=self.channel, + audio_block_size=audio_block_size[0], + create=True, + ) + + self.data_cache = Infer_Cache( + "infer_cache", + create=True, + mel_block_shape=mel_block_size, + audio_block_shape=audio_block_size, + ) + # TODO: 这个verision 不生效 + self.version = "v2" + self.current_speechid = "" + # torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1]) + # torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1]) + self.model_mel_batch = torch.rand((5, 1, 80, 16), dtype=torch.float).to( + self.device + ) + self.model_img_batch = torch.rand((5, 6, 288, 288), dtype=torch.float).to( + self.device + ) + self.infer_silent_idx = 1 + self.post_morphint_idx = 1 + self.first_start = True # 判断是否为开播瞬间的静默 + self.file_index = 1 + + def prepare(self, model_speaker_id: str, need_update: bool = True): + """加载好模型文件""" + if not need_update: + return + self.model_url, self.frame_config, self.padding,self.face_classes,self.trans_method, self.infer_silent_num, self.morphing_num, pkl_version = load_config(model_speaker_id) + models_url_list = [self.model_url] + actor_info = [model_speaker_id] + for actor_url in actor_info: + # 这里不再下载资源,相关工作移到主进程 + actor_id = get_actor_id(actor_url) + self.face_det_results[actor_id] = load_face_box(actor_id, pkl_version) + self.w2l_process_state = "READY" + self.__init_seg_model(f"{config.model_folder}/79999_iter.pth") + # 预先加载所有model 模型 + # TODO: 采用旧的加载方式看下模型model 是否有错 + for model_url in models_url_list: + w2l_name = get_model_local_path(model_url) + self.__init_w2l_model(w2l_name) + logger.debug("Prepare w2l data.") + + def __init_w2l_model(self, w2l_name): + """ """ + # TODO 这个最好是可以放在启动直播间之前----本模型运行时可能会切换,目前咱们不同的模特可能模型不同 + if w2l_name in self.w2l_models: + return + + w2l_path = f"{config.model_folder}/{w2l_name}" + if USING_JIT: + net = self.__jit_script(f"{config.model_folder}/w2l.pt", model=Wav2Lip()) + else: + net = Wav2Lip() + logger.debug(f"加载模型 {w2l_name}") + self.w2l_models[w2l_name] = load_w2l_model(w2l_path, net, self.device) + + def __jit_script(self, jit_module_path: str, model): + """ + convert pytorch module to a runnable scripts. + """ + if not os.path.exists(jit_module_path): + scripted_module = torch.jit.script(model) + torch.jit.save(scripted_module, jit_module_path) + net = torch.jit.load(jit_module_path) + return net + + def __init_seg_model(self, seg_path: str): + """ """ + # 本模型可以视作永不改变 + n_classes = 19 + net = BiSeNet(n_classes=n_classes) + if USING_JIT: + net = self.__jit_script(f"{config.model_folder}/swap.pt", model=net) + self.seg_net = load_seg_model(seg_path, net, self.device) + + def __using_w2lmodel(self, model_url): + """ """ + # 取出对应的模型 + model_path = get_model_local_path(model_url) + if model_path in self.w2l_models: + return self.w2l_models[model_path] + else: + self.__init_w2l_model(model_path) + return self.w2l_models[model_path] + + def infer( + self, + is_silent, + inference_id, + pkl_name, + startfrom=0, + last_direction=1, + duration=0, + end_with_silent=False, + first_speak=False, + last_speak=False, + before_speak=False, + model_url=None, + debug: bool = False, + test_data: tuple = (), + video_idx: list = [], + ): + """ """ + start = time.perf_counter() + # TODO: 这是啥逻辑 + is_silent_ = 0 if duration < 5 else 1 + if not model_url: + model_url = self.model_url + model = self.__using_w2lmodel(model_url) + # 控制加载数据人的个数 + self.current_speechid = inference_id + if debug: + audio_segments, mel_chunks = test_data + else: + audio_segments, mel_chunks = self.data_cache.get_data() + audio_frame_length = len(mel_chunks) + + if video_idx: + mel_chunks, start_idx_list, current_direction = mel_chunks, video_idx, 1 + else: + startfrom = startfrom if startfrom>= self.frame_config[0][0] else self.frame_config[0][0] + start_idx_list, current_direction = play_in_loop_v2( + self.frame_config, + startfrom, + audio_frame_length, + last_direction, + is_silent, + is_silent_, + first_speak, + last_speak, + ) + logger.info( + f"start_idx_list :{start_idx_list}, speechid: {self.current_speechid}" + ) + gan = datagen( + self.img_reading_pool, + mel_chunks, + start_idx_list, + pkl_name, + self.face_det_results[pkl_name], + self.padding, + is_silent, + ) + self.file_index += 5 + # 3.4.0 增加静默推理过渡融合, 说话后的n帧向静默过渡 + need_pre_morphing=False + if self.trans_method == constants.TRANS_METHOD.mophing_infer: + if self.infer_silent_num or self.morphing_num or config.PRE_MORPHING_NUM: + logger.info(f"self.first_start{self.first_start}, {before_speak}, {self.current_speechid}") + if ( + is_silent == 0 + or ( + before_speak) + or ( + not self.first_start + and is_silent == 1 + and ((self.infer_silent_idx < self.infer_silent_num) or (self.post_morphint_idx < self.morphing_num)) + ) + ): + # 开始说话状态时候,要把上次计数清理 + need_infer_silent = ( + True + if is_silent == 1 + and self.infer_silent_idx < self.infer_silent_num and not before_speak and "&" not in self.current_speechid + else False + ) + need_pre_morphing = True if before_speak else False + need_post_morphing = ( + True + if is_silent == 1 + and self.morphing_num + and self.post_morphint_idx < self.morphing_num + and not need_infer_silent + and not need_pre_morphing + else False + ) + + if is_silent == 0 or "&" in self.current_speechid: + self.infer_silent_idx = 1 + self.post_morphint_idx = 1 + + self.first_start = False + logger.debug( + f"{start_idx_list} is_silent:{is_silent}, infer_silent_idx: {self.infer_silent_idx} morphint_idx:{self.post_morphint_idx} need_post_morphing: {need_post_morphing}, need_infer_silent:{need_infer_silent}, speech:{self.current_speechid},need_pre_morphing:{need_pre_morphing}" + ) + if is_silent == 1 and need_post_morphing and self.post_morphint_idx < self.morphing_num: + self.post_morphint_idx += 5 + result_frames = self.prediction( + gan, + model, + pkl_name, + need_post_morphing=need_post_morphing, + need_infer_silent=need_infer_silent, + is_silent=is_silent, + need_pre_morphing=need_pre_morphing, + ) + if is_silent == 1 and need_infer_silent and self.infer_silent_idx < self.infer_silent_num: + self.infer_silent_idx += 5 + + # 第一次静默时候,直接进入这个逻辑 + else: + self.no_prediction(gan, model, pkl_name) + else: + if is_silent==0: + self.prediction( + gan, + model, + pkl_name, + is_silent=is_silent, + need_pre_morphing=need_pre_morphing, + ) + else: + self.no_prediction(gan, model, pkl_name) + elif self.trans_method == constants.TRANS_METHOD.all_infer: + result_frames = self.prediction(gan, model, pkl_name,is_silent=is_silent) + else: + raise ValueError(f"not supported {self.trans_method}") + + if LOCAL_MODE: + saveThread = threading.Thread( + target=save_, + args=(self.current_speechid, result_frames, audio_segments), + ) + saveThread.start() + self.video_cache.put_audio(audio_segments, self.current_speechid) + + logger.debug( + f"视频合成时间{time.perf_counter() - start},inference_id:{self.current_speechid}" + ) + return [ + [0, 1, 2, 3, 4], + start_idx_list[-1], + current_direction, + ] # frames_return_list: 视频帧数据 res_index: 控制下一帧的开始位置 direction: 播放的顺序 正 反 + + @torch.no_grad() + def batch_to_tensor(self, img_batch, mel_batch, model, padding, frames, coords): + padding = 10 + logger.debug(f"非静默,推理: {self.current_speechid} 准备推理数据 {id(model)}") + img_batch_tensor = torch.as_tensor(img_batch, dtype=torch.float32).to( + self.device, non_blocking=True + ) + mel_batch_tensor = torch.as_tensor(mel_batch, dtype=torch.float32).to( + self.device, non_blocking=True + ) + + img_batch_tensor = img_batch_tensor.permute(0, 3, 1, 2) + mel_batch_tensor = mel_batch_tensor.permute(0, 3, 1, 2) + logger.debug( + f"非静默,推理: {self.current_speechid} 即将嘴型生成 {img_batch_tensor.shape} {mel_batch_tensor.shape}" + ) # torch.Size([5, 6, 288, 288]) torch.Size([5, 1, 80, 16]) + + pred_batch = model(mel_batch_tensor, img_batch_tensor) * 255.0 + + logger.debug(f"非静默,推理: {self.current_speechid} 嘴型生成完成") + pred_clone_batch = pred_batch.clone() + pred_batch_cpu = pred_batch.to(torch.uint8).to("cpu", non_blocking=True) + del pred_batch + del mel_batch_tensor + logger.debug(f"非静默,推理: {self.current_speechid} 生成面部准备好了") + large_faces = [ + frame.copy()[:, :, :3][ + y1 - padding : y2 + padding, x1 - padding : x2 + padding + ] + for frame, box in zip(frames, coords) + for y1, y2, x1, x2 in [box] + ] + logger.debug(f"非静默,推理: {self.current_speechid} 即将对生成的面部进行分割") + seg_out = image_to_parsing_ori( + pred_clone_batch, self.seg_net + ) # 假设这个函数在 GPU 上执行并输出 GPU tensor + del img_batch_tensor + del pred_clone_batch + + torch.cuda.synchronize() # 等待异步结束 + seg_out_cpu = seg_out.to( + "cpu", non_blocking=True + ) # 发起 seg_out 的异步数据传输 + del seg_out + logger.debug(f"非静默,推理: {self.current_speechid} 对生成的面部进行分割结束") + pred_batch_cpu = pred_batch_cpu.numpy().transpose( + 0, 2, 3, 1 + ) # 按需计算,用到了才计算,这样能不能使得GPU往cpu做数据copy的时间被隐藏? + for predict, frame, box in zip(pred_batch_cpu, frames, coords): + y1, y2, x1, x2 = box + width, height = x2 - x1, y2 - y1 + frame[:, :, :3][y1:y2, x1:x2] = cv2.resize( + predict, (width, height) + ) # 无需再转为unit8,gpu上直接转为unit8,这样传输规模小一些 + torch.cuda.synchronize() # 等待异步结束 + return large_faces, seg_out_cpu, frames + + def no_prediction(self, gan, model, pkl_name): + logger.debug( + f"无需推理: {self.current_speechid} {self.model_img_batch.shape} {self.model_mel_batch.shape}" + ) + with torch.no_grad(): + pred_ = model(self.model_mel_batch, self.model_img_batch) + seg_out = image_to_parsing_ori( + pred_, self.seg_net + ) # 不确定是不是需要也warmup第二个网络 + for _, _, frames, full_masks, _, end_idx in gan: + speech_ids = [ + f"{self.current_speechid}_{idx}" for idx in range(len(end_idx)) + ] + offset_list = self.video_cache.occupy_video_pos(len(end_idx)) + file_idxes, _, _ = get_trans_idxes(False, False, 0,0, self.file_index) + logger.info(f"self.file_index:{self.file_index}, file_idxes:{file_idxes}") + param_list = [ + ( + speech_id, + frame, + body_mask, + write_pos, + pkl_name, + self.video_cache, + file_idx + ) + for speech_id, frame, body_mask, write_pos, file_idx in zip( + speech_ids, + frames, + full_masks, + offset_list, + file_idxes + ) + ] + if self.using_pool: + futures = [ + self.img_reading_pool.submit(self.silent_paste_back, *param) + for param in param_list + ] + _ = [future.result() for future in futures] + self.video_cache._inject_real_pos(len(end_idx)) + # if config.debug_mode: + # for param in param_list: + # self.silent_paste_back(*param) + del pred_ + del seg_out + torch.cuda.empty_cache() + + + def silent_paste_back(self, speech_id, frame, body_mask,write_pos,pkl_name, video_cache, file_index): + global index + if frame.shape[-1] == 4: + frame[:, :, :3] = frame[:, :, :3][:, :, ::-1] + elif config.output_alpha: + frame = add_alpha(frame, body_mask, config.alpha) + logger.debug(f"self.file_index:{self.file_index}, file_index:{file_index}") + if config.debug_mode: + logger.info(f"不用推理: {file_index} {frame.shape}") + if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame): + logger.error(f"save {file_index} err") + index += 1 + logger.info(f"silent frame shape:{frame.shape}") + video_cache._put_raw_frame(frame, write_pos, speech_id) + return frame + + + def prediction( + self, + gan, + model, + pkl_name: str, + need_post_morphing: bool = False, + need_infer_silent: bool = False, + is_silent: int = 1, + need_pre_morphing: bool = False, + ): + """ """ + logger.debug(f"非静默,推理: {self.current_speechid}") + for img_batch, mel_batch, frames, full_masks, coords, end_idx in gan: + large_faces, seg_out_cpu, frames = self.batch_to_tensor( + img_batch, mel_batch, model, self.padding, frames, coords + ) + offset_list = self.video_cache.occupy_video_pos(len(end_idx)) + speech_ids = [ + f"{self.current_speechid}_{idx}" for idx in range(len(end_idx)) + ] + file_idxes, post_morphint_idxes, infer_silent_idxes = get_trans_idxes(need_post_morphing, need_infer_silent,self.post_morphint_idx,self.infer_silent_idx, self.file_index) + logger.info(f"self.file_index:{self.file_index},infer_silent_idxes:{infer_silent_idxes},post_morphint_idxes:{post_morphint_idxes}, file_idxes:{file_idxes}") + param_list = [ + ( + speech_id, + padded_image, + frame, + body_mask, + seg_mask, + boxes, + self.padding, + self.video_cache, + write_pos, + pkl_name, + need_infer_silent, + need_post_morphing, + is_silent, + need_pre_morphing, + pre_morphing_idx, + infer_silent_idx, + post_morphint_idx, + file_idx + + ) + for speech_id, padded_image, frame, body_mask, seg_mask, boxes, write_pos, pre_morphing_idx,infer_silent_idx, post_morphint_idx,file_idx in zip( + speech_ids, + large_faces, + frames, + full_masks, + seg_out_cpu, + coords, + offset_list, + list(range(config.PRE_MORPHING_NUM)), + infer_silent_idxes, + post_morphint_idxes, + file_idxes + ) + ] + result_frames = [] + if self.using_pool:# and not config.debug_mode: + futures = [ + self.img_reading_pool.submit(self.paste_back, *param) + for param in param_list + ] + _ = [future.result() for future in futures] + self.video_cache._inject_real_pos(len(end_idx)) + # if config.debug_mode: + # for param in param_list: + # frame = self.paste_back(*param) + # if LOCAL_MODE: + # result_frames.append(frame) + return result_frames + + def paste_back( + self, + current_speechid, + large_face, + frame, + body_mask, + mask, + boxes, + padding, + video_cache, + cache_start_offset, + pkl_name, + need_infer_silent: bool = False, + need_post_morphing: bool = False, + is_silent: bool = False, + need_pre_morphing: bool = False, + pre_morphing_idx: int = 0, + infer_silent_idx: int = 0, + post_morphint_idx: int = 0, + file_index: int = 1, + ): + """ + 根据模型预测结果和原始图片进行融合 + Args: + """ + global index + padding = 10 + y1, y2, x1, x2 = boxes + width, height = x2 - x1, y2 - y1 + mask = cal_mask_single_img( + mask, use_old_mode=True, face_classes=self.face_classes + ) + mask = np.repeat(mask[..., None], 3, axis=-1).astype("uint8") + + mask_temp = np.zeros_like(large_face) + mask_out = cv2.resize(mask.astype(np.float) * 255.0, (width, height)).astype( + np.uint8 + ) + mask_temp[padding : height + padding, padding : width + padding] = mask_out + kernel = np.ones((9, 9), np.uint8) + mask_temp = cv2.erode(mask_temp, kernel, iterations=1) # 二值的 + # gaosi_kernel = int(0.1 * large_face.shape[0] // 2 * 2) + 1 + # mask_temp = cv2.GaussianBlur( + # mask_temp, (gaosi_kernel, gaosi_kernel), 0, 0, cv2.BORDER_DEFAULT + # ) + mask_temp = cv2.GaussianBlur(mask_temp, (15, 15), 0, 0, cv2.BORDER_DEFAULT) + mask_temp = cv2.GaussianBlur(mask_temp, (5, 5), 0, 0, cv2.BORDER_DEFAULT) + + f_background = large_face.copy() + + frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding] = ( + f_background * (1 - mask_temp / 255.0) + + frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding] + * (mask_temp / 255.0) + )#.astype("uint8") + if self.trans_method == constants.TRANS_METHOD.mophing_infer: + if is_silent == 1: + if need_pre_morphing: + frame = morphing( + large_face, + frame, + boxes, + mp_ratio=1 - ((pre_morphing_idx + 1) / config.PRE_MORPHING_NUM), + file_index=file_index, + ) + logger.debug( + f"file_index:{file_index},pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}" + ) + logger.debug(f"pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}") + #TODO: @txueduo 处理前过渡问题 + elif need_post_morphing and post_morphint_idx: + mp_ratio = (post_morphint_idx) / self.morphing_num + frame = morphing( + large_face, + frame, + boxes, + mp_ratio=mp_ratio, + file_index=file_index + ) + logger.debug(f"post_morphint_idx:{post_morphint_idx}, mp_ratio:{mp_ratio}, file_index:{file_index}, speech_id:{current_speechid}") + if frame.shape[-1]==4:# and not config.output_alpha: + frame[:,:,:3] = frame[:,:,:3][:,:,::-1] + if config.output_alpha and frame.shape[-1]!=4: + frame = add_alpha(frame, body_mask, config.alpha) + if config.debug_mode: + logger.info(f"推理:{file_index}") + if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame): + logger.error(f"save {file_index} err") + video_cache._put_raw_frame(frame, cache_start_offset, current_speechid) + index += 1 + return frame + + def destroy(self): + """ """ + self.data_cache.destroy() + self.video_cache.destroy() + self.data_cache = None + self.video_cache = None + if self.using_pool and self.img_reading_pool is not None: + self.img_reading_pool.shutdown() + self.img_reading_pool = None + del self.model_img_batch + del self.model_mel_batch + + +def process_wav2lip_predict(task_queue, stop_event, output_queue, resolution, channel): + # 实例化推理类 + w2l_processor = Wav2lip_Processor( + task_queue, output_queue, stop_event, resolution, channel + ) + need_update = True + while True: + if stop_event.is_set(): + print("----------------------stop..") + w2l_processor.destroy() # 这一行代码可能有bug,没有同步的控制代码,可能在读的地方会有异常 + break + try: + params = task_queue.get() + # TODO: 这个操作应该放在更前面, 确认是否更新的判定条件 + start = time.perf_counter() + w2l_processor.prepare(params[2], need_update) + need_update = False + result = w2l_processor.infer(*params) + logger.debug( + f"推理结束 :{time.perf_counter() - start},inference_id:{params[1]}" + ) + output_queue.put(result) + logger.debug( + f"结果通知到主进程{time.perf_counter() - start},inference_id:{params[1]}" + ) + except Exception: + logger.error(f"process_wav2lip_predict :{traceback.format_exc()}")