#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ @File : self.py @Time : 2023/10/11 18:23:49 @Author : LiWei @Version : 1.0 """ import os import threading import time import traceback from concurrent.futures import ThreadPoolExecutor import cv2 import numpy as np import torch import constants from config import config from face_parsing.model import BiSeNet from face_parsing.swap import cal_mask_single_img, image_to_parsing_ori from high_perf.buffer.infer_buffer import Infer_Cache from high_perf.buffer.video_buff import Video_Cache from log import logger from models.wav2lipv2 import Wav2Lip from utils.util import ( add_alpha, datagen, get_actor_id, get_model_local_path, load_face_box, load_seg_model, load_w2l_model, morphing, play_in_loop_v2, save_raw_video, save_raw_wav, get_trans_idxes, load_config ) USING_JIT = False LOCAL_MODE = False index = 1 def save_(current_speechid, result_frames, audio_segments): raw_vid = save_raw_video(current_speechid, result_frames) raw_wav = save_raw_wav(current_speechid, audio_segments) command = "ffmpeg -y -i {} -i {} -c:v copy -c:a aac -strict experimental -shortest {}".format( raw_wav, raw_vid, f"temp/{current_speechid}_result_final.mp4" ) import platform import subprocess subprocess.call(command, shell=platform.system() != "Windows") os.remove(raw_wav) class Wav2lip_Processor: def __init__( self, task_queue, result_queue, stop_event, resolution, channel, device: str = "cuda:0", ) -> None: self.task_queue = task_queue self.result_queue = result_queue self.stop_event = stop_event self.write_lock = threading.Lock() self.device = device self.resolution = resolution self.channel = channel pool_size = config.read_frame_pool_size self.using_pool = True if pool_size <= 0: self.using_pool = False self.img_reading_pool = None if self.using_pool: self.img_reading_pool = ThreadPoolExecutor(max_workers=pool_size) self.face_det_results = {} self.w2l_models = {} self.seg_net = None # 分割模型 mel_block_size = (config.wav2lip_batch_size, 80, 16) audio_block_size = (np.prod(mel_block_size).tolist(), 1) # 缓存区最大,缓存两秒的视频,超出的会被覆盖掉 self.video_cache = Video_Cache( "video_cache", "audio_audio_cache", "speech_id_cache", self.resolution, channel=self.channel, audio_block_size=audio_block_size[0], create=True, ) self.data_cache = Infer_Cache( "infer_cache", create=True, mel_block_shape=mel_block_size, audio_block_shape=audio_block_size, ) # TODO: 这个verision 不生效 self.version = "v2" self.current_speechid = "" # torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1]) # torch.Size([5, 288, 288, 6]) torch.Size([5, 80, 16, 1]) self.model_mel_batch = torch.rand((5, 1, 80, 16), dtype=torch.float).to( self.device ) self.model_img_batch = torch.rand((5, 6, 288, 288), dtype=torch.float).to( self.device ) self.infer_silent_idx = 1 self.post_morphint_idx = 1 self.first_start = True # 判断是否为开播瞬间的静默 self.file_index = 1 def prepare(self, model_speaker_id: str, need_update: bool = True): """加载好模型文件""" if not need_update: return self.model_url, self.frame_config, self.padding,self.face_classes,self.trans_method, self.infer_silent_num, self.morphing_num, pkl_version = load_config(model_speaker_id) models_url_list = [self.model_url] actor_info = [model_speaker_id] for actor_url in actor_info: # 这里不再下载资源,相关工作移到主进程 actor_id = get_actor_id(actor_url) self.face_det_results[actor_id] = load_face_box(actor_id, pkl_version) self.w2l_process_state = "READY" self.__init_seg_model(f"{config.model_folder}/79999_iter.pth") # 预先加载所有model 模型 # TODO: 采用旧的加载方式看下模型model 是否有错 for model_url in models_url_list: w2l_name = get_model_local_path(model_url) self.__init_w2l_model(w2l_name) logger.debug("Prepare w2l data.") def __init_w2l_model(self, w2l_name): """ """ # TODO 这个最好是可以放在启动直播间之前----本模型运行时可能会切换,目前咱们不同的模特可能模型不同 if w2l_name in self.w2l_models: return w2l_path = f"{config.model_folder}/{w2l_name}" if USING_JIT: net = self.__jit_script(f"{config.model_folder}/w2l.pt", model=Wav2Lip()) else: net = Wav2Lip() logger.debug(f"加载模型 {w2l_name}") self.w2l_models[w2l_name] = load_w2l_model(w2l_path, net, self.device) def __jit_script(self, jit_module_path: str, model): """ convert pytorch module to a runnable scripts. """ if not os.path.exists(jit_module_path): scripted_module = torch.jit.script(model) torch.jit.save(scripted_module, jit_module_path) net = torch.jit.load(jit_module_path) return net def __init_seg_model(self, seg_path: str): """ """ # 本模型可以视作永不改变 n_classes = 19 net = BiSeNet(n_classes=n_classes) if USING_JIT: net = self.__jit_script(f"{config.model_folder}/swap.pt", model=net) self.seg_net = load_seg_model(seg_path, net, self.device) def __using_w2lmodel(self, model_url): """ """ # 取出对应的模型 model_path = get_model_local_path(model_url) if model_path in self.w2l_models: return self.w2l_models[model_path] else: self.__init_w2l_model(model_path) return self.w2l_models[model_path] def infer( self, is_silent, inference_id, pkl_name, startfrom=0, last_direction=1, duration=0, end_with_silent=False, first_speak=False, last_speak=False, before_speak=False, model_url=None, debug: bool = False, test_data: tuple = (), video_idx: list = [], ): """ """ start = time.perf_counter() # TODO: 这是啥逻辑 is_silent_ = 0 if duration < 5 else 1 if not model_url: model_url = self.model_url model = self.__using_w2lmodel(model_url) # 控制加载数据人的个数 self.current_speechid = inference_id if debug: audio_segments, mel_chunks = test_data else: audio_segments, mel_chunks = self.data_cache.get_data() audio_frame_length = len(mel_chunks) if video_idx: mel_chunks, start_idx_list, current_direction = mel_chunks, video_idx, 1 else: startfrom = startfrom if startfrom>= self.frame_config[0][0] else self.frame_config[0][0] start_idx_list, current_direction = play_in_loop_v2( self.frame_config, startfrom, audio_frame_length, last_direction, is_silent, is_silent_, first_speak, last_speak, ) logger.info( f"start_idx_list :{start_idx_list}, speechid: {self.current_speechid}" ) gan = datagen( self.img_reading_pool, mel_chunks, start_idx_list, pkl_name, self.face_det_results[pkl_name], self.padding, is_silent, ) self.file_index += 5 # 3.4.0 增加静默推理过渡融合, 说话后的n帧向静默过渡 need_pre_morphing=False if self.trans_method == constants.TRANS_METHOD.mophing_infer: if self.infer_silent_num or self.morphing_num or config.PRE_MORPHING_NUM: logger.info(f"self.first_start{self.first_start}, {before_speak}, {self.current_speechid}") if ( is_silent == 0 or ( before_speak) or ( not self.first_start and is_silent == 1 and ((self.infer_silent_idx < self.infer_silent_num) or (self.post_morphint_idx < self.morphing_num)) ) ): # 开始说话状态时候,要把上次计数清理 need_infer_silent = ( True if is_silent == 1 and self.infer_silent_idx < self.infer_silent_num and not before_speak and "&" not in self.current_speechid else False ) need_pre_morphing = True if before_speak else False need_post_morphing = ( True if is_silent == 1 and self.morphing_num and self.post_morphint_idx < self.morphing_num and not need_infer_silent and not need_pre_morphing else False ) if is_silent == 0 or "&" in self.current_speechid: self.infer_silent_idx = 1 self.post_morphint_idx = 1 self.first_start = False logger.debug( f"{start_idx_list} is_silent:{is_silent}, infer_silent_idx: {self.infer_silent_idx} morphint_idx:{self.post_morphint_idx} need_post_morphing: {need_post_morphing}, need_infer_silent:{need_infer_silent}, speech:{self.current_speechid},need_pre_morphing:{need_pre_morphing}" ) if is_silent == 1 and need_post_morphing and self.post_morphint_idx < self.morphing_num: self.post_morphint_idx += 5 result_frames = self.prediction( gan, model, pkl_name, need_post_morphing=need_post_morphing, need_infer_silent=need_infer_silent, is_silent=is_silent, need_pre_morphing=need_pre_morphing, ) if is_silent == 1 and need_infer_silent and self.infer_silent_idx < self.infer_silent_num: self.infer_silent_idx += 5 # 第一次静默时候,直接进入这个逻辑 else: self.no_prediction(gan, model, pkl_name) else: if is_silent==0: self.prediction( gan, model, pkl_name, is_silent=is_silent, need_pre_morphing=need_pre_morphing, ) else: self.no_prediction(gan, model, pkl_name) elif self.trans_method == constants.TRANS_METHOD.all_infer: result_frames = self.prediction(gan, model, pkl_name,is_silent=is_silent) else: raise ValueError(f"not supported {self.trans_method}") if LOCAL_MODE: saveThread = threading.Thread( target=save_, args=(self.current_speechid, result_frames, audio_segments), ) saveThread.start() self.video_cache.put_audio(audio_segments, self.current_speechid) logger.debug( f"视频合成时间{time.perf_counter() - start},inference_id:{self.current_speechid}" ) return [ [0, 1, 2, 3, 4], start_idx_list[-1], current_direction, ] # frames_return_list: 视频帧数据 res_index: 控制下一帧的开始位置 direction: 播放的顺序 正 反 @torch.no_grad() def batch_to_tensor(self, img_batch, mel_batch, model, padding, frames, coords): padding = 10 logger.debug(f"非静默,推理: {self.current_speechid} 准备推理数据 {id(model)}") img_batch_tensor = torch.as_tensor(img_batch, dtype=torch.float32).to( self.device, non_blocking=True ) mel_batch_tensor = torch.as_tensor(mel_batch, dtype=torch.float32).to( self.device, non_blocking=True ) img_batch_tensor = img_batch_tensor.permute(0, 3, 1, 2) mel_batch_tensor = mel_batch_tensor.permute(0, 3, 1, 2) logger.debug( f"非静默,推理: {self.current_speechid} 即将嘴型生成 {img_batch_tensor.shape} {mel_batch_tensor.shape}" ) # torch.Size([5, 6, 288, 288]) torch.Size([5, 1, 80, 16]) pred_batch = model(mel_batch_tensor, img_batch_tensor) * 255.0 logger.debug(f"非静默,推理: {self.current_speechid} 嘴型生成完成") pred_clone_batch = pred_batch.clone() pred_batch_cpu = pred_batch.to(torch.uint8).to("cpu", non_blocking=True) del pred_batch del mel_batch_tensor logger.debug(f"非静默,推理: {self.current_speechid} 生成面部准备好了") large_faces = [ frame.copy()[:, :, :3][ y1 - padding : y2 + padding, x1 - padding : x2 + padding ] for frame, box in zip(frames, coords) for y1, y2, x1, x2 in [box] ] logger.debug(f"非静默,推理: {self.current_speechid} 即将对生成的面部进行分割") seg_out = image_to_parsing_ori( pred_clone_batch, self.seg_net ) # 假设这个函数在 GPU 上执行并输出 GPU tensor del img_batch_tensor del pred_clone_batch torch.cuda.synchronize() # 等待异步结束 seg_out_cpu = seg_out.to( "cpu", non_blocking=True ) # 发起 seg_out 的异步数据传输 del seg_out logger.debug(f"非静默,推理: {self.current_speechid} 对生成的面部进行分割结束") pred_batch_cpu = pred_batch_cpu.numpy().transpose( 0, 2, 3, 1 ) # 按需计算,用到了才计算,这样能不能使得GPU往cpu做数据copy的时间被隐藏? for predict, frame, box in zip(pred_batch_cpu, frames, coords): y1, y2, x1, x2 = box width, height = x2 - x1, y2 - y1 frame[:, :, :3][y1:y2, x1:x2] = cv2.resize( predict, (width, height) ) # 无需再转为unit8,gpu上直接转为unit8,这样传输规模小一些 torch.cuda.synchronize() # 等待异步结束 return large_faces, seg_out_cpu, frames def no_prediction(self, gan, model, pkl_name): logger.debug( f"无需推理: {self.current_speechid} {self.model_img_batch.shape} {self.model_mel_batch.shape}" ) with torch.no_grad(): pred_ = model(self.model_mel_batch, self.model_img_batch) seg_out = image_to_parsing_ori( pred_, self.seg_net ) # 不确定是不是需要也warmup第二个网络 for _, _, frames, full_masks, _, end_idx in gan: speech_ids = [ f"{self.current_speechid}_{idx}" for idx in range(len(end_idx)) ] offset_list = self.video_cache.occupy_video_pos(len(end_idx)) file_idxes, _, _ = get_trans_idxes(False, False, 0,0, self.file_index) logger.info(f"self.file_index:{self.file_index}, file_idxes:{file_idxes}") param_list = [ ( speech_id, frame, body_mask, write_pos, pkl_name, self.video_cache, file_idx ) for speech_id, frame, body_mask, write_pos, file_idx in zip( speech_ids, frames, full_masks, offset_list, file_idxes ) ] if self.using_pool: futures = [ self.img_reading_pool.submit(self.silent_paste_back, *param) for param in param_list ] _ = [future.result() for future in futures] self.video_cache._inject_real_pos(len(end_idx)) # if config.debug_mode: # for param in param_list: # self.silent_paste_back(*param) del pred_ del seg_out torch.cuda.empty_cache() def silent_paste_back(self, speech_id, frame, body_mask,write_pos,pkl_name, video_cache, file_index): global index if frame.shape[-1] == 4: frame[:, :, :3] = frame[:, :, :3][:, :, ::-1] elif config.output_alpha: frame = add_alpha(frame, body_mask, config.alpha) logger.debug(f"self.file_index:{self.file_index}, file_index:{file_index}") if config.debug_mode: logger.info(f"不用推理: {file_index} {frame.shape}") if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame): logger.error(f"save {file_index} err") index += 1 logger.info(f"silent frame shape:{frame.shape}") video_cache._put_raw_frame(frame, write_pos, speech_id) return frame def prediction( self, gan, model, pkl_name: str, need_post_morphing: bool = False, need_infer_silent: bool = False, is_silent: int = 1, need_pre_morphing: bool = False, ): """ """ logger.debug(f"非静默,推理: {self.current_speechid}") for img_batch, mel_batch, frames, full_masks, coords, end_idx in gan: large_faces, seg_out_cpu, frames = self.batch_to_tensor( img_batch, mel_batch, model, self.padding, frames, coords ) offset_list = self.video_cache.occupy_video_pos(len(end_idx)) speech_ids = [ f"{self.current_speechid}_{idx}" for idx in range(len(end_idx)) ] file_idxes, post_morphint_idxes, infer_silent_idxes = get_trans_idxes(need_post_morphing, need_infer_silent,self.post_morphint_idx,self.infer_silent_idx, self.file_index) logger.info(f"self.file_index:{self.file_index},infer_silent_idxes:{infer_silent_idxes},post_morphint_idxes:{post_morphint_idxes}, file_idxes:{file_idxes}") param_list = [ ( speech_id, padded_image, frame, body_mask, seg_mask, boxes, self.padding, self.video_cache, write_pos, pkl_name, need_infer_silent, need_post_morphing, is_silent, need_pre_morphing, pre_morphing_idx, infer_silent_idx, post_morphint_idx, file_idx ) for speech_id, padded_image, frame, body_mask, seg_mask, boxes, write_pos, pre_morphing_idx,infer_silent_idx, post_morphint_idx,file_idx in zip( speech_ids, large_faces, frames, full_masks, seg_out_cpu, coords, offset_list, list(range(config.PRE_MORPHING_NUM)), infer_silent_idxes, post_morphint_idxes, file_idxes ) ] result_frames = [] if self.using_pool:# and not config.debug_mode: futures = [ self.img_reading_pool.submit(self.paste_back, *param) for param in param_list ] _ = [future.result() for future in futures] self.video_cache._inject_real_pos(len(end_idx)) # if config.debug_mode: # for param in param_list: # frame = self.paste_back(*param) # if LOCAL_MODE: # result_frames.append(frame) return result_frames def paste_back( self, current_speechid, large_face, frame, body_mask, mask, boxes, padding, video_cache, cache_start_offset, pkl_name, need_infer_silent: bool = False, need_post_morphing: bool = False, is_silent: bool = False, need_pre_morphing: bool = False, pre_morphing_idx: int = 0, infer_silent_idx: int = 0, post_morphint_idx: int = 0, file_index: int = 1, ): """ 根据模型预测结果和原始图片进行融合 Args: """ global index padding = 10 y1, y2, x1, x2 = boxes width, height = x2 - x1, y2 - y1 mask = cal_mask_single_img( mask, use_old_mode=True, face_classes=self.face_classes ) mask = np.repeat(mask[..., None], 3, axis=-1).astype("uint8") mask_temp = np.zeros_like(large_face) mask_out = cv2.resize(mask.astype(np.float) * 255.0, (width, height)).astype( np.uint8 ) mask_temp[padding : height + padding, padding : width + padding] = mask_out kernel = np.ones((9, 9), np.uint8) mask_temp = cv2.erode(mask_temp, kernel, iterations=1) # 二值的 # gaosi_kernel = int(0.1 * large_face.shape[0] // 2 * 2) + 1 # mask_temp = cv2.GaussianBlur( # mask_temp, (gaosi_kernel, gaosi_kernel), 0, 0, cv2.BORDER_DEFAULT # ) mask_temp = cv2.GaussianBlur(mask_temp, (15, 15), 0, 0, cv2.BORDER_DEFAULT) mask_temp = cv2.GaussianBlur(mask_temp, (5, 5), 0, 0, cv2.BORDER_DEFAULT) f_background = large_face.copy() frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding] = ( f_background * (1 - mask_temp / 255.0) + frame[:, :, :3][y1 - padding : y2 + padding, x1 - padding : x2 + padding] * (mask_temp / 255.0) )#.astype("uint8") if self.trans_method == constants.TRANS_METHOD.mophing_infer: if is_silent == 1: if need_pre_morphing: frame = morphing( large_face, frame, boxes, mp_ratio=1 - ((pre_morphing_idx + 1) / config.PRE_MORPHING_NUM), file_index=file_index, ) logger.debug( f"file_index:{file_index},pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}" ) logger.debug(f"pre morphing_idx {pre_morphing_idx}, speech_id:{current_speechid}") #TODO: @txueduo 处理前过渡问题 elif need_post_morphing and post_morphint_idx: mp_ratio = (post_morphint_idx) / self.morphing_num frame = morphing( large_face, frame, boxes, mp_ratio=mp_ratio, file_index=file_index ) logger.debug(f"post_morphint_idx:{post_morphint_idx}, mp_ratio:{mp_ratio}, file_index:{file_index}, speech_id:{current_speechid}") if frame.shape[-1]==4:# and not config.output_alpha: frame[:,:,:3] = frame[:,:,:3][:,:,::-1] if config.output_alpha and frame.shape[-1]!=4: frame = add_alpha(frame, body_mask, config.alpha) if config.debug_mode: logger.info(f"推理:{file_index}") if not cv2.imwrite(f"temp/{pkl_name}/new_img{file_index:05d}.jpg", frame): logger.error(f"save {file_index} err") video_cache._put_raw_frame(frame, cache_start_offset, current_speechid) index += 1 return frame def destroy(self): """ """ self.data_cache.destroy() self.video_cache.destroy() self.data_cache = None self.video_cache = None if self.using_pool and self.img_reading_pool is not None: self.img_reading_pool.shutdown() self.img_reading_pool = None del self.model_img_batch del self.model_mel_batch def process_wav2lip_predict(task_queue, stop_event, output_queue, resolution, channel): # 实例化推理类 w2l_processor = Wav2lip_Processor( task_queue, output_queue, stop_event, resolution, channel ) need_update = True while True: if stop_event.is_set(): print("----------------------stop..") w2l_processor.destroy() # 这一行代码可能有bug,没有同步的控制代码,可能在读的地方会有异常 break try: params = task_queue.get() # TODO: 这个操作应该放在更前面, 确认是否更新的判定条件 start = time.perf_counter() w2l_processor.prepare(params[2], need_update) need_update = False result = w2l_processor.infer(*params) logger.debug( f"推理结束 :{time.perf_counter() - start},inference_id:{params[1]}" ) output_queue.put(result) logger.debug( f"结果通知到主进程{time.perf_counter() - start},inference_id:{params[1]}" ) except Exception: logger.error(f"process_wav2lip_predict :{traceback.format_exc()}")