diff --git a/human/human_context.py b/human/human_context.py index 0f1bc1f..193e0ea 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -9,7 +9,7 @@ from .audio_mal_handler import AudioMalHandler from .human_render import HumanRender from nlp import PunctuationSplit, DouBao from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp -from utils import load_avatar, get_device, object_stop, load_avatar_from_processed +from utils import load_avatar, get_device, object_stop, load_avatar_from_processed, load_avatar_from_256_processed logger = logging.getLogger(__name__) current_file_path = os.path.dirname(os.path.abspath(__file__)) @@ -38,11 +38,16 @@ class HumanContext: base_path = os.path.join(current_file_path, '..') logger.info(f'base path:{base_path}') # full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device) - full_images, face_frames, coord_frames = load_avatar_from_processed(base_path, - 'wav2lip_avatar3') + # full_images, face_frames, coord_frames = load_avatar_from_processed(base_path, + # 'wav2lip_avatar3') + full_images, face_frames, coord_frames, align_frames, m_frames, inv_m_frames = load_avatar_from_256_processed( + base_path, 'wav2lip_avatar4', '26.pkl') self._frame_list_cycle = full_images self._face_list_cycle = face_frames self._coord_list_cycle = coord_frames + self._align_frames = align_frames + self._m_frames = m_frames + self._inv_m_frames = inv_m_frames face_images_length = len(self._face_list_cycle) logging.info(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}') diff --git a/utils/__init__.py b/utils/__init__.py index 21e438c..2f3ca11 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -4,5 +4,5 @@ from .async_task_queue import AsyncTaskQueue from .sync_queue import SyncQueue from .utils import mirror_index, load_model, get_device, load_avatar, config_logging from .utils import read_image, object_stop -from .utils import load_avatar_from_processed +from .utils import load_avatar_from_processed, load_avatar_from_256_processed from .audio_utils import melspectrogram, save_wav diff --git a/utils/utils.py b/utils/utils.py index 174f74e..5d99a7a 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -210,6 +210,32 @@ def load_avatar_from_processed(base_path, avatar_name): return frame_list_cycle, face_list_cycle, coord_list_frames +def load_avatar_from_256_processed(base_path, avatar_name, pkl): + avatar_path = os.path.join(base_path, 'data', 'avatars', avatar_name, pkl) + print(f'load avatar from processed:{avatar_path}') + + with open(avatar_path, "rb") as f: + avatar_data = pickle.load(f) + + face_list_cycle = [] + frame_list_cycle = [] + coord_list_frames = [] + align_frames = [] + m_frames = [] + inv_m_frames = [] + + frame_info_list = avatar_data['frame_info_list'] + for frame_info in frame_info_list: + face_list_cycle.append(frame_info['img']) + frame_list_cycle.append(frame_info['frame']) + coord_list_frames.append(frame_info['coords']) + align_frames.append(frame_info['align_frame']) + m_frames.append(frame_info['m']) + inv_m_frames.append(frame_info['inv_m']) + + return frame_list_cycle, face_list_cycle, coord_list_frames, align_frames, m_frames, inv_m_frames + + def config_logging(file_name: str, console_level: int = logging.INFO, file_level: int = logging.DEBUG): file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8") file_handler.setFormatter(logging.Formatter(