modify load avatar
This commit is contained in:
parent
053a2afab5
commit
499dba9bed
@ -9,7 +9,7 @@ from .audio_mal_handler import AudioMalHandler
|
||||
from .human_render import HumanRender
|
||||
from nlp import PunctuationSplit, DouBao
|
||||
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
|
||||
from utils import load_avatar, get_device, object_stop
|
||||
from utils import load_avatar, get_device, object_stop, load_avatar_from_processed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
@ -35,9 +35,11 @@ class HumanContext:
|
||||
|
||||
self._device = get_device()
|
||||
print(f'device:{self._device}')
|
||||
base_path = os.path.join(current_file_path, '..', 'face')
|
||||
logger.info(f'_create_recognizer init, path:{base_path}')
|
||||
full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
|
||||
base_path = os.path.join(current_file_path, '..')
|
||||
logger.info(f'base path:{base_path}')
|
||||
# full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
|
||||
full_images, face_frames, coord_frames = load_avatar_from_processed(base_path,
|
||||
'wav2lip_avatar1')
|
||||
self._frame_list_cycle = full_images
|
||||
self._face_list_cycle = face_frames
|
||||
self._coord_list_cycle = coord_frames
|
||||
|
@ -25,7 +25,7 @@ def download_tts(url):
|
||||
|
||||
def __create_bytes_stream(byte_stream):
|
||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||
# print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||
stream = stream.astype(np.float32)
|
||||
|
||||
if stream.ndim > 1:
|
||||
|
@ -79,15 +79,15 @@ class TTSEdgeHttp(TTSBase):
|
||||
|
||||
def __create_bytes_stream(self, byte_stream):
|
||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||
logging.info(f'tts audio stream {sample_rate}: {stream.shape}')
|
||||
stream = stream.astype(np.float32)
|
||||
|
||||
if stream.ndim > 1:
|
||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
||||
logging.warning(f'audio has {stream.shape[1]} channels, only use the first.')
|
||||
stream = stream[:, 0]
|
||||
|
||||
if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
|
||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
|
||||
logging.warning(f'audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
|
||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
|
||||
|
||||
return stream
|
||||
|
@ -4,4 +4,5 @@ from .async_task_queue import AsyncTaskQueue
|
||||
from .sync_queue import SyncQueue
|
||||
from .utils import mirror_index, load_model, get_device, load_avatar, config_logging
|
||||
from .utils import read_image, object_stop
|
||||
from .utils import load_avatar_from_processed
|
||||
from .audio_utils import melspectrogram, save_wav
|
||||
|
@ -1,6 +1,8 @@
|
||||
#encoding = utf8
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -33,7 +35,6 @@ def read_images(img_list):
|
||||
frames = []
|
||||
print('reading images...')
|
||||
for img_path in tqdm(img_list):
|
||||
print(f'read image path:{img_path}')
|
||||
# frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
frame = Image.open(img_path)
|
||||
frame = np.array(frame)
|
||||
@ -185,6 +186,30 @@ def load_avatar(path, img_size, device):
|
||||
return full_list_cycle, face_frames, coord_frames
|
||||
|
||||
|
||||
def load_avatar_from_processed(base_path, avatar_name):
|
||||
avatar_path = os.path.join(base_path, 'data', 'avatars', avatar_name)
|
||||
print(f'load avatar from processed:{avatar_path}')
|
||||
coord_path = os.path.join(avatar_path, 'coords.pkl')
|
||||
print(f'load avatar_path from processed:{avatar_path}')
|
||||
face_image_path = os.path.join(avatar_path, 'face_imgs')
|
||||
print(f'load face_image_path from processed:{face_image_path}')
|
||||
full_image_path = os.path.join(avatar_path, 'full_imgs')
|
||||
print(f'load full_image_path from processed:{full_image_path}')
|
||||
|
||||
with open(coord_path, 'rb') as f:
|
||||
coord_list_frames = pickle.load(f)
|
||||
|
||||
face_image_list = glob.glob(os.path.join(face_image_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
face_image_list = sorted(face_image_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
face_list_cycle = read_images(face_image_list)
|
||||
|
||||
full_image_list = glob.glob(os.path.join(full_image_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
full_image_list = sorted(full_image_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
frame_list_cycle = read_images(full_image_list)
|
||||
|
||||
return frame_list_cycle, face_list_cycle, coord_list_frames
|
||||
|
||||
|
||||
def config_logging(file_name: str, console_level: int = logging.INFO, file_level: int = logging.DEBUG):
|
||||
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
|
||||
file_handler.setFormatter(logging.Formatter(
|
||||
|
Loading…
Reference in New Issue
Block a user