From f3dcbdc876524aa65e668dc041c6fd4ae6af58b2 Mon Sep 17 00:00:00 2001 From: jocelyn Date: Tue, 10 Jun 2025 15:04:35 +0800 Subject: [PATCH] [ADD]add logic of loop frame --- utils/log.py | 119 ++++++ utils/loop_frame_tool.py | 318 ++++++++++++++ utils/util.py | 889 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 1326 insertions(+) create mode 100644 utils/log.py create mode 100644 utils/loop_frame_tool.py create mode 100644 utils/util.py diff --git a/utils/log.py b/utils/log.py new file mode 100644 index 0000000..a531729 --- /dev/null +++ b/utils/log.py @@ -0,0 +1,119 @@ +import logging +import os +import sys +import redis +from loguru import logger as logurulogger +import redis.exceptions +from app.config import config +import json +from redis.retry import Retry +from redis.backoff import ExponentialBackoff + +LOG_FORMAT = ( + "{level: <8} " + "{process.name} | " # 进程名 + "{thread.name} | " + "{time:YYYY-MM-DD HH:mm:ss.SSS} - " + "{process} " + "{module}.{function}:{line} - " + "{message}" +) +LOG_NAME = ["uvicorn", "uvicorn.access", "uvicorn.error", "flask"] + +# 配置 Redis 连接池 +redis_pool = redis.ConnectionPool( + host=config.LOG_REDIS_HOST, # Redis 服务器地址 + port=config.LOG_REDIS_PORT, # Redis 服务器端口 + db=config.LOG_REDIS_DB, # 数据库编号 + password=config.LOG_REDIS_AUTH, # 密码 + max_connections=config.max_connections, # 最大连接数 + socket_connect_timeout=config.socket_connect_timeout, # 连接超时时间 + socket_timeout=config.socket_timeout, # 等待超时时间 +) + + +class InterceptHandler(logging.Handler): + def emit(self, record): + try: + level = logurulogger.level(record.levelname).name + except AttributeError: + level = logging._levelToName[record.levelno] + + frame, depth = logging.currentframe(), 2 + while frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logurulogger.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + +class Logging: + """自定义日志""" + + def __init__(self): + self.log_path = "logs" + self._connect_redis() + if config.IS_LOCAL: + os.makedirs(self.log_path, exist_ok=True) + self._initlogger() + self._reset_log_handler() + + def _connect_redis(self): + retry = Retry(ExponentialBackoff(), 3) # 重试3次,指数退避 + self.redis_client = redis.Redis(connection_pool=redis_pool,retry=retry) # 使用连接池 + + def _initlogger(self): + """初始化loguru配置""" + logurulogger.remove() + if config.IS_LOCAL: + logurulogger.add( + os.path.join(self.log_path, "error.log.{time:YYYY-MM-DD}"), + format=LOG_FORMAT, + level=logging.ERROR, + rotation="00:00", + retention="1 week", + backtrace=True, + diagnose=True, + enqueue=True + ) + logurulogger.add( + os.path.join(self.log_path, "info.log.{time:YYYY-MM-DD}"), + format=LOG_FORMAT, + level=logging.INFO, + rotation="00:00", + retention="1 week", + enqueue=True + ) + logurulogger.add( + sys.stdout, + format=LOG_FORMAT, + level=logging.DEBUG, + colorize=True, + ) + + logurulogger.add(self._log_to_redis, level="INFO", format=LOG_FORMAT) + self.logger = logurulogger + + + def _log_to_redis(self, message): + """将日志写入 Redis 列表""" + try: + self.redis_client.rpush(f"nlp.logger.{config.env_version}.log", json.dumps({"message": message})) + except redis.exceptions.ConnectionError as e: + logger.error(f"write {message} Redis connection error: {e}") + except redis.exceptions.TimeoutError as e: + logger.error(f"write {message} Redis operation timed out: {e}") + except Exception as e: + logger.error(f"write {message} Unexpected error: {e}") + + def _reset_log_handler(self): + for log in LOG_NAME: + logger = logging.getLogger(log) + logger.handlers = [InterceptHandler()] + + def getlogger(self): + return self.logger + +logger = Logging().getlogger() + diff --git a/utils/loop_frame_tool.py b/utils/loop_frame_tool.py new file mode 100644 index 0000000..9343cab --- /dev/null +++ b/utils/loop_frame_tool.py @@ -0,0 +1,318 @@ +from utils.log import logger + + +def play_in_loop_v2( + segments, + startfrom, + batch_num, + last_direction, + is_silent, + is_silent_, + first_speak, + last_speak, +): + """ + batch_num: 初始和结束,每一帧都这么判断 + 1、静默时,在静默段循环, 左边界正向,右边界反向, 根据上一次方向和位置,给出新的方向和位置 + 2、静默转说话: 就近到说话段,pre_falg, post_flag, 都为true VS 其中一个为true + 3、说话转静默: 动作段播完,再进入静默(如果还在持续说话,静默段不循环) + 4、在整个视频左端点: 开始端只能正向,静默时循环,说话时走2 + 5、在整个视频右端点: 开始时只能反向,静默时循环,说话时走2 + 6、根据方向获取batch_num 数量的视频帧,return batch_idxes, current_direction + Args: + segments: 循环帧配置 [[st, ed, True], ...] + startfrom: cur_pos + batch_num: 5 + last_direction: 0反向1正向 + is_silent: 0说话态1动作态 + is_silent_: 目前不明确,后面可能废弃 + first_speak: 记录是不是第一次讲话 + last_speak: 记录是不是讲话结束那一刻 + """ + frames = [] + cur_pos = startfrom + cur_direction = last_direction + is_first_speak_frame = first_speak + is_last_speak_frame = True if last_speak and batch_num == 1 else False + while batch_num != 0: + # 获取当前帧的所在子分割区间 + sub_seg_idx = subseg_judge(cur_pos, segments) + # 获取移动方向 + next_direction, next_pos = get_next_direction( + segments, + cur_pos, + cur_direction, + is_silent, + sub_seg_idx, + is_first_speak_frame, + is_last_speak_frame, + ) + # 获取指定方向的帧 + next_pos = get_next_frame(next_pos, next_direction) + frames.append(next_pos) + batch_num -= 1 + is_first_speak_frame = ( + True if first_speak and batch_num == config.batch_size else False + ) + is_last_speak_frame = True if last_speak and batch_num == 1 else False + + cur_direction = next_direction + cur_pos = next_pos + return frames, next_direction + + +def subseg_judge(cur_pos, segments): + for idx, frame_seg in enumerate(segments): + if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]: + return idx + if cur_pos == 0: + return 0 + +def get_next_direction( + segments, + cur_pos, + cur_direction, + is_silent, + sub_seg_idx, + is_first_speak_frame: bool = False, + is_last_speak_frame: bool = False, +): + """ + 3.3.0 循环帧需求,想尽快走到预期状态 + if 动作段: + if 开始说话: + if 边界: + if 正向: + pass + else: + pass + else: + if 正向: + pass + else: + pass + elif 静默: + 同上 + elif 说话中: + 同上 + elif 说话结束: + 同上 + elif 静默段: + 同上 + Args: + is_first_speak_frame: 开始说话flag + is_last_speak_frame: 说话结束flag + """ + left, right, loop_flag = segments[sub_seg_idx] + if loop_flag: + if is_silent == 1: + next_direct, next_pos = pure_silent( + segments, left, right, cur_pos, cur_direction, sub_seg_idx + ) + logger.debug( + f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}" + ) + elif is_silent == 0: + next_direct, next_pos = silent2action( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_first_speak_frame, + ) + logger.debug( + f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}" + ) + else: + if is_silent == 1: + next_direct, next_pos = action2silent( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_last_speak_frame, + ) + logger.debug( + f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}" + ) + elif is_silent == 0: + next_direct, next_pos = pure_action( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_last_speak_frame, + ) + logger.debug( + f"cur_pos:{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}" + ) + return next_direct, next_pos + +def get_next_frame(cur_pos, cur_direction): + """根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的 + 不需要再做额外的边界判断 + """ + # 正向 + if cur_direction == 1: + return cur_pos + 1 + # 反向 + elif cur_direction == 0: + return cur_pos - 1 + +def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx): + """ + loop_flag == True and is_silent==1 + whether border + whether forward + Return: + next_direction + """ + # 左边界正向,右边界反向 + if cur_pos == segments[0][0]: + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + # 右边界,反向 + if cur_pos == right: + return 0, cur_pos + # 左边界,正向 + if cur_pos == left: + return 1, cur_pos + # 非边界,之前正向,则继续正向,否则反向 + if cur_pos > left and cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + + +def pure_action( + segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame +): + """ + loop_flag ==False and is_silent == 0 + 动作播完,正向到静默段 (存在跳段行为) + whether border + whether forward # 正播反播 + Args: + is_last_speak_frame: 最后说话结束时刻 + Return: next_direction + """ + if cur_pos == segments[0][0]: + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + + if is_last_speak_frame: + # 动作段在末尾,向前找静默 + if sub_seg_idx == len(segments) - 1: + return 0, cur_pos + # 动作段在开始, 向后 + if sub_seg_idx == 0: + return 1, cur_pos + # 动作段在中间,就近原则 + mid = left + (right - left + 1) // 2 + # 就近原则优先 + if cur_pos < mid: + return 0, cur_pos + else: + return 1, cur_pos + + else: + # 其他情况,播放方向一致 + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + + +def silent2action( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_first_speak_frame: bool = False, +): + """ + 在静默区间但是在讲话 + loop_flag=True and is_silent == 0 + whether border + whether forward + + Return: next_direction + """ + # 向最近的动作段移动, 如果左面没有动作段 + # TODO: 确认下面逻辑是否正确 + if ( + cur_pos == segments[0][0] + ): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行 + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + # 在静默左边界处,且仍在讲话 + if cur_pos == left: + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + # 在静默右边界处,且仍在讲话 + elif cur_pos == right: + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + else: + mid = left + (right - left + 1) // 2 + # !!就近原则只对第一次说话有效,其他情况遵循上一次状态 + if is_first_speak_frame: + # 如果第一段 + if sub_seg_idx == 0 and segments[0][2]: + return 1, cur_pos + # 如果最后一段 + elif sub_seg_idx == len(segments) - 1 and segments[-1][2]: + return 0, cur_pos + + if cur_pos < mid: + return 0, cur_pos + else: + return 1, cur_pos + else: + if cur_direction == 1: + return 1, cur_pos + elif cur_direction == 0: + return 0, cur_pos + + +def action2silent( + segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame +): + """ + loop_flag=False and is_silent==1 + whether border + Return: next_direction + """ + if cur_pos == segments[0][0]: + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + # 动作段,说话结束转静默情况下,就近原则,进入静默 + if is_last_speak_frame: + mid = left + (right - left + 1) // 2 + if cur_pos < mid: + return 0, cur_pos + else: + return 1, cur_pos + + else: + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..577b1b8 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,889 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +import os +import pickle +import numpy as np +import zipfile + +from config import config +import cv2 +import torch +from turbojpeg import TurboJPEG +import platform +import wave +import requests + +from log import logger +import time +import shutil +from utils import create_oss +from typing import Union +import traceback + + +def unzip(zip_path, save_path): + with zipfile.ZipFile(zip_path, "r") as file: + file.extractall(save_path) + + +if platform.system() == "Linux": + jpeg = TurboJPEG() +else: + jpeg = TurboJPEG(lib_path=r"libjpeg-turbo-gcc64\bin\libturbojpeg.dll") + + +def datagen( + pool2, + mel_batch, + start_idx_list, + pk_name, + face_det_results, + padding, + is_silent: int = 1, +): + """ + 数据批量预处理 + Args: + is_silent: 是否为静默帧,静默帧读取后处理过的帧 + """ + img_batch, frame_batch, coords_batch, mask_batch = [], [], [], [] + futures = [] + + return_gan = [] + for start_idx in start_idx_list: + future = pool2.submit( + get_face_image, + pk_name, + start_idx, + face_det_results, + padding, + is_silent, + ) + futures.append(future) + + for future in futures: + frame_to_save, coords, face, mask = future.result() + img_batch.append(face) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + mask_batch.append(mask) + + if len(img_batch) >= config.wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + img_masked = img_batch.copy() + img_masked[:, config.img_size // 2 :] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 + mel_batch = mel_batch[..., np.newaxis] + return_gan.append( + ( + img_batch, + mel_batch, + frame_batch, + mask_batch, + coords_batch, + start_idx_list, + ) + ) + img_batch, frame_batch, mask_batch, coords_batch = [], [], [], [] + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + img_masked = img_batch.copy() + img_masked[:, config.img_size // 2 :] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 + mel_batch = mel_batch[..., np.newaxis] + return_gan.append( + ( + img_batch, + mel_batch, + frame_batch, + mask_batch, + coords_batch, + start_idx_list, + ) + ) + + return return_gan + + +def get_face_image( + pkl_name, + start_idx, + face_det_results, + infer_padding, + is_silent: int = 1, +): + # 判断是否经过换背景, + post_suffix = ( + "png" if not os.path.exists(f"model_people/mask/{pkl_name}") else "jpg" + ) + pre_suffix = os.path.join("model_people", pkl_name, "image") + file = os.path.join(pre_suffix, f"img{start_idx:05d}.{post_suffix}") + + if not os.path.exists(file): + logger.error(f"Not found image file: {file}") + raise FileExistsError(f"Not found image file: {file}") + frame_to_save = read_img( + file, flags=cv2.IMREAD_UNCHANGED if post_suffix == "png" else cv2.IMREAD_COLOR + ) + mask = None + if post_suffix == "jpg": + has_mask = os.path.exists(f"model_people/mask/{pkl_name}") + if not has_mask: + raise Exception("should use mask,however no mask is found.") + mask_file = os.path.join( + f"model_people/mask/{pkl_name}", f"img{start_idx:05d}.jpg" + ) + mask = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE) + elif post_suffix == "png": + mask = frame_to_save[:, :, 3] + else: + raise ValueError(f"file type {post_suffix} not in (png,jpeg)") + y1, y2, x1, x2 = ( + face_det_results[start_idx - 1] + if len(face_det_results[start_idx - 1]) == 4 + else face_det_results[start_idx - 1][1] + ) + + y2 = y2 + infer_padding + face, coords = frame_to_save[:, :, :3][y1:y2, x1:x2], (y1, y2, x1, x2) + face = cv2.resize(face, (config.img_size, config.img_size)) + return frame_to_save, coords, face, mask + + +def play_in_loop_v2( + segments, + startfrom, + batch_num, + last_direction, + is_silent, + is_silent_, + first_speak, + last_speak, +): + """ + batch_num: 初始和结束,每一帧都这么判断 + 1、静默时,在静默段循环, 左边界正向,右边界反向, 根据上一次方向和位置,给出新的方向和位置 + 2、静默转说话: 就近到说话段,pre_falg, post_flag, 都为true VS 其中一个为true + 3、说话转静默: 动作段播完,再进入静默(如果还在持续说话,静默段不循环) + 4、在整个视频左端点: 开始端只能正向,静默时循环,说话时走2 + 5、在整个视频右端点: 开始时只能反向,静默时循环,说话时走2 + 6、根据方向获取batch_num 数量的视频帧,return batch_idxes, current_direction + Args: + segments: 循环帧配置 [[st, ed, True], ...] + startfrom: cur_pos + batch_num: 5 + last_direction: 0反向1正向 + is_silent: 0说话态1动作态 + is_silent_: 目前不明确,后面可能废弃 + first_speak: 记录是不是第一次讲话 + last_speak: 记录是不是讲话结束那一刻 + """ + frames = [] + cur_pos = startfrom + cur_direction = last_direction + is_first_speak_frame = first_speak + is_last_speak_frame = True if last_speak and batch_num == 1 else False + while batch_num != 0: + # 获取当前帧的所在子分割区间 + sub_seg_idx = subseg_judge(cur_pos, segments) + # 获取移动方向 + next_direction, next_pos = get_next_direction( + segments, + cur_pos, + cur_direction, + is_silent, + sub_seg_idx, + is_first_speak_frame, + is_last_speak_frame, + ) + # 获取指定方向的帧 + next_pos = get_next_frame(next_pos, next_direction) + frames.append(next_pos) + batch_num -= 1 + is_first_speak_frame = ( + True if first_speak and batch_num == config.wav2lip_batch_size else False + ) + is_last_speak_frame = True if last_speak and batch_num == 1 else False + + cur_direction = next_direction + cur_pos = next_pos + return frames, next_direction + + +def get_next_direction( + segments, + cur_pos, + cur_direction, + is_silent, + sub_seg_idx, + is_first_speak_frame: bool = False, + is_last_speak_frame: bool = False, +): + """ + 3.3.0 循环帧需求,想尽快走到预期状态 + if 动作段: + if 开始说话: + if 边界: + if 正向: + pass + else: + pass + else: + if 正向: + pass + else: + pass + elif 静默: + 同上 + elif 说话中: + 同上 + elif 说话结束: + 同上 + elif 静默段: + 同上 + Args: + is_first_speak_frame: 开始说话flag + is_last_speak_frame: 说话结束flag + """ + left, right, loop_flag = segments[sub_seg_idx] + if loop_flag: + if is_silent == 1: + next_direct, next_pos = pure_silent( + segments, left, right, cur_pos, cur_direction, sub_seg_idx + ) + logger.debug( + f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}" + ) + elif is_silent == 0: + next_direct, next_pos = silent2action( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_first_speak_frame, + ) + logger.debug( + f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}" + ) + else: + if is_silent == 1: + next_direct, next_pos = action2silent( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_last_speak_frame, + ) + logger.debug( + f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}" + ) + elif is_silent == 0: + next_direct, next_pos = pure_action( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_last_speak_frame, + ) + logger.debug( + f"cur_pos:{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}" + ) + return next_direct, next_pos + + +def get_next_frame(cur_pos, cur_direction): + """根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的 + 不需要再做额外的边界判断 + """ + # 正向 + if cur_direction == 1: + return cur_pos + 1 + # 反向 + elif cur_direction == 0: + return cur_pos - 1 + + +def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx): + """ + loop_flag == True and is_silent==1 + whether border + whether forward + Return: + next_direction + """ + # 左边界正向,右边界反向 + if cur_pos == segments[0][0]: + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + # 右边界,反向 + if cur_pos == right: + return 0, cur_pos + # 左边界,正向 + if cur_pos == left: + return 1, cur_pos + # 非边界,之前正向,则继续正向,否则反向 + if cur_pos > left and cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + + +def pure_action( + segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame +): + """ + loop_flag ==False and is_silent == 0 + 动作播完,正向到静默段 (存在跳段行为) + whether border + whether forward # 正播反播 + Args: + is_last_speak_frame: 最后说话结束时刻 + Return: next_direction + """ + if cur_pos == segments[0][0]: + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + + if is_last_speak_frame: + # 动作段在末尾,向前找静默 + if sub_seg_idx == len(segments) - 1: + return 0, cur_pos + # 动作段在开始, 向后 + if sub_seg_idx == 0: + return 1, cur_pos + # 动作段在中间,就近原则 + mid = left + (right - left + 1) // 2 + # 就近原则优先 + if cur_pos < mid: + return 0, cur_pos + else: + return 1, cur_pos + + else: + # 其他情况,播放方向一致 + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + + +def silent2action( + segments, + left, + right, + cur_pos, + cur_direction, + sub_seg_idx, + is_first_speak_frame: bool = False, +): + """ + 在静默区间但是在讲话 + loop_flag=True and is_silent == 0 + whether border + whether forward + + Return: next_direction + """ + # 向最近的动作段移动, 如果左面没有动作段 + # TODO: 确认下面逻辑是否正确 + if ( + cur_pos == segments[0][0] + ): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行 + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + # 在静默左边界处,且仍在讲话 + if cur_pos == left: + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + # 在静默右边界处,且仍在讲话 + elif cur_pos == right: + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + else: + mid = left + (right - left + 1) // 2 + # !!就近原则只对第一次说话有效,其他情况遵循上一次状态 + if is_first_speak_frame: + # 如果第一段 + if sub_seg_idx == 0 and segments[0][2]: + return 1, cur_pos + # 如果最后一段 + elif sub_seg_idx == len(segments) - 1 and segments[-1][2]: + return 0, cur_pos + + if cur_pos < mid: + return 0, cur_pos + else: + return 1, cur_pos + else: + if cur_direction == 1: + return 1, cur_pos + elif cur_direction == 0: + return 0, cur_pos + + +def action2silent( + segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame +): + """ + loop_flag=False and is_silent==1 + whether border + Return: next_direction + """ + if cur_pos == segments[0][0]: + return 1, cur_pos + if cur_pos == segments[-1][1]: + return 0, cur_pos + # 动作段,说话结束转静默情况下,就近原则,进入静默 + if is_last_speak_frame: + mid = left + (right - left + 1) // 2 + if cur_pos < mid: + return 0, cur_pos + else: + return 1, cur_pos + + else: + if cur_direction == 1: + return 1, cur_pos + else: + return 0, cur_pos + + +def subseg_judge(cur_pos, segments): + for idx, frame_seg in enumerate(segments): + if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]: + return idx + if cur_pos == 0: + return 0 + + +def get_oss_config_masx(): + logger.info("download from mas-x") + OSS_ENDPOINT = "oss-cn-hangzhou.aliyuncs.com" + OSS_KEY = "LTAIgPS3rCV89vYO" + OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT" + OSS_BUCKET = "mas-x" + + return OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET + + +def get_oss_config_zqinghui(): + logger.info("download from zqinghui") + OSS_ENDPOINT = "oss-cn-zhangjiakou.aliyuncs.com" + OSS_KEY = "LTAIgPS3rCV89vYO" + OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT" + OSS_BUCKET = "zqinghui" + + return OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET + + +def get_oss_config(url: str): + if "zqinghui" in url: + return get_oss_config_zqinghui() + return get_oss_config_masx() + + +def download_file_from_oss_url( + model_dir: str, + url_path: str, + download_type: str, + download_file_len: int, + downloadQueue, + downloadFileLen: int, +): + """ + 默认从oss上根据url下载不同的type数据,如果下载失败,获取zqinghui配置下的数据 + Args: + + """ + try: + logger.info(f"下载url_path: {url_path}") + OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET = get_oss_config(url_path) + oss_class = create_oss.Oss_connection( + OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET + ) + oss_class.oss_get_file( + model_dir, + url_path, + type=download_type, + keys=url_path.lstrip(f"https://{OSS_BUCKET}.{OSS_ENDPOINT}"), + downloadQueue=downloadQueue, + downloadFileLen=downloadFileLen, + downloadAllLen=download_file_len, + ) + except Exception: + logger.warning(f"下载 {url_path} error: {traceback.format_exc()}") + if url_path.split("/")[-1] in os.listdir(model_dir): + os.remove(os.path.join(model_dir, url_path.split('/')[-1])) + + +def download_checkpoint(model_name, model_url, download_file_len, downloadQueue): + logger.info(f"Download checkpoint of model people: {model_url}") + model_dir = config.model_folder + os.makedirs(model_dir, exist_ok=True) + if not os.path.exists(os.path.join(model_dir, model_name)): + download_file_len += 1 + OSS_ENDPOINT = "oss-cn-zhangjiakou.aliyuncs.com" + OSS_KEY = "LTAIgPS3rCV89vYO" + OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT" + OSS_BUCKET = "zqinghui" + oss_class = create_oss.Oss_connection( + OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET + ) + oss_class.oss_get_file( + model_dir, model_url, type="模型", downloadQueue=downloadQueue + ) + return download_file_len, downloadQueue + + +def download(modelSpeakerId, downloadQueue, res): + """ + 1、异常退出时,删除所有已经下载的临时文件 + 2、重新下载时,根据上次下载一半的结果继续下载 + 3、错误拦截有问题 + 4、一些文件的创建和清理逻辑模糊 + """ + # TODO: 这个变量什么作用 + download_file_len = 3 + if res["code"] != 0: + logger.warning(f"无法获取当前数字人信息,{res['code']}") + return "ERROR" + + res = res["data"] + pkl_url, model_url, video_needle_url, mask_needle_url = ( + res["pkl_url"], + res["modeling_url"], + res["video_needle_url"], + res["mask_needle_url"], + ) + model_name = model_url.split("/")[-1] # 这会取URL的最后一部分作为文件名 + download_file_len, downloadQueue = download_checkpoint( + model_name, model_url, download_file_len, downloadQueue + ) + + model_dir = "model_people/" + os.makedirs(model_dir, exist_ok=True) + pkl_url_zip_name = pkl_url.split("/")[-1] # 这会取URL的最后一部分作为文件名 + pkl_url_name = pkl_url_zip_name.replace(".zip", "") + logger.info(f"pkl_url_name:{pkl_url_name}") + if pkl_url_name in os.listdir(model_dir): + return f"model people of {modelSpeakerId} not need to download" + + # TODO: 目前有报错,会无法删除已经下载的文件 + # pkl + logger.info(f"download model people of : {modelSpeakerId}") + if pkl_url != "none": + download_file_from_oss_url( + model_dir, + pkl_url, + "人物", + download_file_len, + downloadQueue, + downloadFileLen=1, + ) + unzip( + os.path.join(model_dir, pkl_url_zip_name), + os.path.join(model_dir, pkl_url_name), + ) + shutil.rmtree(os.path.join(model_dir, pkl_url_name, "image"), ignore_errors=True) + + # video + if video_needle_url != "none": + # TODO: 从配置获取 + download_file_from_oss_url( + model_dir, + video_needle_url, + "人物", + download_file_len, + downloadQueue, + downloadFileLen=2, + ) + # 没有换背景前的数字人 + unzip( + os.path.join(model_dir, video_needle_url.split('/')[-1]), + os.path.join(model_dir, pkl_url_name, "image"), + ) + shutil.rmtree(os.path.join(model_dir, modelSpeakerId), ignore_errors=True) + os.rename(os.path.join(model_dir, pkl_url_name), os.path.join(model_dir, modelSpeakerId)) + os.remove(os.path.join(model_dir, pkl_url_zip_name)) + + # mask + if mask_needle_url != "none": + download_file_from_oss_url( + model_dir, + mask_needle_url, + "mask", + download_file_len, + downloadQueue, + downloadFileLen=3, + ) + # TODO: 这里面放的是纯白色背景 + mask_dir = os.path.join(model_dir, "mask", modelSpeakerId) + os.makedirs(mask_dir, exist_ok=True) + + # unzip(f"{model_dir}/{mask_needle_url.split('/')[-1]}", mask_dir) + unzip(os.path.join(model_dir, mask_needle_url.split('/')[-1]), mask_dir) + # 删除zip包 + os.remove(os.path.join(model_dir, mask_needle_url.split('/')[-1])) + # 这么处理时因为,pkl 和 img 可能链接最后一级重名 + video_img_file = os.path.join(model_dir, video_needle_url.split('/')[-1]) + if os.path.exists(video_img_file): + os.remove(video_img_file) + + downloadQueue.put("100") + logger.info(f"model people of {modelSpeakerId} 下载完成") + return "下载完成" + + +def get_model_local_path(model_url: str): + return model_url.split("/")[-1] + + +def get_actor_id(pkl_url): + """ + 取URL的最后一部分作为template_id + """ + pkl_file_name = pkl_url.split("/")[-1] + return pkl_file_name.replace(".zip", "") + + +def load_face_box(actor_id, pkl_version: int = 1) -> np.ndarray: + """ + Args: actor_id: str : uid + pkl_version: 1, 2 + """ + face_det_results_pk_file = os.path.join( + "model_people", actor_id, "face_det_results.pkl" + ) + if not os.path.exists(face_det_results_pk_file): + raise FileExistsError(f"Not found file:{face_det_results_pk_file}") + if pkl_version != 1: + with open(face_det_results_pk_file, "rb") as f: + res = pickle.load(f)["boxes"] + return np.array(res) + + with open(face_det_results_pk_file, mode="rb") as f: + unpickler = pickle.Unpickler(f) + face_det_results = [] + try: + # TODO,太奇怪了,只能这么结束吗? + while 1: + data = unpickler.load() + face_det_results.append(data) + except Exception: + pass + final_result = [] + # TODO 改一下代码,这个在numpy1.24及以上不work。 + final_result = np.concatenate(face_det_results) + return final_result + + +def load_seg_model(pth_path, model, device): + model.to(device) + model.load_state_dict(torch.load(pth_path)) + return model.eval() + + +def load_w2l_model(pth_path, model, device): + model = model.to(device) + model.load_state_dict( + { + k.replace("module.", ""): v + for k, v in torch.load(pth_path)["state_dict"].items() + } + ) + return model.eval() + + +def read_img(file_path: str, flags=cv2.IMREAD_COLOR): + """read jpg img and decode""" + return cv2.imread(file_path, flags=flags) + + +def write_img(content, jpg_path: str, mode: str = "wb"): + with open(jpg_path, mode) as file: + file.write(content) + + +def save_raw_video(speech_id, result_frames, tag="infer"): + name = f"temp/{tag}_{speech_id}_result.avi" + out = cv2.VideoWriter(name, cv2.VideoWriter_fourcc(*"mp4v"), 25, (1080, 1920)) + + for i, frame in enumerate(result_frames): + out.write(frame) + out.release() + return name + + +def save_raw_wav(speech_id, audio_packets, tag="infer", folder="temp"): + os.makedirs(folder, exist_ok=True) + sample_rate = 16000 # 采样率,每秒44100个采样点 + # duration = 1 # 音频时长,单位秒 + num_channels = 1 # 通道数,单声道 + sample_width = 2 # 采样宽度,单位字节,这里用16位音频 + os.makedirs("temp", exist_ok=True) + saved_wav_name = f"temp/{tag}_{speech_id}.wav" + + with wave.open(saved_wav_name, "wb") as wav_file: + # 设置WAV文件参数 + wav_file.setnchannels(num_channels) + wav_file.setsampwidth(sample_width) # 设置采样宽度,单位为字节 + wav_file.setframerate(sample_rate) + wav_file.writeframes(audio_packets) + + return saved_wav_name + + +def morphing( + ori_frame: np.ndarray, + pred_frame: np.ndarray, + box: list, + mp_ratio: float, + padding: int = 10, + file_index: int = 1, +) -> np.ndarray: + """Tool to 图片融合 + Args: + ori_frame: 原始视频帧 01.png - np + pred_frame: infer_01.png - np + box: box, # 确认[x1,y1,x2,y2] + mp_ratio: [0,1] / morphing_frams_num + Return: + merged numpy array + """ + (y1, y2, x1, x2) = box + ori_face = ori_frame[ori_frame.shape[0] // 2 :, :] + pred_face = pred_frame[:, :, :3][ + y1 - padding : y2 + padding, x1 - padding : x2 + padding + ] + pred_half_face = pred_face[pred_face.shape[0] // 2 :, :] + alpha = np.ones_like(ori_face) * mp_ratio + pred_frame[:, :, :3][ + y1 + (y2 - y1) // 2 : y2 + padding, x1 - padding : x2 + padding + ] = (1 - alpha) * pred_half_face + alpha * ori_face + logger.debug(f"file_index:{file_index}, ratio:{mp_ratio}") + return pred_frame + + +# TODO: @txueduo 这里其实和fusion处理相似,原因是mask 非二值,所以要用float alpha 做融合 +def add_alpha(img: np.ndarray, mask: np.ndarray, alpha: int = 0): + t0 = time.time() + if img.shape[-1] == 3: + res_img = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8) + res_img[:, :, :3] = img[:, :, ::-1] + else: + res_img = np.zeros_like(img, dtype=np.uint8) + res_img[:, :, :3] = img[:, :, :3][:, :, ::-1] + mask[mask > 125] = 255 + res_img[:, :, 3] = mask + result = res_img + # 将新的 alpha 通道添加到图像中 + + logger.info(f"add_alpha cost:{time.time() - t0}") + return result.astype("uint8") + + +def fusion(person: np.ndarray, background: np.ndarray, mask: np.ndarray) -> np.ndarray: + """ + Args: + person: ori_img + background: need to fusion img + mask: filter mask + """ + + result = np.zeros_like(person, dtype=np.uint8) + # png 输出透明度 + if len(mask.shape) == 2: + background_alpha = np.zeros_like(person, dtype=np.uint8) + background_alpha[:, :, :3] = background + # background_alpha[:, :, 3] = 255 if not config.output_alpha else 0 + background_alpha[:, :, 3] = 0 + mask_rgba = np.repeat(mask[..., None], 4, axis=-1).astype("uint8") + else: + # jpeg 输出透明度由后面计算确认 + mask_rgba = mask + background_alpha = background + mask_rgba = mask_rgba / 255.0 + result = mask_rgba * person + (1 - mask_rgba) * background_alpha + return result.astype("uint8") + + +def chg_bg( + ori_img_path: Union[str, np.ndarray] = None, + bg_img_path: Union[str, np.ndarray] = None, + mask: Union[str, np.ndarray] = None, + default_flags=cv2.IMREAD_COLOR, +): + """""" + if isinstance(bg_img_path, str): + flags = cv2.IMREAD_UNCHANGED if "png" in bg_img_path else default_flags + bg_img = cv2.imread(bg_img_path, flags=flags) + else: + bg_img = bg_img_path + if isinstance(ori_img_path, str): + flags = cv2.IMREAD_UNCHANGED if "png" in ori_img_path else default_flags + ori_img = cv2.imread(ori_img_path, flags=flags) + else: + ori_img = ori_img_path + logger.debug( + f"bg-img shape:{bg_img.shape}, ori_img_path:{ori_img_path}, flags:{flags}" + ) + if ori_img.shape[-1] == 4: + return fusion(ori_img, bg_img, ori_img[:, :, 3]) + elif ori_img.shape[-1] == 3: + if isinstance(mask, str): + mask = cv2.imread(mask) + return fusion(ori_img, bg_img, mask) + else: + raise ValueError("img shape[-1] must be 3 or 4") + + +def load_config(model_speaker_id:str): + """""" + logger.info(f"model_speaker_id: {model_speaker_id}") + res = requests.post( + config.GET_segments, json={"modelSpeakerId": model_speaker_id} + ).json()["data"] + logger.info(res) + model_url = res["modeling_url"] + frame_config = res.get("frame_config", [[1,200, True]]) + # frame_config = [[1, 591, True]] + padding = res["padding"] # padding 过于小,会存在下巴没有推理的现象 + face_classes = res.get("face_classes", [1,11,12,13]) + if not frame_config: + raise RuntimeError(f"frame config {frame_config} not json") + + # 新增字段 trans_method + trans_method = res.get("trans_method", 2) + # 新增过渡参数字段 + infer_silent_num = res.get("infer_silent_num", 5) + morphing_num = res.get("morphing_num", 25) + pkl_version = res.get("pkl_version", 1) + logger.info(f"infer_silent_num:{infer_silent_num}, morphing_num:{morphing_num}") + return model_url, frame_config, padding, face_classes, trans_method, infer_silent_num, morphing_num, pkl_version + + +def get_trans_idxes(need_post_morphing, need_infer_silent, post_morphint_idx,infer_silent_idx, file_index): + if need_post_morphing: + post_morphint_idxes = list(range(1,6)) if post_morphint_idx==1 else list(range(post_morphint_idx)[-5:]) + else: post_morphint_idxes = [0,0,0,0,0] #typeignore + if need_infer_silent: + infer_silent_idxes = list(range(1,6)) if infer_silent_idx==1 else list(range(infer_silent_idx)[-5:]) + else: infer_silent_idxes =[0,0,0,0,0] + file_idxes = list(range(1,6)) if file_index==1 else list(range(file_index))[-5:] + return file_idxes, post_morphint_idxes, infer_silent_idxes