add 255 avatars

This commit is contained in:
jiegeaiai 2024-11-23 01:22:47 +08:00
parent e3646f4e71
commit c01ec04cd3
3 changed files with 35 additions and 4 deletions

View File

@ -9,7 +9,7 @@ from .audio_mal_handler import AudioMalHandler
from .human_render import HumanRender from .human_render import HumanRender
from nlp import PunctuationSplit, DouBao from nlp import PunctuationSplit, DouBao
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
from utils import load_avatar, get_device, object_stop, load_avatar_from_processed from utils import load_avatar, get_device, object_stop, load_avatar_from_processed, load_avatar_from_256_processed
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
current_file_path = os.path.dirname(os.path.abspath(__file__)) current_file_path = os.path.dirname(os.path.abspath(__file__))
@ -38,11 +38,16 @@ class HumanContext:
base_path = os.path.join(current_file_path, '..') base_path = os.path.join(current_file_path, '..')
logger.info(f'base path:{base_path}') logger.info(f'base path:{base_path}')
# full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device) # full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
full_images, face_frames, coord_frames = load_avatar_from_processed(base_path, # full_images, face_frames, coord_frames = load_avatar_from_processed(base_path,
'wav2lip_avatar3') # 'wav2lip_avatar3')
full_images, face_frames, coord_frames, align_frames, m_frames, inv_m_frames = load_avatar_from_256_processed(
base_path, 'wav2lip_avatar4', '26.pkl')
self._frame_list_cycle = full_images self._frame_list_cycle = full_images
self._face_list_cycle = face_frames self._face_list_cycle = face_frames
self._coord_list_cycle = coord_frames self._coord_list_cycle = coord_frames
self._align_frames = align_frames
self._m_frames = m_frames
self._inv_m_frames = inv_m_frames
face_images_length = len(self._face_list_cycle) face_images_length = len(self._face_list_cycle)
logging.info(f'face images length: {face_images_length}') logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}')

View File

@ -4,5 +4,5 @@ from .async_task_queue import AsyncTaskQueue
from .sync_queue import SyncQueue from .sync_queue import SyncQueue
from .utils import mirror_index, load_model, get_device, load_avatar, config_logging from .utils import mirror_index, load_model, get_device, load_avatar, config_logging
from .utils import read_image, object_stop from .utils import read_image, object_stop
from .utils import load_avatar_from_processed from .utils import load_avatar_from_processed, load_avatar_from_256_processed
from .audio_utils import melspectrogram, save_wav from .audio_utils import melspectrogram, save_wav

View File

@ -210,6 +210,32 @@ def load_avatar_from_processed(base_path, avatar_name):
return frame_list_cycle, face_list_cycle, coord_list_frames return frame_list_cycle, face_list_cycle, coord_list_frames
def load_avatar_from_256_processed(base_path, avatar_name, pkl):
avatar_path = os.path.join(base_path, 'data', 'avatars', avatar_name, pkl)
print(f'load avatar from processed:{avatar_path}')
with open(avatar_path, "rb") as f:
avatar_data = pickle.load(f)
face_list_cycle = []
frame_list_cycle = []
coord_list_frames = []
align_frames = []
m_frames = []
inv_m_frames = []
frame_info_list = avatar_data['frame_info_list']
for frame_info in frame_info_list:
face_list_cycle.append(frame_info['img'])
frame_list_cycle.append(frame_info['frame'])
coord_list_frames.append(frame_info['coords'])
align_frames.append(frame_info['align_frame'])
m_frames.append(frame_info['m'])
inv_m_frames.append(frame_info['inv_m'])
return frame_list_cycle, face_list_cycle, coord_list_frames, align_frames, m_frames, inv_m_frames
def config_logging(file_name: str, console_level: int = logging.INFO, file_level: int = logging.DEBUG): def config_logging(file_name: str, console_level: int = logging.INFO, file_level: int = logging.DEBUG):
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8") file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
file_handler.setFormatter(logging.Formatter( file_handler.setFormatter(logging.Formatter(