diff --git a/utils/log.py b/utils/log.py
new file mode 100644
index 0000000..a531729
--- /dev/null
+++ b/utils/log.py
@@ -0,0 +1,119 @@
+import logging
+import os
+import sys
+import redis
+from loguru import logger as logurulogger
+import redis.exceptions
+from app.config import config
+import json
+from redis.retry import Retry
+from redis.backoff import ExponentialBackoff
+
+LOG_FORMAT = (
+ "{level: <8} "
+ "{process.name} | " # 进程名
+ "{thread.name} | "
+ "{time:YYYY-MM-DD HH:mm:ss.SSS} - "
+ "{process} "
+ "{module}.{function}:{line} - "
+ "{message}"
+)
+LOG_NAME = ["uvicorn", "uvicorn.access", "uvicorn.error", "flask"]
+
+# 配置 Redis 连接池
+redis_pool = redis.ConnectionPool(
+ host=config.LOG_REDIS_HOST, # Redis 服务器地址
+ port=config.LOG_REDIS_PORT, # Redis 服务器端口
+ db=config.LOG_REDIS_DB, # 数据库编号
+ password=config.LOG_REDIS_AUTH, # 密码
+ max_connections=config.max_connections, # 最大连接数
+ socket_connect_timeout=config.socket_connect_timeout, # 连接超时时间
+ socket_timeout=config.socket_timeout, # 等待超时时间
+)
+
+
+class InterceptHandler(logging.Handler):
+ def emit(self, record):
+ try:
+ level = logurulogger.level(record.levelname).name
+ except AttributeError:
+ level = logging._levelToName[record.levelno]
+
+ frame, depth = logging.currentframe(), 2
+ while frame.f_code.co_filename == logging.__file__:
+ frame = frame.f_back
+ depth += 1
+
+ logurulogger.opt(depth=depth, exception=record.exc_info).log(
+ level, record.getMessage()
+ )
+
+class Logging:
+ """自定义日志"""
+
+ def __init__(self):
+ self.log_path = "logs"
+ self._connect_redis()
+ if config.IS_LOCAL:
+ os.makedirs(self.log_path, exist_ok=True)
+ self._initlogger()
+ self._reset_log_handler()
+
+ def _connect_redis(self):
+ retry = Retry(ExponentialBackoff(), 3) # 重试3次,指数退避
+ self.redis_client = redis.Redis(connection_pool=redis_pool,retry=retry) # 使用连接池
+
+ def _initlogger(self):
+ """初始化loguru配置"""
+ logurulogger.remove()
+ if config.IS_LOCAL:
+ logurulogger.add(
+ os.path.join(self.log_path, "error.log.{time:YYYY-MM-DD}"),
+ format=LOG_FORMAT,
+ level=logging.ERROR,
+ rotation="00:00",
+ retention="1 week",
+ backtrace=True,
+ diagnose=True,
+ enqueue=True
+ )
+ logurulogger.add(
+ os.path.join(self.log_path, "info.log.{time:YYYY-MM-DD}"),
+ format=LOG_FORMAT,
+ level=logging.INFO,
+ rotation="00:00",
+ retention="1 week",
+ enqueue=True
+ )
+ logurulogger.add(
+ sys.stdout,
+ format=LOG_FORMAT,
+ level=logging.DEBUG,
+ colorize=True,
+ )
+
+ logurulogger.add(self._log_to_redis, level="INFO", format=LOG_FORMAT)
+ self.logger = logurulogger
+
+
+ def _log_to_redis(self, message):
+ """将日志写入 Redis 列表"""
+ try:
+ self.redis_client.rpush(f"nlp.logger.{config.env_version}.log", json.dumps({"message": message}))
+ except redis.exceptions.ConnectionError as e:
+ logger.error(f"write {message} Redis connection error: {e}")
+ except redis.exceptions.TimeoutError as e:
+ logger.error(f"write {message} Redis operation timed out: {e}")
+ except Exception as e:
+ logger.error(f"write {message} Unexpected error: {e}")
+
+ def _reset_log_handler(self):
+ for log in LOG_NAME:
+ logger = logging.getLogger(log)
+ logger.handlers = [InterceptHandler()]
+
+ def getlogger(self):
+ return self.logger
+
+logger = Logging().getlogger()
+
diff --git a/utils/loop_frame_tool.py b/utils/loop_frame_tool.py
new file mode 100644
index 0000000..9343cab
--- /dev/null
+++ b/utils/loop_frame_tool.py
@@ -0,0 +1,318 @@
+from utils.log import logger
+
+
+def play_in_loop_v2(
+ segments,
+ startfrom,
+ batch_num,
+ last_direction,
+ is_silent,
+ is_silent_,
+ first_speak,
+ last_speak,
+):
+ """
+ batch_num: 初始和结束,每一帧都这么判断
+ 1、静默时,在静默段循环, 左边界正向,右边界反向, 根据上一次方向和位置,给出新的方向和位置
+ 2、静默转说话: 就近到说话段,pre_falg, post_flag, 都为true VS 其中一个为true
+ 3、说话转静默: 动作段播完,再进入静默(如果还在持续说话,静默段不循环)
+ 4、在整个视频左端点: 开始端只能正向,静默时循环,说话时走2
+ 5、在整个视频右端点: 开始时只能反向,静默时循环,说话时走2
+ 6、根据方向获取batch_num 数量的视频帧,return batch_idxes, current_direction
+ Args:
+ segments: 循环帧配置 [[st, ed, True], ...]
+ startfrom: cur_pos
+ batch_num: 5
+ last_direction: 0反向1正向
+ is_silent: 0说话态1动作态
+ is_silent_: 目前不明确,后面可能废弃
+ first_speak: 记录是不是第一次讲话
+ last_speak: 记录是不是讲话结束那一刻
+ """
+ frames = []
+ cur_pos = startfrom
+ cur_direction = last_direction
+ is_first_speak_frame = first_speak
+ is_last_speak_frame = True if last_speak and batch_num == 1 else False
+ while batch_num != 0:
+ # 获取当前帧的所在子分割区间
+ sub_seg_idx = subseg_judge(cur_pos, segments)
+ # 获取移动方向
+ next_direction, next_pos = get_next_direction(
+ segments,
+ cur_pos,
+ cur_direction,
+ is_silent,
+ sub_seg_idx,
+ is_first_speak_frame,
+ is_last_speak_frame,
+ )
+ # 获取指定方向的帧
+ next_pos = get_next_frame(next_pos, next_direction)
+ frames.append(next_pos)
+ batch_num -= 1
+ is_first_speak_frame = (
+ True if first_speak and batch_num == config.batch_size else False
+ )
+ is_last_speak_frame = True if last_speak and batch_num == 1 else False
+
+ cur_direction = next_direction
+ cur_pos = next_pos
+ return frames, next_direction
+
+
+def subseg_judge(cur_pos, segments):
+ for idx, frame_seg in enumerate(segments):
+ if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]:
+ return idx
+ if cur_pos == 0:
+ return 0
+
+def get_next_direction(
+ segments,
+ cur_pos,
+ cur_direction,
+ is_silent,
+ sub_seg_idx,
+ is_first_speak_frame: bool = False,
+ is_last_speak_frame: bool = False,
+):
+ """
+ 3.3.0 循环帧需求,想尽快走到预期状态
+ if 动作段:
+ if 开始说话:
+ if 边界:
+ if 正向:
+ pass
+ else:
+ pass
+ else:
+ if 正向:
+ pass
+ else:
+ pass
+ elif 静默:
+ 同上
+ elif 说话中:
+ 同上
+ elif 说话结束:
+ 同上
+ elif 静默段:
+ 同上
+ Args:
+ is_first_speak_frame: 开始说话flag
+ is_last_speak_frame: 说话结束flag
+ """
+ left, right, loop_flag = segments[sub_seg_idx]
+ if loop_flag:
+ if is_silent == 1:
+ next_direct, next_pos = pure_silent(
+ segments, left, right, cur_pos, cur_direction, sub_seg_idx
+ )
+ logger.debug(
+ f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}"
+ )
+ elif is_silent == 0:
+ next_direct, next_pos = silent2action(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_first_speak_frame,
+ )
+ logger.debug(
+ f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}"
+ )
+ else:
+ if is_silent == 1:
+ next_direct, next_pos = action2silent(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_last_speak_frame,
+ )
+ logger.debug(
+ f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}"
+ )
+ elif is_silent == 0:
+ next_direct, next_pos = pure_action(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_last_speak_frame,
+ )
+ logger.debug(
+ f"cur_pos:{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}"
+ )
+ return next_direct, next_pos
+
+def get_next_frame(cur_pos, cur_direction):
+ """根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的
+ 不需要再做额外的边界判断
+ """
+ # 正向
+ if cur_direction == 1:
+ return cur_pos + 1
+ # 反向
+ elif cur_direction == 0:
+ return cur_pos - 1
+
+def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx):
+ """
+ loop_flag == True and is_silent==1
+ whether border
+ whether forward
+ Return:
+ next_direction
+ """
+ # 左边界正向,右边界反向
+ if cur_pos == segments[0][0]:
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+ # 右边界,反向
+ if cur_pos == right:
+ return 0, cur_pos
+ # 左边界,正向
+ if cur_pos == left:
+ return 1, cur_pos
+ # 非边界,之前正向,则继续正向,否则反向
+ if cur_pos > left and cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+
+
+def pure_action(
+ segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
+):
+ """
+ loop_flag ==False and is_silent == 0
+ 动作播完,正向到静默段 (存在跳段行为)
+ whether border
+ whether forward # 正播反播
+ Args:
+ is_last_speak_frame: 最后说话结束时刻
+ Return: next_direction
+ """
+ if cur_pos == segments[0][0]:
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+
+ if is_last_speak_frame:
+ # 动作段在末尾,向前找静默
+ if sub_seg_idx == len(segments) - 1:
+ return 0, cur_pos
+ # 动作段在开始, 向后
+ if sub_seg_idx == 0:
+ return 1, cur_pos
+ # 动作段在中间,就近原则
+ mid = left + (right - left + 1) // 2
+ # 就近原则优先
+ if cur_pos < mid:
+ return 0, cur_pos
+ else:
+ return 1, cur_pos
+
+ else:
+ # 其他情况,播放方向一致
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+
+
+def silent2action(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_first_speak_frame: bool = False,
+):
+ """
+ 在静默区间但是在讲话
+ loop_flag=True and is_silent == 0
+ whether border
+ whether forward
+
+ Return: next_direction
+ """
+ # 向最近的动作段移动, 如果左面没有动作段
+ # TODO: 确认下面逻辑是否正确
+ if (
+ cur_pos == segments[0][0]
+ ): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+ # 在静默左边界处,且仍在讲话
+ if cur_pos == left:
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+ # 在静默右边界处,且仍在讲话
+ elif cur_pos == right:
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+ else:
+ mid = left + (right - left + 1) // 2
+ # !!就近原则只对第一次说话有效,其他情况遵循上一次状态
+ if is_first_speak_frame:
+ # 如果第一段
+ if sub_seg_idx == 0 and segments[0][2]:
+ return 1, cur_pos
+ # 如果最后一段
+ elif sub_seg_idx == len(segments) - 1 and segments[-1][2]:
+ return 0, cur_pos
+
+ if cur_pos < mid:
+ return 0, cur_pos
+ else:
+ return 1, cur_pos
+ else:
+ if cur_direction == 1:
+ return 1, cur_pos
+ elif cur_direction == 0:
+ return 0, cur_pos
+
+
+def action2silent(
+ segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
+):
+ """
+ loop_flag=False and is_silent==1
+ whether border
+ Return: next_direction
+ """
+ if cur_pos == segments[0][0]:
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+ # 动作段,说话结束转静默情况下,就近原则,进入静默
+ if is_last_speak_frame:
+ mid = left + (right - left + 1) // 2
+ if cur_pos < mid:
+ return 0, cur_pos
+ else:
+ return 1, cur_pos
+
+ else:
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+
diff --git a/utils/util.py b/utils/util.py
new file mode 100644
index 0000000..577b1b8
--- /dev/null
+++ b/utils/util.py
@@ -0,0 +1,889 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+import os
+import pickle
+import numpy as np
+import zipfile
+
+from config import config
+import cv2
+import torch
+from turbojpeg import TurboJPEG
+import platform
+import wave
+import requests
+
+from log import logger
+import time
+import shutil
+from utils import create_oss
+from typing import Union
+import traceback
+
+
+def unzip(zip_path, save_path):
+ with zipfile.ZipFile(zip_path, "r") as file:
+ file.extractall(save_path)
+
+
+if platform.system() == "Linux":
+ jpeg = TurboJPEG()
+else:
+ jpeg = TurboJPEG(lib_path=r"libjpeg-turbo-gcc64\bin\libturbojpeg.dll")
+
+
+def datagen(
+ pool2,
+ mel_batch,
+ start_idx_list,
+ pk_name,
+ face_det_results,
+ padding,
+ is_silent: int = 1,
+):
+ """
+ 数据批量预处理
+ Args:
+ is_silent: 是否为静默帧,静默帧读取后处理过的帧
+ """
+ img_batch, frame_batch, coords_batch, mask_batch = [], [], [], []
+ futures = []
+
+ return_gan = []
+ for start_idx in start_idx_list:
+ future = pool2.submit(
+ get_face_image,
+ pk_name,
+ start_idx,
+ face_det_results,
+ padding,
+ is_silent,
+ )
+ futures.append(future)
+
+ for future in futures:
+ frame_to_save, coords, face, mask = future.result()
+ img_batch.append(face)
+ frame_batch.append(frame_to_save)
+ coords_batch.append(coords)
+ mask_batch.append(mask)
+
+ if len(img_batch) >= config.wav2lip_batch_size:
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
+ img_masked = img_batch.copy()
+ img_masked[:, config.img_size // 2 :] = 0
+
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
+ mel_batch = mel_batch[..., np.newaxis]
+ return_gan.append(
+ (
+ img_batch,
+ mel_batch,
+ frame_batch,
+ mask_batch,
+ coords_batch,
+ start_idx_list,
+ )
+ )
+ img_batch, frame_batch, mask_batch, coords_batch = [], [], [], []
+
+ if len(img_batch) > 0:
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
+ img_masked = img_batch.copy()
+ img_masked[:, config.img_size // 2 :] = 0
+
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
+ mel_batch = mel_batch[..., np.newaxis]
+ return_gan.append(
+ (
+ img_batch,
+ mel_batch,
+ frame_batch,
+ mask_batch,
+ coords_batch,
+ start_idx_list,
+ )
+ )
+
+ return return_gan
+
+
+def get_face_image(
+ pkl_name,
+ start_idx,
+ face_det_results,
+ infer_padding,
+ is_silent: int = 1,
+):
+ # 判断是否经过换背景,
+ post_suffix = (
+ "png" if not os.path.exists(f"model_people/mask/{pkl_name}") else "jpg"
+ )
+ pre_suffix = os.path.join("model_people", pkl_name, "image")
+ file = os.path.join(pre_suffix, f"img{start_idx:05d}.{post_suffix}")
+
+ if not os.path.exists(file):
+ logger.error(f"Not found image file: {file}")
+ raise FileExistsError(f"Not found image file: {file}")
+ frame_to_save = read_img(
+ file, flags=cv2.IMREAD_UNCHANGED if post_suffix == "png" else cv2.IMREAD_COLOR
+ )
+ mask = None
+ if post_suffix == "jpg":
+ has_mask = os.path.exists(f"model_people/mask/{pkl_name}")
+ if not has_mask:
+ raise Exception("should use mask,however no mask is found.")
+ mask_file = os.path.join(
+ f"model_people/mask/{pkl_name}", f"img{start_idx:05d}.jpg"
+ )
+ mask = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)
+ elif post_suffix == "png":
+ mask = frame_to_save[:, :, 3]
+ else:
+ raise ValueError(f"file type {post_suffix} not in (png,jpeg)")
+ y1, y2, x1, x2 = (
+ face_det_results[start_idx - 1]
+ if len(face_det_results[start_idx - 1]) == 4
+ else face_det_results[start_idx - 1][1]
+ )
+
+ y2 = y2 + infer_padding
+ face, coords = frame_to_save[:, :, :3][y1:y2, x1:x2], (y1, y2, x1, x2)
+ face = cv2.resize(face, (config.img_size, config.img_size))
+ return frame_to_save, coords, face, mask
+
+
+def play_in_loop_v2(
+ segments,
+ startfrom,
+ batch_num,
+ last_direction,
+ is_silent,
+ is_silent_,
+ first_speak,
+ last_speak,
+):
+ """
+ batch_num: 初始和结束,每一帧都这么判断
+ 1、静默时,在静默段循环, 左边界正向,右边界反向, 根据上一次方向和位置,给出新的方向和位置
+ 2、静默转说话: 就近到说话段,pre_falg, post_flag, 都为true VS 其中一个为true
+ 3、说话转静默: 动作段播完,再进入静默(如果还在持续说话,静默段不循环)
+ 4、在整个视频左端点: 开始端只能正向,静默时循环,说话时走2
+ 5、在整个视频右端点: 开始时只能反向,静默时循环,说话时走2
+ 6、根据方向获取batch_num 数量的视频帧,return batch_idxes, current_direction
+ Args:
+ segments: 循环帧配置 [[st, ed, True], ...]
+ startfrom: cur_pos
+ batch_num: 5
+ last_direction: 0反向1正向
+ is_silent: 0说话态1动作态
+ is_silent_: 目前不明确,后面可能废弃
+ first_speak: 记录是不是第一次讲话
+ last_speak: 记录是不是讲话结束那一刻
+ """
+ frames = []
+ cur_pos = startfrom
+ cur_direction = last_direction
+ is_first_speak_frame = first_speak
+ is_last_speak_frame = True if last_speak and batch_num == 1 else False
+ while batch_num != 0:
+ # 获取当前帧的所在子分割区间
+ sub_seg_idx = subseg_judge(cur_pos, segments)
+ # 获取移动方向
+ next_direction, next_pos = get_next_direction(
+ segments,
+ cur_pos,
+ cur_direction,
+ is_silent,
+ sub_seg_idx,
+ is_first_speak_frame,
+ is_last_speak_frame,
+ )
+ # 获取指定方向的帧
+ next_pos = get_next_frame(next_pos, next_direction)
+ frames.append(next_pos)
+ batch_num -= 1
+ is_first_speak_frame = (
+ True if first_speak and batch_num == config.wav2lip_batch_size else False
+ )
+ is_last_speak_frame = True if last_speak and batch_num == 1 else False
+
+ cur_direction = next_direction
+ cur_pos = next_pos
+ return frames, next_direction
+
+
+def get_next_direction(
+ segments,
+ cur_pos,
+ cur_direction,
+ is_silent,
+ sub_seg_idx,
+ is_first_speak_frame: bool = False,
+ is_last_speak_frame: bool = False,
+):
+ """
+ 3.3.0 循环帧需求,想尽快走到预期状态
+ if 动作段:
+ if 开始说话:
+ if 边界:
+ if 正向:
+ pass
+ else:
+ pass
+ else:
+ if 正向:
+ pass
+ else:
+ pass
+ elif 静默:
+ 同上
+ elif 说话中:
+ 同上
+ elif 说话结束:
+ 同上
+ elif 静默段:
+ 同上
+ Args:
+ is_first_speak_frame: 开始说话flag
+ is_last_speak_frame: 说话结束flag
+ """
+ left, right, loop_flag = segments[sub_seg_idx]
+ if loop_flag:
+ if is_silent == 1:
+ next_direct, next_pos = pure_silent(
+ segments, left, right, cur_pos, cur_direction, sub_seg_idx
+ )
+ logger.debug(
+ f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}"
+ )
+ elif is_silent == 0:
+ next_direct, next_pos = silent2action(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_first_speak_frame,
+ )
+ logger.debug(
+ f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}"
+ )
+ else:
+ if is_silent == 1:
+ next_direct, next_pos = action2silent(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_last_speak_frame,
+ )
+ logger.debug(
+ f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}"
+ )
+ elif is_silent == 0:
+ next_direct, next_pos = pure_action(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_last_speak_frame,
+ )
+ logger.debug(
+ f"cur_pos:{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}"
+ )
+ return next_direct, next_pos
+
+
+def get_next_frame(cur_pos, cur_direction):
+ """根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的
+ 不需要再做额外的边界判断
+ """
+ # 正向
+ if cur_direction == 1:
+ return cur_pos + 1
+ # 反向
+ elif cur_direction == 0:
+ return cur_pos - 1
+
+
+def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx):
+ """
+ loop_flag == True and is_silent==1
+ whether border
+ whether forward
+ Return:
+ next_direction
+ """
+ # 左边界正向,右边界反向
+ if cur_pos == segments[0][0]:
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+ # 右边界,反向
+ if cur_pos == right:
+ return 0, cur_pos
+ # 左边界,正向
+ if cur_pos == left:
+ return 1, cur_pos
+ # 非边界,之前正向,则继续正向,否则反向
+ if cur_pos > left and cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+
+
+def pure_action(
+ segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
+):
+ """
+ loop_flag ==False and is_silent == 0
+ 动作播完,正向到静默段 (存在跳段行为)
+ whether border
+ whether forward # 正播反播
+ Args:
+ is_last_speak_frame: 最后说话结束时刻
+ Return: next_direction
+ """
+ if cur_pos == segments[0][0]:
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+
+ if is_last_speak_frame:
+ # 动作段在末尾,向前找静默
+ if sub_seg_idx == len(segments) - 1:
+ return 0, cur_pos
+ # 动作段在开始, 向后
+ if sub_seg_idx == 0:
+ return 1, cur_pos
+ # 动作段在中间,就近原则
+ mid = left + (right - left + 1) // 2
+ # 就近原则优先
+ if cur_pos < mid:
+ return 0, cur_pos
+ else:
+ return 1, cur_pos
+
+ else:
+ # 其他情况,播放方向一致
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+
+
+def silent2action(
+ segments,
+ left,
+ right,
+ cur_pos,
+ cur_direction,
+ sub_seg_idx,
+ is_first_speak_frame: bool = False,
+):
+ """
+ 在静默区间但是在讲话
+ loop_flag=True and is_silent == 0
+ whether border
+ whether forward
+
+ Return: next_direction
+ """
+ # 向最近的动作段移动, 如果左面没有动作段
+ # TODO: 确认下面逻辑是否正确
+ if (
+ cur_pos == segments[0][0]
+ ): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+ # 在静默左边界处,且仍在讲话
+ if cur_pos == left:
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+ # 在静默右边界处,且仍在讲话
+ elif cur_pos == right:
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+ else:
+ mid = left + (right - left + 1) // 2
+ # !!就近原则只对第一次说话有效,其他情况遵循上一次状态
+ if is_first_speak_frame:
+ # 如果第一段
+ if sub_seg_idx == 0 and segments[0][2]:
+ return 1, cur_pos
+ # 如果最后一段
+ elif sub_seg_idx == len(segments) - 1 and segments[-1][2]:
+ return 0, cur_pos
+
+ if cur_pos < mid:
+ return 0, cur_pos
+ else:
+ return 1, cur_pos
+ else:
+ if cur_direction == 1:
+ return 1, cur_pos
+ elif cur_direction == 0:
+ return 0, cur_pos
+
+
+def action2silent(
+ segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
+):
+ """
+ loop_flag=False and is_silent==1
+ whether border
+ Return: next_direction
+ """
+ if cur_pos == segments[0][0]:
+ return 1, cur_pos
+ if cur_pos == segments[-1][1]:
+ return 0, cur_pos
+ # 动作段,说话结束转静默情况下,就近原则,进入静默
+ if is_last_speak_frame:
+ mid = left + (right - left + 1) // 2
+ if cur_pos < mid:
+ return 0, cur_pos
+ else:
+ return 1, cur_pos
+
+ else:
+ if cur_direction == 1:
+ return 1, cur_pos
+ else:
+ return 0, cur_pos
+
+
+def subseg_judge(cur_pos, segments):
+ for idx, frame_seg in enumerate(segments):
+ if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]:
+ return idx
+ if cur_pos == 0:
+ return 0
+
+
+def get_oss_config_masx():
+ logger.info("download from mas-x")
+ OSS_ENDPOINT = "oss-cn-hangzhou.aliyuncs.com"
+ OSS_KEY = "LTAIgPS3rCV89vYO"
+ OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT"
+ OSS_BUCKET = "mas-x"
+
+ return OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
+
+
+def get_oss_config_zqinghui():
+ logger.info("download from zqinghui")
+ OSS_ENDPOINT = "oss-cn-zhangjiakou.aliyuncs.com"
+ OSS_KEY = "LTAIgPS3rCV89vYO"
+ OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT"
+ OSS_BUCKET = "zqinghui"
+
+ return OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
+
+
+def get_oss_config(url: str):
+ if "zqinghui" in url:
+ return get_oss_config_zqinghui()
+ return get_oss_config_masx()
+
+
+def download_file_from_oss_url(
+ model_dir: str,
+ url_path: str,
+ download_type: str,
+ download_file_len: int,
+ downloadQueue,
+ downloadFileLen: int,
+):
+ """
+ 默认从oss上根据url下载不同的type数据,如果下载失败,获取zqinghui配置下的数据
+ Args:
+
+ """
+ try:
+ logger.info(f"下载url_path: {url_path}")
+ OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET = get_oss_config(url_path)
+ oss_class = create_oss.Oss_connection(
+ OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
+ )
+ oss_class.oss_get_file(
+ model_dir,
+ url_path,
+ type=download_type,
+ keys=url_path.lstrip(f"https://{OSS_BUCKET}.{OSS_ENDPOINT}"),
+ downloadQueue=downloadQueue,
+ downloadFileLen=downloadFileLen,
+ downloadAllLen=download_file_len,
+ )
+ except Exception:
+ logger.warning(f"下载 {url_path} error: {traceback.format_exc()}")
+ if url_path.split("/")[-1] in os.listdir(model_dir):
+ os.remove(os.path.join(model_dir, url_path.split('/')[-1]))
+
+
+def download_checkpoint(model_name, model_url, download_file_len, downloadQueue):
+ logger.info(f"Download checkpoint of model people: {model_url}")
+ model_dir = config.model_folder
+ os.makedirs(model_dir, exist_ok=True)
+ if not os.path.exists(os.path.join(model_dir, model_name)):
+ download_file_len += 1
+ OSS_ENDPOINT = "oss-cn-zhangjiakou.aliyuncs.com"
+ OSS_KEY = "LTAIgPS3rCV89vYO"
+ OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT"
+ OSS_BUCKET = "zqinghui"
+ oss_class = create_oss.Oss_connection(
+ OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
+ )
+ oss_class.oss_get_file(
+ model_dir, model_url, type="模型", downloadQueue=downloadQueue
+ )
+ return download_file_len, downloadQueue
+
+
+def download(modelSpeakerId, downloadQueue, res):
+ """
+ 1、异常退出时,删除所有已经下载的临时文件
+ 2、重新下载时,根据上次下载一半的结果继续下载
+ 3、错误拦截有问题
+ 4、一些文件的创建和清理逻辑模糊
+ """
+ # TODO: 这个变量什么作用
+ download_file_len = 3
+ if res["code"] != 0:
+ logger.warning(f"无法获取当前数字人信息,{res['code']}")
+ return "ERROR"
+
+ res = res["data"]
+ pkl_url, model_url, video_needle_url, mask_needle_url = (
+ res["pkl_url"],
+ res["modeling_url"],
+ res["video_needle_url"],
+ res["mask_needle_url"],
+ )
+ model_name = model_url.split("/")[-1] # 这会取URL的最后一部分作为文件名
+ download_file_len, downloadQueue = download_checkpoint(
+ model_name, model_url, download_file_len, downloadQueue
+ )
+
+ model_dir = "model_people/"
+ os.makedirs(model_dir, exist_ok=True)
+ pkl_url_zip_name = pkl_url.split("/")[-1] # 这会取URL的最后一部分作为文件名
+ pkl_url_name = pkl_url_zip_name.replace(".zip", "")
+ logger.info(f"pkl_url_name:{pkl_url_name}")
+ if pkl_url_name in os.listdir(model_dir):
+ return f"model people of {modelSpeakerId} not need to download"
+
+ # TODO: 目前有报错,会无法删除已经下载的文件
+ # pkl
+ logger.info(f"download model people of : {modelSpeakerId}")
+ if pkl_url != "none":
+ download_file_from_oss_url(
+ model_dir,
+ pkl_url,
+ "人物",
+ download_file_len,
+ downloadQueue,
+ downloadFileLen=1,
+ )
+ unzip(
+ os.path.join(model_dir, pkl_url_zip_name),
+ os.path.join(model_dir, pkl_url_name),
+ )
+ shutil.rmtree(os.path.join(model_dir, pkl_url_name, "image"), ignore_errors=True)
+
+ # video
+ if video_needle_url != "none":
+ # TODO: 从配置获取
+ download_file_from_oss_url(
+ model_dir,
+ video_needle_url,
+ "人物",
+ download_file_len,
+ downloadQueue,
+ downloadFileLen=2,
+ )
+ # 没有换背景前的数字人
+ unzip(
+ os.path.join(model_dir, video_needle_url.split('/')[-1]),
+ os.path.join(model_dir, pkl_url_name, "image"),
+ )
+ shutil.rmtree(os.path.join(model_dir, modelSpeakerId), ignore_errors=True)
+ os.rename(os.path.join(model_dir, pkl_url_name), os.path.join(model_dir, modelSpeakerId))
+ os.remove(os.path.join(model_dir, pkl_url_zip_name))
+
+ # mask
+ if mask_needle_url != "none":
+ download_file_from_oss_url(
+ model_dir,
+ mask_needle_url,
+ "mask",
+ download_file_len,
+ downloadQueue,
+ downloadFileLen=3,
+ )
+ # TODO: 这里面放的是纯白色背景
+ mask_dir = os.path.join(model_dir, "mask", modelSpeakerId)
+ os.makedirs(mask_dir, exist_ok=True)
+
+ # unzip(f"{model_dir}/{mask_needle_url.split('/')[-1]}", mask_dir)
+ unzip(os.path.join(model_dir, mask_needle_url.split('/')[-1]), mask_dir)
+ # 删除zip包
+ os.remove(os.path.join(model_dir, mask_needle_url.split('/')[-1]))
+ # 这么处理时因为,pkl 和 img 可能链接最后一级重名
+ video_img_file = os.path.join(model_dir, video_needle_url.split('/')[-1])
+ if os.path.exists(video_img_file):
+ os.remove(video_img_file)
+
+ downloadQueue.put("100")
+ logger.info(f"model people of {modelSpeakerId} 下载完成")
+ return "下载完成"
+
+
+def get_model_local_path(model_url: str):
+ return model_url.split("/")[-1]
+
+
+def get_actor_id(pkl_url):
+ """
+ 取URL的最后一部分作为template_id
+ """
+ pkl_file_name = pkl_url.split("/")[-1]
+ return pkl_file_name.replace(".zip", "")
+
+
+def load_face_box(actor_id, pkl_version: int = 1) -> np.ndarray:
+ """
+ Args: actor_id: str : uid
+ pkl_version: 1, 2
+ """
+ face_det_results_pk_file = os.path.join(
+ "model_people", actor_id, "face_det_results.pkl"
+ )
+ if not os.path.exists(face_det_results_pk_file):
+ raise FileExistsError(f"Not found file:{face_det_results_pk_file}")
+ if pkl_version != 1:
+ with open(face_det_results_pk_file, "rb") as f:
+ res = pickle.load(f)["boxes"]
+ return np.array(res)
+
+ with open(face_det_results_pk_file, mode="rb") as f:
+ unpickler = pickle.Unpickler(f)
+ face_det_results = []
+ try:
+ # TODO,太奇怪了,只能这么结束吗?
+ while 1:
+ data = unpickler.load()
+ face_det_results.append(data)
+ except Exception:
+ pass
+ final_result = []
+ # TODO 改一下代码,这个在numpy1.24及以上不work。
+ final_result = np.concatenate(face_det_results)
+ return final_result
+
+
+def load_seg_model(pth_path, model, device):
+ model.to(device)
+ model.load_state_dict(torch.load(pth_path))
+ return model.eval()
+
+
+def load_w2l_model(pth_path, model, device):
+ model = model.to(device)
+ model.load_state_dict(
+ {
+ k.replace("module.", ""): v
+ for k, v in torch.load(pth_path)["state_dict"].items()
+ }
+ )
+ return model.eval()
+
+
+def read_img(file_path: str, flags=cv2.IMREAD_COLOR):
+ """read jpg img and decode"""
+ return cv2.imread(file_path, flags=flags)
+
+
+def write_img(content, jpg_path: str, mode: str = "wb"):
+ with open(jpg_path, mode) as file:
+ file.write(content)
+
+
+def save_raw_video(speech_id, result_frames, tag="infer"):
+ name = f"temp/{tag}_{speech_id}_result.avi"
+ out = cv2.VideoWriter(name, cv2.VideoWriter_fourcc(*"mp4v"), 25, (1080, 1920))
+
+ for i, frame in enumerate(result_frames):
+ out.write(frame)
+ out.release()
+ return name
+
+
+def save_raw_wav(speech_id, audio_packets, tag="infer", folder="temp"):
+ os.makedirs(folder, exist_ok=True)
+ sample_rate = 16000 # 采样率,每秒44100个采样点
+ # duration = 1 # 音频时长,单位秒
+ num_channels = 1 # 通道数,单声道
+ sample_width = 2 # 采样宽度,单位字节,这里用16位音频
+ os.makedirs("temp", exist_ok=True)
+ saved_wav_name = f"temp/{tag}_{speech_id}.wav"
+
+ with wave.open(saved_wav_name, "wb") as wav_file:
+ # 设置WAV文件参数
+ wav_file.setnchannels(num_channels)
+ wav_file.setsampwidth(sample_width) # 设置采样宽度,单位为字节
+ wav_file.setframerate(sample_rate)
+ wav_file.writeframes(audio_packets)
+
+ return saved_wav_name
+
+
+def morphing(
+ ori_frame: np.ndarray,
+ pred_frame: np.ndarray,
+ box: list,
+ mp_ratio: float,
+ padding: int = 10,
+ file_index: int = 1,
+) -> np.ndarray:
+ """Tool to 图片融合
+ Args:
+ ori_frame: 原始视频帧 01.png - np
+ pred_frame: infer_01.png - np
+ box: box, # 确认[x1,y1,x2,y2]
+ mp_ratio: [0,1] / morphing_frams_num
+ Return:
+ merged numpy array
+ """
+ (y1, y2, x1, x2) = box
+ ori_face = ori_frame[ori_frame.shape[0] // 2 :, :]
+ pred_face = pred_frame[:, :, :3][
+ y1 - padding : y2 + padding, x1 - padding : x2 + padding
+ ]
+ pred_half_face = pred_face[pred_face.shape[0] // 2 :, :]
+ alpha = np.ones_like(ori_face) * mp_ratio
+ pred_frame[:, :, :3][
+ y1 + (y2 - y1) // 2 : y2 + padding, x1 - padding : x2 + padding
+ ] = (1 - alpha) * pred_half_face + alpha * ori_face
+ logger.debug(f"file_index:{file_index}, ratio:{mp_ratio}")
+ return pred_frame
+
+
+# TODO: @txueduo 这里其实和fusion处理相似,原因是mask 非二值,所以要用float alpha 做融合
+def add_alpha(img: np.ndarray, mask: np.ndarray, alpha: int = 0):
+ t0 = time.time()
+ if img.shape[-1] == 3:
+ res_img = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
+ res_img[:, :, :3] = img[:, :, ::-1]
+ else:
+ res_img = np.zeros_like(img, dtype=np.uint8)
+ res_img[:, :, :3] = img[:, :, :3][:, :, ::-1]
+ mask[mask > 125] = 255
+ res_img[:, :, 3] = mask
+ result = res_img
+ # 将新的 alpha 通道添加到图像中
+
+ logger.info(f"add_alpha cost:{time.time() - t0}")
+ return result.astype("uint8")
+
+
+def fusion(person: np.ndarray, background: np.ndarray, mask: np.ndarray) -> np.ndarray:
+ """
+ Args:
+ person: ori_img
+ background: need to fusion img
+ mask: filter mask
+ """
+
+ result = np.zeros_like(person, dtype=np.uint8)
+ # png 输出透明度
+ if len(mask.shape) == 2:
+ background_alpha = np.zeros_like(person, dtype=np.uint8)
+ background_alpha[:, :, :3] = background
+ # background_alpha[:, :, 3] = 255 if not config.output_alpha else 0
+ background_alpha[:, :, 3] = 0
+ mask_rgba = np.repeat(mask[..., None], 4, axis=-1).astype("uint8")
+ else:
+ # jpeg 输出透明度由后面计算确认
+ mask_rgba = mask
+ background_alpha = background
+ mask_rgba = mask_rgba / 255.0
+ result = mask_rgba * person + (1 - mask_rgba) * background_alpha
+ return result.astype("uint8")
+
+
+def chg_bg(
+ ori_img_path: Union[str, np.ndarray] = None,
+ bg_img_path: Union[str, np.ndarray] = None,
+ mask: Union[str, np.ndarray] = None,
+ default_flags=cv2.IMREAD_COLOR,
+):
+ """"""
+ if isinstance(bg_img_path, str):
+ flags = cv2.IMREAD_UNCHANGED if "png" in bg_img_path else default_flags
+ bg_img = cv2.imread(bg_img_path, flags=flags)
+ else:
+ bg_img = bg_img_path
+ if isinstance(ori_img_path, str):
+ flags = cv2.IMREAD_UNCHANGED if "png" in ori_img_path else default_flags
+ ori_img = cv2.imread(ori_img_path, flags=flags)
+ else:
+ ori_img = ori_img_path
+ logger.debug(
+ f"bg-img shape:{bg_img.shape}, ori_img_path:{ori_img_path}, flags:{flags}"
+ )
+ if ori_img.shape[-1] == 4:
+ return fusion(ori_img, bg_img, ori_img[:, :, 3])
+ elif ori_img.shape[-1] == 3:
+ if isinstance(mask, str):
+ mask = cv2.imread(mask)
+ return fusion(ori_img, bg_img, mask)
+ else:
+ raise ValueError("img shape[-1] must be 3 or 4")
+
+
+def load_config(model_speaker_id:str):
+ """"""
+ logger.info(f"model_speaker_id: {model_speaker_id}")
+ res = requests.post(
+ config.GET_segments, json={"modelSpeakerId": model_speaker_id}
+ ).json()["data"]
+ logger.info(res)
+ model_url = res["modeling_url"]
+ frame_config = res.get("frame_config", [[1,200, True]])
+ # frame_config = [[1, 591, True]]
+ padding = res["padding"] # padding 过于小,会存在下巴没有推理的现象
+ face_classes = res.get("face_classes", [1,11,12,13])
+ if not frame_config:
+ raise RuntimeError(f"frame config {frame_config} not json")
+
+ # 新增字段 trans_method
+ trans_method = res.get("trans_method", 2)
+ # 新增过渡参数字段
+ infer_silent_num = res.get("infer_silent_num", 5)
+ morphing_num = res.get("morphing_num", 25)
+ pkl_version = res.get("pkl_version", 1)
+ logger.info(f"infer_silent_num:{infer_silent_num}, morphing_num:{morphing_num}")
+ return model_url, frame_config, padding, face_classes, trans_method, infer_silent_num, morphing_num, pkl_version
+
+
+def get_trans_idxes(need_post_morphing, need_infer_silent, post_morphint_idx,infer_silent_idx, file_index):
+ if need_post_morphing:
+ post_morphint_idxes = list(range(1,6)) if post_morphint_idx==1 else list(range(post_morphint_idx)[-5:])
+ else: post_morphint_idxes = [0,0,0,0,0] #typeignore
+ if need_infer_silent:
+ infer_silent_idxes = list(range(1,6)) if infer_silent_idx==1 else list(range(infer_silent_idx)[-5:])
+ else: infer_silent_idxes =[0,0,0,0,0]
+ file_idxes = list(range(1,6)) if file_index==1 else list(range(file_index))[-5:]
+ return file_idxes, post_morphint_idxes, infer_silent_idxes