Compare commits

..

1 Commits

Author SHA1 Message Date
jocelyn
f3dcbdc876 [ADD]add logic of loop frame 2025-06-10 15:11:01 +08:00
2 changed files with 1207 additions and 0 deletions

View File

@ -0,0 +1,318 @@
from utils.log import logger
def play_in_loop_v2(
segments,
startfrom,
batch_num,
last_direction,
is_silent,
is_silent_,
first_speak,
last_speak,
):
"""
batch_num: 初始和结束每一帧都这么判断
1静默时在静默段循环 左边界正向右边界反向, 根据上一次方向和位置给出新的方向和位置
2静默转说话 就近到说话段pre_falg, post_flag, 都为true VS 其中一个为true
3说话转静默 动作段播完再进入静默(如果还在持续说话静默段不循环)
4在整个视频左端点 开始端只能正向静默时循环说话时走2
5在整个视频右端点 开始时只能反向静默时循环说话时走2
6根据方向获取batch_num 数量的视频帧return batch_idxes, current_direction
Args:
segments: 循环帧配置 [[st, ed, True], ...]
startfrom: cur_pos
batch_num: 5
last_direction: 0反向1正向
is_silent: 0说话态1动作态
is_silent_: 目前不明确后面可能废弃
first_speak: 记录是不是第一次讲话
last_speak: 记录是不是讲话结束那一刻
"""
frames = []
cur_pos = startfrom
cur_direction = last_direction
is_first_speak_frame = first_speak
is_last_speak_frame = True if last_speak and batch_num == 1 else False
while batch_num != 0:
# 获取当前帧的所在子分割区间
sub_seg_idx = subseg_judge(cur_pos, segments)
# 获取移动方向
next_direction, next_pos = get_next_direction(
segments,
cur_pos,
cur_direction,
is_silent,
sub_seg_idx,
is_first_speak_frame,
is_last_speak_frame,
)
# 获取指定方向的帧
next_pos = get_next_frame(next_pos, next_direction)
frames.append(next_pos)
batch_num -= 1
is_first_speak_frame = (
True if first_speak and batch_num == config.batch_size else False
)
is_last_speak_frame = True if last_speak and batch_num == 1 else False
cur_direction = next_direction
cur_pos = next_pos
return frames, next_direction
def subseg_judge(cur_pos, segments):
for idx, frame_seg in enumerate(segments):
if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]:
return idx
if cur_pos == 0:
return 0
def get_next_direction(
segments,
cur_pos,
cur_direction,
is_silent,
sub_seg_idx,
is_first_speak_frame: bool = False,
is_last_speak_frame: bool = False,
):
"""
3.3.0 循环帧需求想尽快走到预期状态
if 动作段
if 开始说话
if 边界
if 正向
pass
else:
pass
else:
if 正向
pass
else:
pass
elif 静默:
同上
elif 说话中
同上
elif 说话结束
同上
elif 静默段
同上
Args:
is_first_speak_frame: 开始说话flag
is_last_speak_frame 说话结束flag
"""
left, right, loop_flag = segments[sub_seg_idx]
if loop_flag:
if is_silent == 1:
next_direct, next_pos = pure_silent(
segments, left, right, cur_pos, cur_direction, sub_seg_idx
)
logger.debug(
f"cur_pos{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}"
)
elif is_silent == 0:
next_direct, next_pos = silent2action(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_first_speak_frame,
)
logger.debug(
f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}"
)
else:
if is_silent == 1:
next_direct, next_pos = action2silent(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_last_speak_frame,
)
logger.debug(
f"cur_pos{cur_pos} next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}"
)
elif is_silent == 0:
next_direct, next_pos = pure_action(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_last_speak_frame,
)
logger.debug(
f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}"
)
return next_direct, next_pos
def get_next_frame(cur_pos, cur_direction):
"""根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的
不需要再做额外的边界判断
"""
# 正向
if cur_direction == 1:
return cur_pos + 1
# 反向
elif cur_direction == 0:
return cur_pos - 1
def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx):
"""
loop_flag == True and is_silent==1
whether border
whether forward
Return:
next_direction
"""
# 左边界正向,右边界反向
if cur_pos == segments[0][0]:
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
# 右边界,反向
if cur_pos == right:
return 0, cur_pos
# 左边界,正向
if cur_pos == left:
return 1, cur_pos
# 非边界,之前正向,则继续正向,否则反向
if cur_pos > left and cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
def pure_action(
segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
):
"""
loop_flag ==False and is_silent == 0
动作播完正向到静默段 (存在跳段行为)
whether border
whether forward # 正播反播
Args:
is_last_speak_frame: 最后说话结束时刻
Return: next_direction
"""
if cur_pos == segments[0][0]:
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
if is_last_speak_frame:
# 动作段在末尾,向前找静默
if sub_seg_idx == len(segments) - 1:
return 0, cur_pos
# 动作段在开始, 向后
if sub_seg_idx == 0:
return 1, cur_pos
# 动作段在中间,就近原则
mid = left + (right - left + 1) // 2
# 就近原则优先
if cur_pos < mid:
return 0, cur_pos
else:
return 1, cur_pos
else:
# 其他情况,播放方向一致
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
def silent2action(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_first_speak_frame: bool = False,
):
"""
在静默区间但是在讲话
loop_flag=True and is_silent == 0
whether border
whether forward
Return: next_direction
"""
# 向最近的动作段移动, 如果左面没有动作段
# TODO: 确认下面逻辑是否正确
if (
cur_pos == segments[0][0]
): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
# 在静默左边界处,且仍在讲话
if cur_pos == left:
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
# 在静默右边界处,且仍在讲话
elif cur_pos == right:
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
else:
mid = left + (right - left + 1) // 2
# !!就近原则只对第一次说话有效,其他情况遵循上一次状态
if is_first_speak_frame:
# 如果第一段
if sub_seg_idx == 0 and segments[0][2]:
return 1, cur_pos
# 如果最后一段
elif sub_seg_idx == len(segments) - 1 and segments[-1][2]:
return 0, cur_pos
if cur_pos < mid:
return 0, cur_pos
else:
return 1, cur_pos
else:
if cur_direction == 1:
return 1, cur_pos
elif cur_direction == 0:
return 0, cur_pos
def action2silent(
segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
):
"""
loop_flag=False and is_silent==1
whether border
Return: next_direction
"""
if cur_pos == segments[0][0]:
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
# 动作段,说话结束转静默情况下,就近原则,进入静默
if is_last_speak_frame:
mid = left + (right - left + 1) // 2
if cur_pos < mid:
return 0, cur_pos
else:
return 1, cur_pos
else:
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos

889
utils/util.py Normal file
View File

@ -0,0 +1,889 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import pickle
import numpy as np
import zipfile
from config import config
import cv2
import torch
from turbojpeg import TurboJPEG
import platform
import wave
import requests
from log import logger
import time
import shutil
from utils import create_oss
from typing import Union
import traceback
def unzip(zip_path, save_path):
with zipfile.ZipFile(zip_path, "r") as file:
file.extractall(save_path)
if platform.system() == "Linux":
jpeg = TurboJPEG()
else:
jpeg = TurboJPEG(lib_path=r"libjpeg-turbo-gcc64\bin\libturbojpeg.dll")
def datagen(
pool2,
mel_batch,
start_idx_list,
pk_name,
face_det_results,
padding,
is_silent: int = 1,
):
"""
数据批量预处理
Args:
is_silent: 是否为静默帧静默帧读取后处理过的帧
"""
img_batch, frame_batch, coords_batch, mask_batch = [], [], [], []
futures = []
return_gan = []
for start_idx in start_idx_list:
future = pool2.submit(
get_face_image,
pk_name,
start_idx,
face_det_results,
padding,
is_silent,
)
futures.append(future)
for future in futures:
frame_to_save, coords, face, mask = future.result()
img_batch.append(face)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
mask_batch.append(mask)
if len(img_batch) >= config.wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, config.img_size // 2 :] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
mel_batch = mel_batch[..., np.newaxis]
return_gan.append(
(
img_batch,
mel_batch,
frame_batch,
mask_batch,
coords_batch,
start_idx_list,
)
)
img_batch, frame_batch, mask_batch, coords_batch = [], [], [], []
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, config.img_size // 2 :] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
mel_batch = mel_batch[..., np.newaxis]
return_gan.append(
(
img_batch,
mel_batch,
frame_batch,
mask_batch,
coords_batch,
start_idx_list,
)
)
return return_gan
def get_face_image(
pkl_name,
start_idx,
face_det_results,
infer_padding,
is_silent: int = 1,
):
# 判断是否经过换背景,
post_suffix = (
"png" if not os.path.exists(f"model_people/mask/{pkl_name}") else "jpg"
)
pre_suffix = os.path.join("model_people", pkl_name, "image")
file = os.path.join(pre_suffix, f"img{start_idx:05d}.{post_suffix}")
if not os.path.exists(file):
logger.error(f"Not found image file: {file}")
raise FileExistsError(f"Not found image file: {file}")
frame_to_save = read_img(
file, flags=cv2.IMREAD_UNCHANGED if post_suffix == "png" else cv2.IMREAD_COLOR
)
mask = None
if post_suffix == "jpg":
has_mask = os.path.exists(f"model_people/mask/{pkl_name}")
if not has_mask:
raise Exception("should use mask,however no mask is found.")
mask_file = os.path.join(
f"model_people/mask/{pkl_name}", f"img{start_idx:05d}.jpg"
)
mask = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)
elif post_suffix == "png":
mask = frame_to_save[:, :, 3]
else:
raise ValueError(f"file type {post_suffix} not in (png,jpeg)")
y1, y2, x1, x2 = (
face_det_results[start_idx - 1]
if len(face_det_results[start_idx - 1]) == 4
else face_det_results[start_idx - 1][1]
)
y2 = y2 + infer_padding
face, coords = frame_to_save[:, :, :3][y1:y2, x1:x2], (y1, y2, x1, x2)
face = cv2.resize(face, (config.img_size, config.img_size))
return frame_to_save, coords, face, mask
def play_in_loop_v2(
segments,
startfrom,
batch_num,
last_direction,
is_silent,
is_silent_,
first_speak,
last_speak,
):
"""
batch_num: 初始和结束每一帧都这么判断
1静默时在静默段循环 左边界正向右边界反向, 根据上一次方向和位置给出新的方向和位置
2静默转说话 就近到说话段pre_falg, post_flag, 都为true VS 其中一个为true
3说话转静默 动作段播完再进入静默(如果还在持续说话静默段不循环)
4在整个视频左端点 开始端只能正向静默时循环说话时走2
5在整个视频右端点 开始时只能反向静默时循环说话时走2
6根据方向获取batch_num 数量的视频帧return batch_idxes, current_direction
Args:
segments: 循环帧配置 [[st, ed, True], ...]
startfrom: cur_pos
batch_num: 5
last_direction: 0反向1正向
is_silent: 0说话态1动作态
is_silent_: 目前不明确后面可能废弃
first_speak: 记录是不是第一次讲话
last_speak: 记录是不是讲话结束那一刻
"""
frames = []
cur_pos = startfrom
cur_direction = last_direction
is_first_speak_frame = first_speak
is_last_speak_frame = True if last_speak and batch_num == 1 else False
while batch_num != 0:
# 获取当前帧的所在子分割区间
sub_seg_idx = subseg_judge(cur_pos, segments)
# 获取移动方向
next_direction, next_pos = get_next_direction(
segments,
cur_pos,
cur_direction,
is_silent,
sub_seg_idx,
is_first_speak_frame,
is_last_speak_frame,
)
# 获取指定方向的帧
next_pos = get_next_frame(next_pos, next_direction)
frames.append(next_pos)
batch_num -= 1
is_first_speak_frame = (
True if first_speak and batch_num == config.wav2lip_batch_size else False
)
is_last_speak_frame = True if last_speak and batch_num == 1 else False
cur_direction = next_direction
cur_pos = next_pos
return frames, next_direction
def get_next_direction(
segments,
cur_pos,
cur_direction,
is_silent,
sub_seg_idx,
is_first_speak_frame: bool = False,
is_last_speak_frame: bool = False,
):
"""
3.3.0 循环帧需求想尽快走到预期状态
if 动作段
if 开始说话
if 边界
if 正向
pass
else:
pass
else:
if 正向
pass
else:
pass
elif 静默:
同上
elif 说话中
同上
elif 说话结束
同上
elif 静默段
同上
Args:
is_first_speak_frame: 开始说话flag
is_last_speak_frame 说话结束flag
"""
left, right, loop_flag = segments[sub_seg_idx]
if loop_flag:
if is_silent == 1:
next_direct, next_pos = pure_silent(
segments, left, right, cur_pos, cur_direction, sub_seg_idx
)
logger.debug(
f"cur_pos{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}"
)
elif is_silent == 0:
next_direct, next_pos = silent2action(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_first_speak_frame,
)
logger.debug(
f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}"
)
else:
if is_silent == 1:
next_direct, next_pos = action2silent(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_last_speak_frame,
)
logger.debug(
f"cur_pos{cur_pos} next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}"
)
elif is_silent == 0:
next_direct, next_pos = pure_action(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_last_speak_frame,
)
logger.debug(
f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}"
)
return next_direct, next_pos
def get_next_frame(cur_pos, cur_direction):
"""根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的
不需要再做额外的边界判断
"""
# 正向
if cur_direction == 1:
return cur_pos + 1
# 反向
elif cur_direction == 0:
return cur_pos - 1
def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx):
"""
loop_flag == True and is_silent==1
whether border
whether forward
Return:
next_direction
"""
# 左边界正向,右边界反向
if cur_pos == segments[0][0]:
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
# 右边界,反向
if cur_pos == right:
return 0, cur_pos
# 左边界,正向
if cur_pos == left:
return 1, cur_pos
# 非边界,之前正向,则继续正向,否则反向
if cur_pos > left and cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
def pure_action(
segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
):
"""
loop_flag ==False and is_silent == 0
动作播完正向到静默段 (存在跳段行为)
whether border
whether forward # 正播反播
Args:
is_last_speak_frame: 最后说话结束时刻
Return: next_direction
"""
if cur_pos == segments[0][0]:
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
if is_last_speak_frame:
# 动作段在末尾,向前找静默
if sub_seg_idx == len(segments) - 1:
return 0, cur_pos
# 动作段在开始, 向后
if sub_seg_idx == 0:
return 1, cur_pos
# 动作段在中间,就近原则
mid = left + (right - left + 1) // 2
# 就近原则优先
if cur_pos < mid:
return 0, cur_pos
else:
return 1, cur_pos
else:
# 其他情况,播放方向一致
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
def silent2action(
segments,
left,
right,
cur_pos,
cur_direction,
sub_seg_idx,
is_first_speak_frame: bool = False,
):
"""
在静默区间但是在讲话
loop_flag=True and is_silent == 0
whether border
whether forward
Return: next_direction
"""
# 向最近的动作段移动, 如果左面没有动作段
# TODO: 确认下面逻辑是否正确
if (
cur_pos == segments[0][0]
): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
# 在静默左边界处,且仍在讲话
if cur_pos == left:
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
# 在静默右边界处,且仍在讲话
elif cur_pos == right:
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
else:
mid = left + (right - left + 1) // 2
# !!就近原则只对第一次说话有效,其他情况遵循上一次状态
if is_first_speak_frame:
# 如果第一段
if sub_seg_idx == 0 and segments[0][2]:
return 1, cur_pos
# 如果最后一段
elif sub_seg_idx == len(segments) - 1 and segments[-1][2]:
return 0, cur_pos
if cur_pos < mid:
return 0, cur_pos
else:
return 1, cur_pos
else:
if cur_direction == 1:
return 1, cur_pos
elif cur_direction == 0:
return 0, cur_pos
def action2silent(
segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
):
"""
loop_flag=False and is_silent==1
whether border
Return: next_direction
"""
if cur_pos == segments[0][0]:
return 1, cur_pos
if cur_pos == segments[-1][1]:
return 0, cur_pos
# 动作段,说话结束转静默情况下,就近原则,进入静默
if is_last_speak_frame:
mid = left + (right - left + 1) // 2
if cur_pos < mid:
return 0, cur_pos
else:
return 1, cur_pos
else:
if cur_direction == 1:
return 1, cur_pos
else:
return 0, cur_pos
def subseg_judge(cur_pos, segments):
for idx, frame_seg in enumerate(segments):
if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]:
return idx
if cur_pos == 0:
return 0
def get_oss_config_masx():
logger.info("download from mas-x")
OSS_ENDPOINT = "oss-cn-hangzhou.aliyuncs.com"
OSS_KEY = "LTAIgPS3rCV89vYO"
OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT"
OSS_BUCKET = "mas-x"
return OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
def get_oss_config_zqinghui():
logger.info("download from zqinghui")
OSS_ENDPOINT = "oss-cn-zhangjiakou.aliyuncs.com"
OSS_KEY = "LTAIgPS3rCV89vYO"
OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT"
OSS_BUCKET = "zqinghui"
return OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
def get_oss_config(url: str):
if "zqinghui" in url:
return get_oss_config_zqinghui()
return get_oss_config_masx()
def download_file_from_oss_url(
model_dir: str,
url_path: str,
download_type: str,
download_file_len: int,
downloadQueue,
downloadFileLen: int,
):
"""
默认从oss上根据url下载不同的type数据如果下载失败获取zqinghui配置下的数据
Args:
"""
try:
logger.info(f"下载url_path: {url_path}")
OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET = get_oss_config(url_path)
oss_class = create_oss.Oss_connection(
OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
)
oss_class.oss_get_file(
model_dir,
url_path,
type=download_type,
keys=url_path.lstrip(f"https://{OSS_BUCKET}.{OSS_ENDPOINT}"),
downloadQueue=downloadQueue,
downloadFileLen=downloadFileLen,
downloadAllLen=download_file_len,
)
except Exception:
logger.warning(f"下载 {url_path} error: {traceback.format_exc()}")
if url_path.split("/")[-1] in os.listdir(model_dir):
os.remove(os.path.join(model_dir, url_path.split('/')[-1]))
def download_checkpoint(model_name, model_url, download_file_len, downloadQueue):
logger.info(f"Download checkpoint of model people: {model_url}")
model_dir = config.model_folder
os.makedirs(model_dir, exist_ok=True)
if not os.path.exists(os.path.join(model_dir, model_name)):
download_file_len += 1
OSS_ENDPOINT = "oss-cn-zhangjiakou.aliyuncs.com"
OSS_KEY = "LTAIgPS3rCV89vYO"
OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT"
OSS_BUCKET = "zqinghui"
oss_class = create_oss.Oss_connection(
OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
)
oss_class.oss_get_file(
model_dir, model_url, type="模型", downloadQueue=downloadQueue
)
return download_file_len, downloadQueue
def download(modelSpeakerId, downloadQueue, res):
"""
1异常退出时,删除所有已经下载的临时文件
2重新下载时根据上次下载一半的结果继续下载
3错误拦截有问题
4一些文件的创建和清理逻辑模糊
"""
# TODO: 这个变量什么作用
download_file_len = 3
if res["code"] != 0:
logger.warning(f"无法获取当前数字人信息,{res['code']}")
return "ERROR"
res = res["data"]
pkl_url, model_url, video_needle_url, mask_needle_url = (
res["pkl_url"],
res["modeling_url"],
res["video_needle_url"],
res["mask_needle_url"],
)
model_name = model_url.split("/")[-1] # 这会取URL的最后一部分作为文件名
download_file_len, downloadQueue = download_checkpoint(
model_name, model_url, download_file_len, downloadQueue
)
model_dir = "model_people/"
os.makedirs(model_dir, exist_ok=True)
pkl_url_zip_name = pkl_url.split("/")[-1] # 这会取URL的最后一部分作为文件名
pkl_url_name = pkl_url_zip_name.replace(".zip", "")
logger.info(f"pkl_url_name{pkl_url_name}")
if pkl_url_name in os.listdir(model_dir):
return f"model people of {modelSpeakerId} not need to download"
# TODO: 目前有报错,会无法删除已经下载的文件
# pkl
logger.info(f"download model people of : {modelSpeakerId}")
if pkl_url != "none":
download_file_from_oss_url(
model_dir,
pkl_url,
"人物",
download_file_len,
downloadQueue,
downloadFileLen=1,
)
unzip(
os.path.join(model_dir, pkl_url_zip_name),
os.path.join(model_dir, pkl_url_name),
)
shutil.rmtree(os.path.join(model_dir, pkl_url_name, "image"), ignore_errors=True)
# video
if video_needle_url != "none":
# TODO: 从配置获取
download_file_from_oss_url(
model_dir,
video_needle_url,
"人物",
download_file_len,
downloadQueue,
downloadFileLen=2,
)
# 没有换背景前的数字人
unzip(
os.path.join(model_dir, video_needle_url.split('/')[-1]),
os.path.join(model_dir, pkl_url_name, "image"),
)
shutil.rmtree(os.path.join(model_dir, modelSpeakerId), ignore_errors=True)
os.rename(os.path.join(model_dir, pkl_url_name), os.path.join(model_dir, modelSpeakerId))
os.remove(os.path.join(model_dir, pkl_url_zip_name))
# mask
if mask_needle_url != "none":
download_file_from_oss_url(
model_dir,
mask_needle_url,
"mask",
download_file_len,
downloadQueue,
downloadFileLen=3,
)
# TODO: 这里面放的是纯白色背景
mask_dir = os.path.join(model_dir, "mask", modelSpeakerId)
os.makedirs(mask_dir, exist_ok=True)
# unzip(f"{model_dir}/{mask_needle_url.split('/')[-1]}", mask_dir)
unzip(os.path.join(model_dir, mask_needle_url.split('/')[-1]), mask_dir)
# 删除zip包
os.remove(os.path.join(model_dir, mask_needle_url.split('/')[-1]))
# 这么处理时因为pkl 和 img 可能链接最后一级重名
video_img_file = os.path.join(model_dir, video_needle_url.split('/')[-1])
if os.path.exists(video_img_file):
os.remove(video_img_file)
downloadQueue.put("100")
logger.info(f"model people of {modelSpeakerId} 下载完成")
return "下载完成"
def get_model_local_path(model_url: str):
return model_url.split("/")[-1]
def get_actor_id(pkl_url):
"""
取URL的最后一部分作为template_id
"""
pkl_file_name = pkl_url.split("/")[-1]
return pkl_file_name.replace(".zip", "")
def load_face_box(actor_id, pkl_version: int = 1) -> np.ndarray:
"""
Args: actor_id: str : uid
pkl_version: 1, 2
"""
face_det_results_pk_file = os.path.join(
"model_people", actor_id, "face_det_results.pkl"
)
if not os.path.exists(face_det_results_pk_file):
raise FileExistsError(f"Not found file:{face_det_results_pk_file}")
if pkl_version != 1:
with open(face_det_results_pk_file, "rb") as f:
res = pickle.load(f)["boxes"]
return np.array(res)
with open(face_det_results_pk_file, mode="rb") as f:
unpickler = pickle.Unpickler(f)
face_det_results = []
try:
# TODO太奇怪了只能这么结束吗
while 1:
data = unpickler.load()
face_det_results.append(data)
except Exception:
pass
final_result = []
# TODO 改一下代码这个在numpy1.24及以上不work。
final_result = np.concatenate(face_det_results)
return final_result
def load_seg_model(pth_path, model, device):
model.to(device)
model.load_state_dict(torch.load(pth_path))
return model.eval()
def load_w2l_model(pth_path, model, device):
model = model.to(device)
model.load_state_dict(
{
k.replace("module.", ""): v
for k, v in torch.load(pth_path)["state_dict"].items()
}
)
return model.eval()
def read_img(file_path: str, flags=cv2.IMREAD_COLOR):
"""read jpg img and decode"""
return cv2.imread(file_path, flags=flags)
def write_img(content, jpg_path: str, mode: str = "wb"):
with open(jpg_path, mode) as file:
file.write(content)
def save_raw_video(speech_id, result_frames, tag="infer"):
name = f"temp/{tag}_{speech_id}_result.avi"
out = cv2.VideoWriter(name, cv2.VideoWriter_fourcc(*"mp4v"), 25, (1080, 1920))
for i, frame in enumerate(result_frames):
out.write(frame)
out.release()
return name
def save_raw_wav(speech_id, audio_packets, tag="infer", folder="temp"):
os.makedirs(folder, exist_ok=True)
sample_rate = 16000 # 采样率每秒44100个采样点
# duration = 1 # 音频时长,单位秒
num_channels = 1 # 通道数,单声道
sample_width = 2 # 采样宽度单位字节这里用16位音频
os.makedirs("temp", exist_ok=True)
saved_wav_name = f"temp/{tag}_{speech_id}.wav"
with wave.open(saved_wav_name, "wb") as wav_file:
# 设置WAV文件参数
wav_file.setnchannels(num_channels)
wav_file.setsampwidth(sample_width) # 设置采样宽度,单位为字节
wav_file.setframerate(sample_rate)
wav_file.writeframes(audio_packets)
return saved_wav_name
def morphing(
ori_frame: np.ndarray,
pred_frame: np.ndarray,
box: list,
mp_ratio: float,
padding: int = 10,
file_index: int = 1,
) -> np.ndarray:
"""Tool to 图片融合
Args:
ori_frame: 原始视频帧 01.png - np
pred_frame: infer_01.png - np
box: box, # 确认[x1,y1,x2,y2]
mp_ratio: [0,1] / morphing_frams_num
Return:
merged numpy array
"""
(y1, y2, x1, x2) = box
ori_face = ori_frame[ori_frame.shape[0] // 2 :, :]
pred_face = pred_frame[:, :, :3][
y1 - padding : y2 + padding, x1 - padding : x2 + padding
]
pred_half_face = pred_face[pred_face.shape[0] // 2 :, :]
alpha = np.ones_like(ori_face) * mp_ratio
pred_frame[:, :, :3][
y1 + (y2 - y1) // 2 : y2 + padding, x1 - padding : x2 + padding
] = (1 - alpha) * pred_half_face + alpha * ori_face
logger.debug(f"file_index:{file_index}, ratio:{mp_ratio}")
return pred_frame
# TODO: @txueduo 这里其实和fusion处理相似原因是mask 非二值所以要用float alpha 做融合
def add_alpha(img: np.ndarray, mask: np.ndarray, alpha: int = 0):
t0 = time.time()
if img.shape[-1] == 3:
res_img = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
res_img[:, :, :3] = img[:, :, ::-1]
else:
res_img = np.zeros_like(img, dtype=np.uint8)
res_img[:, :, :3] = img[:, :, :3][:, :, ::-1]
mask[mask > 125] = 255
res_img[:, :, 3] = mask
result = res_img
# 将新的 alpha 通道添加到图像中
logger.info(f"add_alpha cost:{time.time() - t0}")
return result.astype("uint8")
def fusion(person: np.ndarray, background: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""
Args:
person: ori_img
background: need to fusion img
mask: filter mask
"""
result = np.zeros_like(person, dtype=np.uint8)
# png 输出透明度
if len(mask.shape) == 2:
background_alpha = np.zeros_like(person, dtype=np.uint8)
background_alpha[:, :, :3] = background
# background_alpha[:, :, 3] = 255 if not config.output_alpha else 0
background_alpha[:, :, 3] = 0
mask_rgba = np.repeat(mask[..., None], 4, axis=-1).astype("uint8")
else:
# jpeg 输出透明度由后面计算确认
mask_rgba = mask
background_alpha = background
mask_rgba = mask_rgba / 255.0
result = mask_rgba * person + (1 - mask_rgba) * background_alpha
return result.astype("uint8")
def chg_bg(
ori_img_path: Union[str, np.ndarray] = None,
bg_img_path: Union[str, np.ndarray] = None,
mask: Union[str, np.ndarray] = None,
default_flags=cv2.IMREAD_COLOR,
):
""""""
if isinstance(bg_img_path, str):
flags = cv2.IMREAD_UNCHANGED if "png" in bg_img_path else default_flags
bg_img = cv2.imread(bg_img_path, flags=flags)
else:
bg_img = bg_img_path
if isinstance(ori_img_path, str):
flags = cv2.IMREAD_UNCHANGED if "png" in ori_img_path else default_flags
ori_img = cv2.imread(ori_img_path, flags=flags)
else:
ori_img = ori_img_path
logger.debug(
f"bg-img shape:{bg_img.shape}, ori_img_path:{ori_img_path}, flags:{flags}"
)
if ori_img.shape[-1] == 4:
return fusion(ori_img, bg_img, ori_img[:, :, 3])
elif ori_img.shape[-1] == 3:
if isinstance(mask, str):
mask = cv2.imread(mask)
return fusion(ori_img, bg_img, mask)
else:
raise ValueError("img shape[-1] must be 3 or 4")
def load_config(model_speaker_id:str):
""""""
logger.info(f"model_speaker_id: {model_speaker_id}")
res = requests.post(
config.GET_segments, json={"modelSpeakerId": model_speaker_id}
).json()["data"]
logger.info(res)
model_url = res["modeling_url"]
frame_config = res.get("frame_config", [[1,200, True]])
# frame_config = [[1, 591, True]]
padding = res["padding"] # padding 过于小,会存在下巴没有推理的现象
face_classes = res.get("face_classes", [1,11,12,13])
if not frame_config:
raise RuntimeError(f"frame config {frame_config} not json")
# 新增字段 trans_method
trans_method = res.get("trans_method", 2)
# 新增过渡参数字段
infer_silent_num = res.get("infer_silent_num", 5)
morphing_num = res.get("morphing_num", 25)
pkl_version = res.get("pkl_version", 1)
logger.info(f"infer_silent_num:{infer_silent_num} morphing_num{morphing_num}")
return model_url, frame_config, padding, face_classes, trans_method, infer_silent_num, morphing_num, pkl_version
def get_trans_idxes(need_post_morphing, need_infer_silent, post_morphint_idx,infer_silent_idx, file_index):
if need_post_morphing:
post_morphint_idxes = list(range(1,6)) if post_morphint_idx==1 else list(range(post_morphint_idx)[-5:])
else: post_morphint_idxes = [0,0,0,0,0] #typeignore
if need_infer_silent:
infer_silent_idxes = list(range(1,6)) if infer_silent_idx==1 else list(range(infer_silent_idx)[-5:])
else: infer_silent_idxes =[0,0,0,0,0]
file_idxes = list(range(1,6)) if file_index==1 else list(range(file_index))[-5:]
return file_idxes, post_morphint_idxes, infer_silent_idxes