Compare commits
1 Commits
f3dcbdc876
...
dac34b0962
Author | SHA1 | Date | |
---|---|---|---|
![]() |
dac34b0962 |
889
utils/util.py
889
utils/util.py
@ -1,889 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import pickle
|
||||
import numpy as np
|
||||
import zipfile
|
||||
|
||||
from config import config
|
||||
import cv2
|
||||
import torch
|
||||
from turbojpeg import TurboJPEG
|
||||
import platform
|
||||
import wave
|
||||
import requests
|
||||
|
||||
from log import logger
|
||||
import time
|
||||
import shutil
|
||||
from utils import create_oss
|
||||
from typing import Union
|
||||
import traceback
|
||||
|
||||
|
||||
def unzip(zip_path, save_path):
|
||||
with zipfile.ZipFile(zip_path, "r") as file:
|
||||
file.extractall(save_path)
|
||||
|
||||
|
||||
if platform.system() == "Linux":
|
||||
jpeg = TurboJPEG()
|
||||
else:
|
||||
jpeg = TurboJPEG(lib_path=r"libjpeg-turbo-gcc64\bin\libturbojpeg.dll")
|
||||
|
||||
|
||||
def datagen(
|
||||
pool2,
|
||||
mel_batch,
|
||||
start_idx_list,
|
||||
pk_name,
|
||||
face_det_results,
|
||||
padding,
|
||||
is_silent: int = 1,
|
||||
):
|
||||
"""
|
||||
数据批量预处理
|
||||
Args:
|
||||
is_silent: 是否为静默帧,静默帧读取后处理过的帧
|
||||
"""
|
||||
img_batch, frame_batch, coords_batch, mask_batch = [], [], [], []
|
||||
futures = []
|
||||
|
||||
return_gan = []
|
||||
for start_idx in start_idx_list:
|
||||
future = pool2.submit(
|
||||
get_face_image,
|
||||
pk_name,
|
||||
start_idx,
|
||||
face_det_results,
|
||||
padding,
|
||||
is_silent,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
for future in futures:
|
||||
frame_to_save, coords, face, mask = future.result()
|
||||
img_batch.append(face)
|
||||
frame_batch.append(frame_to_save)
|
||||
coords_batch.append(coords)
|
||||
mask_batch.append(mask)
|
||||
|
||||
if len(img_batch) >= config.wav2lip_batch_size:
|
||||
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
||||
img_masked = img_batch.copy()
|
||||
img_masked[:, config.img_size // 2 :] = 0
|
||||
|
||||
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
|
||||
mel_batch = mel_batch[..., np.newaxis]
|
||||
return_gan.append(
|
||||
(
|
||||
img_batch,
|
||||
mel_batch,
|
||||
frame_batch,
|
||||
mask_batch,
|
||||
coords_batch,
|
||||
start_idx_list,
|
||||
)
|
||||
)
|
||||
img_batch, frame_batch, mask_batch, coords_batch = [], [], [], []
|
||||
|
||||
if len(img_batch) > 0:
|
||||
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
||||
img_masked = img_batch.copy()
|
||||
img_masked[:, config.img_size // 2 :] = 0
|
||||
|
||||
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0
|
||||
mel_batch = mel_batch[..., np.newaxis]
|
||||
return_gan.append(
|
||||
(
|
||||
img_batch,
|
||||
mel_batch,
|
||||
frame_batch,
|
||||
mask_batch,
|
||||
coords_batch,
|
||||
start_idx_list,
|
||||
)
|
||||
)
|
||||
|
||||
return return_gan
|
||||
|
||||
|
||||
def get_face_image(
|
||||
pkl_name,
|
||||
start_idx,
|
||||
face_det_results,
|
||||
infer_padding,
|
||||
is_silent: int = 1,
|
||||
):
|
||||
# 判断是否经过换背景,
|
||||
post_suffix = (
|
||||
"png" if not os.path.exists(f"model_people/mask/{pkl_name}") else "jpg"
|
||||
)
|
||||
pre_suffix = os.path.join("model_people", pkl_name, "image")
|
||||
file = os.path.join(pre_suffix, f"img{start_idx:05d}.{post_suffix}")
|
||||
|
||||
if not os.path.exists(file):
|
||||
logger.error(f"Not found image file: {file}")
|
||||
raise FileExistsError(f"Not found image file: {file}")
|
||||
frame_to_save = read_img(
|
||||
file, flags=cv2.IMREAD_UNCHANGED if post_suffix == "png" else cv2.IMREAD_COLOR
|
||||
)
|
||||
mask = None
|
||||
if post_suffix == "jpg":
|
||||
has_mask = os.path.exists(f"model_people/mask/{pkl_name}")
|
||||
if not has_mask:
|
||||
raise Exception("should use mask,however no mask is found.")
|
||||
mask_file = os.path.join(
|
||||
f"model_people/mask/{pkl_name}", f"img{start_idx:05d}.jpg"
|
||||
)
|
||||
mask = cv2.imread(mask_file, cv2.IMREAD_GRAYSCALE)
|
||||
elif post_suffix == "png":
|
||||
mask = frame_to_save[:, :, 3]
|
||||
else:
|
||||
raise ValueError(f"file type {post_suffix} not in (png,jpeg)")
|
||||
y1, y2, x1, x2 = (
|
||||
face_det_results[start_idx - 1]
|
||||
if len(face_det_results[start_idx - 1]) == 4
|
||||
else face_det_results[start_idx - 1][1]
|
||||
)
|
||||
|
||||
y2 = y2 + infer_padding
|
||||
face, coords = frame_to_save[:, :, :3][y1:y2, x1:x2], (y1, y2, x1, x2)
|
||||
face = cv2.resize(face, (config.img_size, config.img_size))
|
||||
return frame_to_save, coords, face, mask
|
||||
|
||||
|
||||
def play_in_loop_v2(
|
||||
segments,
|
||||
startfrom,
|
||||
batch_num,
|
||||
last_direction,
|
||||
is_silent,
|
||||
is_silent_,
|
||||
first_speak,
|
||||
last_speak,
|
||||
):
|
||||
"""
|
||||
batch_num: 初始和结束,每一帧都这么判断
|
||||
1、静默时,在静默段循环, 左边界正向,右边界反向, 根据上一次方向和位置,给出新的方向和位置
|
||||
2、静默转说话: 就近到说话段,pre_falg, post_flag, 都为true VS 其中一个为true
|
||||
3、说话转静默: 动作段播完,再进入静默(如果还在持续说话,静默段不循环)
|
||||
4、在整个视频左端点: 开始端只能正向,静默时循环,说话时走2
|
||||
5、在整个视频右端点: 开始时只能反向,静默时循环,说话时走2
|
||||
6、根据方向获取batch_num 数量的视频帧,return batch_idxes, current_direction
|
||||
Args:
|
||||
segments: 循环帧配置 [[st, ed, True], ...]
|
||||
startfrom: cur_pos
|
||||
batch_num: 5
|
||||
last_direction: 0反向1正向
|
||||
is_silent: 0说话态1动作态
|
||||
is_silent_: 目前不明确,后面可能废弃
|
||||
first_speak: 记录是不是第一次讲话
|
||||
last_speak: 记录是不是讲话结束那一刻
|
||||
"""
|
||||
frames = []
|
||||
cur_pos = startfrom
|
||||
cur_direction = last_direction
|
||||
is_first_speak_frame = first_speak
|
||||
is_last_speak_frame = True if last_speak and batch_num == 1 else False
|
||||
while batch_num != 0:
|
||||
# 获取当前帧的所在子分割区间
|
||||
sub_seg_idx = subseg_judge(cur_pos, segments)
|
||||
# 获取移动方向
|
||||
next_direction, next_pos = get_next_direction(
|
||||
segments,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
is_silent,
|
||||
sub_seg_idx,
|
||||
is_first_speak_frame,
|
||||
is_last_speak_frame,
|
||||
)
|
||||
# 获取指定方向的帧
|
||||
next_pos = get_next_frame(next_pos, next_direction)
|
||||
frames.append(next_pos)
|
||||
batch_num -= 1
|
||||
is_first_speak_frame = (
|
||||
True if first_speak and batch_num == config.wav2lip_batch_size else False
|
||||
)
|
||||
is_last_speak_frame = True if last_speak and batch_num == 1 else False
|
||||
|
||||
cur_direction = next_direction
|
||||
cur_pos = next_pos
|
||||
return frames, next_direction
|
||||
|
||||
|
||||
def get_next_direction(
|
||||
segments,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
is_silent,
|
||||
sub_seg_idx,
|
||||
is_first_speak_frame: bool = False,
|
||||
is_last_speak_frame: bool = False,
|
||||
):
|
||||
"""
|
||||
3.3.0 循环帧需求,想尽快走到预期状态
|
||||
if 动作段:
|
||||
if 开始说话:
|
||||
if 边界:
|
||||
if 正向:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
if 正向:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
elif 静默:
|
||||
同上
|
||||
elif 说话中:
|
||||
同上
|
||||
elif 说话结束:
|
||||
同上
|
||||
elif 静默段:
|
||||
同上
|
||||
Args:
|
||||
is_first_speak_frame: 开始说话flag
|
||||
is_last_speak_frame: 说话结束flag
|
||||
"""
|
||||
left, right, loop_flag = segments[sub_seg_idx]
|
||||
if loop_flag:
|
||||
if is_silent == 1:
|
||||
next_direct, next_pos = pure_silent(
|
||||
segments, left, right, cur_pos, cur_direction, sub_seg_idx
|
||||
)
|
||||
logger.debug(
|
||||
f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame:{is_first_speak_frame}"
|
||||
)
|
||||
elif is_silent == 0:
|
||||
next_direct, next_pos = silent2action(
|
||||
segments,
|
||||
left,
|
||||
right,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
sub_seg_idx,
|
||||
is_first_speak_frame,
|
||||
)
|
||||
logger.debug(
|
||||
f"cur_pos:{cur_pos}, next_direct:{next_direct}, is_first_speak_frame{is_first_speak_frame}"
|
||||
)
|
||||
else:
|
||||
if is_silent == 1:
|
||||
next_direct, next_pos = action2silent(
|
||||
segments,
|
||||
left,
|
||||
right,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
sub_seg_idx,
|
||||
is_last_speak_frame,
|
||||
)
|
||||
logger.debug(
|
||||
f"cur_pos{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame},is_last_speak_frame:{is_last_speak_frame}"
|
||||
)
|
||||
elif is_silent == 0:
|
||||
next_direct, next_pos = pure_action(
|
||||
segments,
|
||||
left,
|
||||
right,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
sub_seg_idx,
|
||||
is_last_speak_frame,
|
||||
)
|
||||
logger.debug(
|
||||
f"cur_pos:{cur_pos}, next_direct:{next_direct},is_first_speak_frame{is_first_speak_frame}, is_last_speak_frame:{is_last_speak_frame}"
|
||||
)
|
||||
return next_direct, next_pos
|
||||
|
||||
|
||||
def get_next_frame(cur_pos, cur_direction):
|
||||
"""根据当前帧和方向,获取下一帧,这里应该保证方向上的帧是一定能取到的
|
||||
不需要再做额外的边界判断
|
||||
"""
|
||||
# 正向
|
||||
if cur_direction == 1:
|
||||
return cur_pos + 1
|
||||
# 反向
|
||||
elif cur_direction == 0:
|
||||
return cur_pos - 1
|
||||
|
||||
|
||||
def pure_silent(segments, left, right, cur_pos, cur_direction, sub_seg_idx):
|
||||
"""
|
||||
loop_flag == True and is_silent==1
|
||||
whether border
|
||||
whether forward
|
||||
Return:
|
||||
next_direction
|
||||
"""
|
||||
# 左边界正向,右边界反向
|
||||
if cur_pos == segments[0][0]:
|
||||
return 1, cur_pos
|
||||
if cur_pos == segments[-1][1]:
|
||||
return 0, cur_pos
|
||||
# 右边界,反向
|
||||
if cur_pos == right:
|
||||
return 0, cur_pos
|
||||
# 左边界,正向
|
||||
if cur_pos == left:
|
||||
return 1, cur_pos
|
||||
# 非边界,之前正向,则继续正向,否则反向
|
||||
if cur_pos > left and cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
return 0, cur_pos
|
||||
|
||||
|
||||
def pure_action(
|
||||
segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
|
||||
):
|
||||
"""
|
||||
loop_flag ==False and is_silent == 0
|
||||
动作播完,正向到静默段 (存在跳段行为)
|
||||
whether border
|
||||
whether forward # 正播反播
|
||||
Args:
|
||||
is_last_speak_frame: 最后说话结束时刻
|
||||
Return: next_direction
|
||||
"""
|
||||
if cur_pos == segments[0][0]:
|
||||
return 1, cur_pos
|
||||
if cur_pos == segments[-1][1]:
|
||||
return 0, cur_pos
|
||||
|
||||
if is_last_speak_frame:
|
||||
# 动作段在末尾,向前找静默
|
||||
if sub_seg_idx == len(segments) - 1:
|
||||
return 0, cur_pos
|
||||
# 动作段在开始, 向后
|
||||
if sub_seg_idx == 0:
|
||||
return 1, cur_pos
|
||||
# 动作段在中间,就近原则
|
||||
mid = left + (right - left + 1) // 2
|
||||
# 就近原则优先
|
||||
if cur_pos < mid:
|
||||
return 0, cur_pos
|
||||
else:
|
||||
return 1, cur_pos
|
||||
|
||||
else:
|
||||
# 其他情况,播放方向一致
|
||||
if cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
return 0, cur_pos
|
||||
|
||||
|
||||
def silent2action(
|
||||
segments,
|
||||
left,
|
||||
right,
|
||||
cur_pos,
|
||||
cur_direction,
|
||||
sub_seg_idx,
|
||||
is_first_speak_frame: bool = False,
|
||||
):
|
||||
"""
|
||||
在静默区间但是在讲话
|
||||
loop_flag=True and is_silent == 0
|
||||
whether border
|
||||
whether forward
|
||||
|
||||
Return: next_direction
|
||||
"""
|
||||
# 向最近的动作段移动, 如果左面没有动作段
|
||||
# TODO: 确认下面逻辑是否正确
|
||||
if (
|
||||
cur_pos == segments[0][0]
|
||||
): # 如果发生过跳跃,新段无论是不是动作段,仍然都向后执行
|
||||
return 1, cur_pos
|
||||
if cur_pos == segments[-1][1]:
|
||||
return 0, cur_pos
|
||||
# 在静默左边界处,且仍在讲话
|
||||
if cur_pos == left:
|
||||
if cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
return 0, cur_pos
|
||||
# 在静默右边界处,且仍在讲话
|
||||
elif cur_pos == right:
|
||||
if cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
return 0, cur_pos
|
||||
else:
|
||||
mid = left + (right - left + 1) // 2
|
||||
# !!就近原则只对第一次说话有效,其他情况遵循上一次状态
|
||||
if is_first_speak_frame:
|
||||
# 如果第一段
|
||||
if sub_seg_idx == 0 and segments[0][2]:
|
||||
return 1, cur_pos
|
||||
# 如果最后一段
|
||||
elif sub_seg_idx == len(segments) - 1 and segments[-1][2]:
|
||||
return 0, cur_pos
|
||||
|
||||
if cur_pos < mid:
|
||||
return 0, cur_pos
|
||||
else:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
if cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
elif cur_direction == 0:
|
||||
return 0, cur_pos
|
||||
|
||||
|
||||
def action2silent(
|
||||
segments, left, right, cur_pos, cur_direction, sub_seg_idx, is_last_speak_frame
|
||||
):
|
||||
"""
|
||||
loop_flag=False and is_silent==1
|
||||
whether border
|
||||
Return: next_direction
|
||||
"""
|
||||
if cur_pos == segments[0][0]:
|
||||
return 1, cur_pos
|
||||
if cur_pos == segments[-1][1]:
|
||||
return 0, cur_pos
|
||||
# 动作段,说话结束转静默情况下,就近原则,进入静默
|
||||
if is_last_speak_frame:
|
||||
mid = left + (right - left + 1) // 2
|
||||
if cur_pos < mid:
|
||||
return 0, cur_pos
|
||||
else:
|
||||
return 1, cur_pos
|
||||
|
||||
else:
|
||||
if cur_direction == 1:
|
||||
return 1, cur_pos
|
||||
else:
|
||||
return 0, cur_pos
|
||||
|
||||
|
||||
def subseg_judge(cur_pos, segments):
|
||||
for idx, frame_seg in enumerate(segments):
|
||||
if cur_pos >= frame_seg[0] and cur_pos <= frame_seg[1]:
|
||||
return idx
|
||||
if cur_pos == 0:
|
||||
return 0
|
||||
|
||||
|
||||
def get_oss_config_masx():
|
||||
logger.info("download from mas-x")
|
||||
OSS_ENDPOINT = "oss-cn-hangzhou.aliyuncs.com"
|
||||
OSS_KEY = "LTAIgPS3rCV89vYO"
|
||||
OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT"
|
||||
OSS_BUCKET = "mas-x"
|
||||
|
||||
return OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
|
||||
|
||||
|
||||
def get_oss_config_zqinghui():
|
||||
logger.info("download from zqinghui")
|
||||
OSS_ENDPOINT = "oss-cn-zhangjiakou.aliyuncs.com"
|
||||
OSS_KEY = "LTAIgPS3rCV89vYO"
|
||||
OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT"
|
||||
OSS_BUCKET = "zqinghui"
|
||||
|
||||
return OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
|
||||
|
||||
|
||||
def get_oss_config(url: str):
|
||||
if "zqinghui" in url:
|
||||
return get_oss_config_zqinghui()
|
||||
return get_oss_config_masx()
|
||||
|
||||
|
||||
def download_file_from_oss_url(
|
||||
model_dir: str,
|
||||
url_path: str,
|
||||
download_type: str,
|
||||
download_file_len: int,
|
||||
downloadQueue,
|
||||
downloadFileLen: int,
|
||||
):
|
||||
"""
|
||||
默认从oss上根据url下载不同的type数据,如果下载失败,获取zqinghui配置下的数据
|
||||
Args:
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.info(f"下载url_path: {url_path}")
|
||||
OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET = get_oss_config(url_path)
|
||||
oss_class = create_oss.Oss_connection(
|
||||
OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
|
||||
)
|
||||
oss_class.oss_get_file(
|
||||
model_dir,
|
||||
url_path,
|
||||
type=download_type,
|
||||
keys=url_path.lstrip(f"https://{OSS_BUCKET}.{OSS_ENDPOINT}"),
|
||||
downloadQueue=downloadQueue,
|
||||
downloadFileLen=downloadFileLen,
|
||||
downloadAllLen=download_file_len,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(f"下载 {url_path} error: {traceback.format_exc()}")
|
||||
if url_path.split("/")[-1] in os.listdir(model_dir):
|
||||
os.remove(os.path.join(model_dir, url_path.split('/')[-1]))
|
||||
|
||||
|
||||
def download_checkpoint(model_name, model_url, download_file_len, downloadQueue):
|
||||
logger.info(f"Download checkpoint of model people: {model_url}")
|
||||
model_dir = config.model_folder
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
if not os.path.exists(os.path.join(model_dir, model_name)):
|
||||
download_file_len += 1
|
||||
OSS_ENDPOINT = "oss-cn-zhangjiakou.aliyuncs.com"
|
||||
OSS_KEY = "LTAIgPS3rCV89vYO"
|
||||
OSS_SECRET = "p9wBFUAC5e29rtunkbnPTefWYLQpxT"
|
||||
OSS_BUCKET = "zqinghui"
|
||||
oss_class = create_oss.Oss_connection(
|
||||
OSS_ENDPOINT, OSS_KEY, OSS_SECRET, OSS_BUCKET
|
||||
)
|
||||
oss_class.oss_get_file(
|
||||
model_dir, model_url, type="模型", downloadQueue=downloadQueue
|
||||
)
|
||||
return download_file_len, downloadQueue
|
||||
|
||||
|
||||
def download(modelSpeakerId, downloadQueue, res):
|
||||
"""
|
||||
1、异常退出时,删除所有已经下载的临时文件
|
||||
2、重新下载时,根据上次下载一半的结果继续下载
|
||||
3、错误拦截有问题
|
||||
4、一些文件的创建和清理逻辑模糊
|
||||
"""
|
||||
# TODO: 这个变量什么作用
|
||||
download_file_len = 3
|
||||
if res["code"] != 0:
|
||||
logger.warning(f"无法获取当前数字人信息,{res['code']}")
|
||||
return "ERROR"
|
||||
|
||||
res = res["data"]
|
||||
pkl_url, model_url, video_needle_url, mask_needle_url = (
|
||||
res["pkl_url"],
|
||||
res["modeling_url"],
|
||||
res["video_needle_url"],
|
||||
res["mask_needle_url"],
|
||||
)
|
||||
model_name = model_url.split("/")[-1] # 这会取URL的最后一部分作为文件名
|
||||
download_file_len, downloadQueue = download_checkpoint(
|
||||
model_name, model_url, download_file_len, downloadQueue
|
||||
)
|
||||
|
||||
model_dir = "model_people/"
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
pkl_url_zip_name = pkl_url.split("/")[-1] # 这会取URL的最后一部分作为文件名
|
||||
pkl_url_name = pkl_url_zip_name.replace(".zip", "")
|
||||
logger.info(f"pkl_url_name:{pkl_url_name}")
|
||||
if pkl_url_name in os.listdir(model_dir):
|
||||
return f"model people of {modelSpeakerId} not need to download"
|
||||
|
||||
# TODO: 目前有报错,会无法删除已经下载的文件
|
||||
# pkl
|
||||
logger.info(f"download model people of : {modelSpeakerId}")
|
||||
if pkl_url != "none":
|
||||
download_file_from_oss_url(
|
||||
model_dir,
|
||||
pkl_url,
|
||||
"人物",
|
||||
download_file_len,
|
||||
downloadQueue,
|
||||
downloadFileLen=1,
|
||||
)
|
||||
unzip(
|
||||
os.path.join(model_dir, pkl_url_zip_name),
|
||||
os.path.join(model_dir, pkl_url_name),
|
||||
)
|
||||
shutil.rmtree(os.path.join(model_dir, pkl_url_name, "image"), ignore_errors=True)
|
||||
|
||||
# video
|
||||
if video_needle_url != "none":
|
||||
# TODO: 从配置获取
|
||||
download_file_from_oss_url(
|
||||
model_dir,
|
||||
video_needle_url,
|
||||
"人物",
|
||||
download_file_len,
|
||||
downloadQueue,
|
||||
downloadFileLen=2,
|
||||
)
|
||||
# 没有换背景前的数字人
|
||||
unzip(
|
||||
os.path.join(model_dir, video_needle_url.split('/')[-1]),
|
||||
os.path.join(model_dir, pkl_url_name, "image"),
|
||||
)
|
||||
shutil.rmtree(os.path.join(model_dir, modelSpeakerId), ignore_errors=True)
|
||||
os.rename(os.path.join(model_dir, pkl_url_name), os.path.join(model_dir, modelSpeakerId))
|
||||
os.remove(os.path.join(model_dir, pkl_url_zip_name))
|
||||
|
||||
# mask
|
||||
if mask_needle_url != "none":
|
||||
download_file_from_oss_url(
|
||||
model_dir,
|
||||
mask_needle_url,
|
||||
"mask",
|
||||
download_file_len,
|
||||
downloadQueue,
|
||||
downloadFileLen=3,
|
||||
)
|
||||
# TODO: 这里面放的是纯白色背景
|
||||
mask_dir = os.path.join(model_dir, "mask", modelSpeakerId)
|
||||
os.makedirs(mask_dir, exist_ok=True)
|
||||
|
||||
# unzip(f"{model_dir}/{mask_needle_url.split('/')[-1]}", mask_dir)
|
||||
unzip(os.path.join(model_dir, mask_needle_url.split('/')[-1]), mask_dir)
|
||||
# 删除zip包
|
||||
os.remove(os.path.join(model_dir, mask_needle_url.split('/')[-1]))
|
||||
# 这么处理时因为,pkl 和 img 可能链接最后一级重名
|
||||
video_img_file = os.path.join(model_dir, video_needle_url.split('/')[-1])
|
||||
if os.path.exists(video_img_file):
|
||||
os.remove(video_img_file)
|
||||
|
||||
downloadQueue.put("100")
|
||||
logger.info(f"model people of {modelSpeakerId} 下载完成")
|
||||
return "下载完成"
|
||||
|
||||
|
||||
def get_model_local_path(model_url: str):
|
||||
return model_url.split("/")[-1]
|
||||
|
||||
|
||||
def get_actor_id(pkl_url):
|
||||
"""
|
||||
取URL的最后一部分作为template_id
|
||||
"""
|
||||
pkl_file_name = pkl_url.split("/")[-1]
|
||||
return pkl_file_name.replace(".zip", "")
|
||||
|
||||
|
||||
def load_face_box(actor_id, pkl_version: int = 1) -> np.ndarray:
|
||||
"""
|
||||
Args: actor_id: str : uid
|
||||
pkl_version: 1, 2
|
||||
"""
|
||||
face_det_results_pk_file = os.path.join(
|
||||
"model_people", actor_id, "face_det_results.pkl"
|
||||
)
|
||||
if not os.path.exists(face_det_results_pk_file):
|
||||
raise FileExistsError(f"Not found file:{face_det_results_pk_file}")
|
||||
if pkl_version != 1:
|
||||
with open(face_det_results_pk_file, "rb") as f:
|
||||
res = pickle.load(f)["boxes"]
|
||||
return np.array(res)
|
||||
|
||||
with open(face_det_results_pk_file, mode="rb") as f:
|
||||
unpickler = pickle.Unpickler(f)
|
||||
face_det_results = []
|
||||
try:
|
||||
# TODO,太奇怪了,只能这么结束吗?
|
||||
while 1:
|
||||
data = unpickler.load()
|
||||
face_det_results.append(data)
|
||||
except Exception:
|
||||
pass
|
||||
final_result = []
|
||||
# TODO 改一下代码,这个在numpy1.24及以上不work。
|
||||
final_result = np.concatenate(face_det_results)
|
||||
return final_result
|
||||
|
||||
|
||||
def load_seg_model(pth_path, model, device):
|
||||
model.to(device)
|
||||
model.load_state_dict(torch.load(pth_path))
|
||||
return model.eval()
|
||||
|
||||
|
||||
def load_w2l_model(pth_path, model, device):
|
||||
model = model.to(device)
|
||||
model.load_state_dict(
|
||||
{
|
||||
k.replace("module.", ""): v
|
||||
for k, v in torch.load(pth_path)["state_dict"].items()
|
||||
}
|
||||
)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def read_img(file_path: str, flags=cv2.IMREAD_COLOR):
|
||||
"""read jpg img and decode"""
|
||||
return cv2.imread(file_path, flags=flags)
|
||||
|
||||
|
||||
def write_img(content, jpg_path: str, mode: str = "wb"):
|
||||
with open(jpg_path, mode) as file:
|
||||
file.write(content)
|
||||
|
||||
|
||||
def save_raw_video(speech_id, result_frames, tag="infer"):
|
||||
name = f"temp/{tag}_{speech_id}_result.avi"
|
||||
out = cv2.VideoWriter(name, cv2.VideoWriter_fourcc(*"mp4v"), 25, (1080, 1920))
|
||||
|
||||
for i, frame in enumerate(result_frames):
|
||||
out.write(frame)
|
||||
out.release()
|
||||
return name
|
||||
|
||||
|
||||
def save_raw_wav(speech_id, audio_packets, tag="infer", folder="temp"):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
sample_rate = 16000 # 采样率,每秒44100个采样点
|
||||
# duration = 1 # 音频时长,单位秒
|
||||
num_channels = 1 # 通道数,单声道
|
||||
sample_width = 2 # 采样宽度,单位字节,这里用16位音频
|
||||
os.makedirs("temp", exist_ok=True)
|
||||
saved_wav_name = f"temp/{tag}_{speech_id}.wav"
|
||||
|
||||
with wave.open(saved_wav_name, "wb") as wav_file:
|
||||
# 设置WAV文件参数
|
||||
wav_file.setnchannels(num_channels)
|
||||
wav_file.setsampwidth(sample_width) # 设置采样宽度,单位为字节
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(audio_packets)
|
||||
|
||||
return saved_wav_name
|
||||
|
||||
|
||||
def morphing(
|
||||
ori_frame: np.ndarray,
|
||||
pred_frame: np.ndarray,
|
||||
box: list,
|
||||
mp_ratio: float,
|
||||
padding: int = 10,
|
||||
file_index: int = 1,
|
||||
) -> np.ndarray:
|
||||
"""Tool to 图片融合
|
||||
Args:
|
||||
ori_frame: 原始视频帧 01.png - np
|
||||
pred_frame: infer_01.png - np
|
||||
box: box, # 确认[x1,y1,x2,y2]
|
||||
mp_ratio: [0,1] / morphing_frams_num
|
||||
Return:
|
||||
merged numpy array
|
||||
"""
|
||||
(y1, y2, x1, x2) = box
|
||||
ori_face = ori_frame[ori_frame.shape[0] // 2 :, :]
|
||||
pred_face = pred_frame[:, :, :3][
|
||||
y1 - padding : y2 + padding, x1 - padding : x2 + padding
|
||||
]
|
||||
pred_half_face = pred_face[pred_face.shape[0] // 2 :, :]
|
||||
alpha = np.ones_like(ori_face) * mp_ratio
|
||||
pred_frame[:, :, :3][
|
||||
y1 + (y2 - y1) // 2 : y2 + padding, x1 - padding : x2 + padding
|
||||
] = (1 - alpha) * pred_half_face + alpha * ori_face
|
||||
logger.debug(f"file_index:{file_index}, ratio:{mp_ratio}")
|
||||
return pred_frame
|
||||
|
||||
|
||||
# TODO: @txueduo 这里其实和fusion处理相似,原因是mask 非二值,所以要用float alpha 做融合
|
||||
def add_alpha(img: np.ndarray, mask: np.ndarray, alpha: int = 0):
|
||||
t0 = time.time()
|
||||
if img.shape[-1] == 3:
|
||||
res_img = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
|
||||
res_img[:, :, :3] = img[:, :, ::-1]
|
||||
else:
|
||||
res_img = np.zeros_like(img, dtype=np.uint8)
|
||||
res_img[:, :, :3] = img[:, :, :3][:, :, ::-1]
|
||||
mask[mask > 125] = 255
|
||||
res_img[:, :, 3] = mask
|
||||
result = res_img
|
||||
# 将新的 alpha 通道添加到图像中
|
||||
|
||||
logger.info(f"add_alpha cost:{time.time() - t0}")
|
||||
return result.astype("uint8")
|
||||
|
||||
|
||||
def fusion(person: np.ndarray, background: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Args:
|
||||
person: ori_img
|
||||
background: need to fusion img
|
||||
mask: filter mask
|
||||
"""
|
||||
|
||||
result = np.zeros_like(person, dtype=np.uint8)
|
||||
# png 输出透明度
|
||||
if len(mask.shape) == 2:
|
||||
background_alpha = np.zeros_like(person, dtype=np.uint8)
|
||||
background_alpha[:, :, :3] = background
|
||||
# background_alpha[:, :, 3] = 255 if not config.output_alpha else 0
|
||||
background_alpha[:, :, 3] = 0
|
||||
mask_rgba = np.repeat(mask[..., None], 4, axis=-1).astype("uint8")
|
||||
else:
|
||||
# jpeg 输出透明度由后面计算确认
|
||||
mask_rgba = mask
|
||||
background_alpha = background
|
||||
mask_rgba = mask_rgba / 255.0
|
||||
result = mask_rgba * person + (1 - mask_rgba) * background_alpha
|
||||
return result.astype("uint8")
|
||||
|
||||
|
||||
def chg_bg(
|
||||
ori_img_path: Union[str, np.ndarray] = None,
|
||||
bg_img_path: Union[str, np.ndarray] = None,
|
||||
mask: Union[str, np.ndarray] = None,
|
||||
default_flags=cv2.IMREAD_COLOR,
|
||||
):
|
||||
""""""
|
||||
if isinstance(bg_img_path, str):
|
||||
flags = cv2.IMREAD_UNCHANGED if "png" in bg_img_path else default_flags
|
||||
bg_img = cv2.imread(bg_img_path, flags=flags)
|
||||
else:
|
||||
bg_img = bg_img_path
|
||||
if isinstance(ori_img_path, str):
|
||||
flags = cv2.IMREAD_UNCHANGED if "png" in ori_img_path else default_flags
|
||||
ori_img = cv2.imread(ori_img_path, flags=flags)
|
||||
else:
|
||||
ori_img = ori_img_path
|
||||
logger.debug(
|
||||
f"bg-img shape:{bg_img.shape}, ori_img_path:{ori_img_path}, flags:{flags}"
|
||||
)
|
||||
if ori_img.shape[-1] == 4:
|
||||
return fusion(ori_img, bg_img, ori_img[:, :, 3])
|
||||
elif ori_img.shape[-1] == 3:
|
||||
if isinstance(mask, str):
|
||||
mask = cv2.imread(mask)
|
||||
return fusion(ori_img, bg_img, mask)
|
||||
else:
|
||||
raise ValueError("img shape[-1] must be 3 or 4")
|
||||
|
||||
|
||||
def load_config(model_speaker_id:str):
|
||||
""""""
|
||||
logger.info(f"model_speaker_id: {model_speaker_id}")
|
||||
res = requests.post(
|
||||
config.GET_segments, json={"modelSpeakerId": model_speaker_id}
|
||||
).json()["data"]
|
||||
logger.info(res)
|
||||
model_url = res["modeling_url"]
|
||||
frame_config = res.get("frame_config", [[1,200, True]])
|
||||
# frame_config = [[1, 591, True]]
|
||||
padding = res["padding"] # padding 过于小,会存在下巴没有推理的现象
|
||||
face_classes = res.get("face_classes", [1,11,12,13])
|
||||
if not frame_config:
|
||||
raise RuntimeError(f"frame config {frame_config} not json")
|
||||
|
||||
# 新增字段 trans_method
|
||||
trans_method = res.get("trans_method", 2)
|
||||
# 新增过渡参数字段
|
||||
infer_silent_num = res.get("infer_silent_num", 5)
|
||||
morphing_num = res.get("morphing_num", 25)
|
||||
pkl_version = res.get("pkl_version", 1)
|
||||
logger.info(f"infer_silent_num:{infer_silent_num}, morphing_num:{morphing_num}")
|
||||
return model_url, frame_config, padding, face_classes, trans_method, infer_silent_num, morphing_num, pkl_version
|
||||
|
||||
|
||||
def get_trans_idxes(need_post_morphing, need_infer_silent, post_morphint_idx,infer_silent_idx, file_index):
|
||||
if need_post_morphing:
|
||||
post_morphint_idxes = list(range(1,6)) if post_morphint_idx==1 else list(range(post_morphint_idx)[-5:])
|
||||
else: post_morphint_idxes = [0,0,0,0,0] #typeignore
|
||||
if need_infer_silent:
|
||||
infer_silent_idxes = list(range(1,6)) if infer_silent_idx==1 else list(range(infer_silent_idx)[-5:])
|
||||
else: infer_silent_idxes =[0,0,0,0,0]
|
||||
file_idxes = list(range(1,6)) if file_index==1 else list(range(file_index))[-5:]
|
||||
return file_idxes, post_morphint_idxes, infer_silent_idxes
|
Loading…
Reference in New Issue
Block a user