diff --git a/human/human_context.py b/human/human_context.py index cca030d..75a6df2 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -9,7 +9,7 @@ from .audio_mal_handler import AudioMalHandler from .human_render import HumanRender from nlp import PunctuationSplit, DouBao from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp -from utils import load_avatar, get_device, object_stop +from utils import load_avatar, get_device, object_stop, load_avatar_from_processed logger = logging.getLogger(__name__) current_file_path = os.path.dirname(os.path.abspath(__file__)) @@ -35,9 +35,11 @@ class HumanContext: self._device = get_device() print(f'device:{self._device}') - base_path = os.path.join(current_file_path, '..', 'face') - logger.info(f'_create_recognizer init, path:{base_path}') - full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device) + base_path = os.path.join(current_file_path, '..') + logger.info(f'base path:{base_path}') + # full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device) + full_images, face_frames, coord_frames = load_avatar_from_processed(base_path, + 'wav2lip_avatar1') self._frame_list_cycle = full_images self._face_list_cycle = face_frames self._coord_list_cycle = coord_frames diff --git a/test/test_mzzsfy_tts.py b/test/test_mzzsfy_tts.py index f62c643..ecd3895 100644 --- a/test/test_mzzsfy_tts.py +++ b/test/test_mzzsfy_tts.py @@ -25,7 +25,7 @@ def download_tts(url): def __create_bytes_stream(byte_stream): stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 - print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') + # print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') stream = stream.astype(np.float32) if stream.ndim > 1: diff --git a/tts/tts_edge_http.py b/tts/tts_edge_http.py index 84b5301..c3ad86a 100644 --- a/tts/tts_edge_http.py +++ b/tts/tts_edge_http.py @@ -79,15 +79,15 @@ class TTSEdgeHttp(TTSBase): def __create_bytes_stream(self, byte_stream): stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 - print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') + logging.info(f'tts audio stream {sample_rate}: {stream.shape}') stream = stream.astype(np.float32) if stream.ndim > 1: - print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + logging.warning(f'audio has {stream.shape[1]} channels, only use the first.') stream = stream[:, 0] if sample_rate != self._handle.sample_rate and stream.shape[0] > 0: - print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.') + logging.warning(f'audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.') stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate) return stream diff --git a/utils/__init__.py b/utils/__init__.py index 3a9238c..21e438c 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -4,4 +4,5 @@ from .async_task_queue import AsyncTaskQueue from .sync_queue import SyncQueue from .utils import mirror_index, load_model, get_device, load_avatar, config_logging from .utils import read_image, object_stop +from .utils import load_avatar_from_processed from .audio_utils import melspectrogram, save_wav diff --git a/utils/utils.py b/utils/utils.py index 28c3a19..8968245 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,6 +1,8 @@ #encoding = utf8 +import glob import logging import os +import pickle import cv2 import numpy as np @@ -33,7 +35,6 @@ def read_images(img_list): frames = [] print('reading images...') for img_path in tqdm(img_list): - print(f'read image path:{img_path}') # frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) frame = Image.open(img_path) frame = np.array(frame) @@ -185,6 +186,30 @@ def load_avatar(path, img_size, device): return full_list_cycle, face_frames, coord_frames +def load_avatar_from_processed(base_path, avatar_name): + avatar_path = os.path.join(base_path, 'data', 'avatars', avatar_name) + print(f'load avatar from processed:{avatar_path}') + coord_path = os.path.join(avatar_path, 'coords.pkl') + print(f'load avatar_path from processed:{avatar_path}') + face_image_path = os.path.join(avatar_path, 'face_imgs') + print(f'load face_image_path from processed:{face_image_path}') + full_image_path = os.path.join(avatar_path, 'full_imgs') + print(f'load full_image_path from processed:{full_image_path}') + + with open(coord_path, 'rb') as f: + coord_list_frames = pickle.load(f) + + face_image_list = glob.glob(os.path.join(face_image_path, '*.[jpJP][pnPN]*[gG]')) + face_image_list = sorted(face_image_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + face_list_cycle = read_images(face_image_list) + + full_image_list = glob.glob(os.path.join(full_image_path, '*.[jpJP][pnPN]*[gG]')) + full_image_list = sorted(full_image_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + frame_list_cycle = read_images(full_image_list) + + return frame_list_cycle, face_list_cycle, coord_list_frames + + def config_logging(file_name: str, console_level: int = logging.INFO, file_level: int = logging.DEBUG): file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8") file_handler.setFormatter(logging.Formatter(