modify load avatar

This commit is contained in:
brige 2024-11-18 20:05:44 +08:00
parent 053a2afab5
commit 499dba9bed
5 changed files with 37 additions and 9 deletions

View File

@ -9,7 +9,7 @@ from .audio_mal_handler import AudioMalHandler
from .human_render import HumanRender from .human_render import HumanRender
from nlp import PunctuationSplit, DouBao from nlp import PunctuationSplit, DouBao
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
from utils import load_avatar, get_device, object_stop from utils import load_avatar, get_device, object_stop, load_avatar_from_processed
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
current_file_path = os.path.dirname(os.path.abspath(__file__)) current_file_path = os.path.dirname(os.path.abspath(__file__))
@ -35,9 +35,11 @@ class HumanContext:
self._device = get_device() self._device = get_device()
print(f'device:{self._device}') print(f'device:{self._device}')
base_path = os.path.join(current_file_path, '..', 'face') base_path = os.path.join(current_file_path, '..')
logger.info(f'_create_recognizer init, path:{base_path}') logger.info(f'base path:{base_path}')
full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device) # full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
full_images, face_frames, coord_frames = load_avatar_from_processed(base_path,
'wav2lip_avatar1')
self._frame_list_cycle = full_images self._frame_list_cycle = full_images
self._face_list_cycle = face_frames self._face_list_cycle = face_frames
self._coord_list_cycle = coord_frames self._coord_list_cycle = coord_frames

View File

@ -25,7 +25,7 @@ def download_tts(url):
def __create_bytes_stream(byte_stream): def __create_bytes_stream(byte_stream):
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') # print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32) stream = stream.astype(np.float32)
if stream.ndim > 1: if stream.ndim > 1:

View File

@ -79,15 +79,15 @@ class TTSEdgeHttp(TTSBase):
def __create_bytes_stream(self, byte_stream): def __create_bytes_stream(self, byte_stream):
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') logging.info(f'tts audio stream {sample_rate}: {stream.shape}')
stream = stream.astype(np.float32) stream = stream.astype(np.float32)
if stream.ndim > 1: if stream.ndim > 1:
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') logging.warning(f'audio has {stream.shape[1]} channels, only use the first.')
stream = stream[:, 0] stream = stream[:, 0]
if sample_rate != self._handle.sample_rate and stream.shape[0] > 0: if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.') logging.warning(f'audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate) stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
return stream return stream

View File

@ -4,4 +4,5 @@ from .async_task_queue import AsyncTaskQueue
from .sync_queue import SyncQueue from .sync_queue import SyncQueue
from .utils import mirror_index, load_model, get_device, load_avatar, config_logging from .utils import mirror_index, load_model, get_device, load_avatar, config_logging
from .utils import read_image, object_stop from .utils import read_image, object_stop
from .utils import load_avatar_from_processed
from .audio_utils import melspectrogram, save_wav from .audio_utils import melspectrogram, save_wav

View File

@ -1,6 +1,8 @@
#encoding = utf8 #encoding = utf8
import glob
import logging import logging
import os import os
import pickle
import cv2 import cv2
import numpy as np import numpy as np
@ -33,7 +35,6 @@ def read_images(img_list):
frames = [] frames = []
print('reading images...') print('reading images...')
for img_path in tqdm(img_list): for img_path in tqdm(img_list):
print(f'read image path:{img_path}')
# frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
frame = Image.open(img_path) frame = Image.open(img_path)
frame = np.array(frame) frame = np.array(frame)
@ -185,6 +186,30 @@ def load_avatar(path, img_size, device):
return full_list_cycle, face_frames, coord_frames return full_list_cycle, face_frames, coord_frames
def load_avatar_from_processed(base_path, avatar_name):
avatar_path = os.path.join(base_path, 'data', 'avatars', avatar_name)
print(f'load avatar from processed:{avatar_path}')
coord_path = os.path.join(avatar_path, 'coords.pkl')
print(f'load avatar_path from processed:{avatar_path}')
face_image_path = os.path.join(avatar_path, 'face_imgs')
print(f'load face_image_path from processed:{face_image_path}')
full_image_path = os.path.join(avatar_path, 'full_imgs')
print(f'load full_image_path from processed:{full_image_path}')
with open(coord_path, 'rb') as f:
coord_list_frames = pickle.load(f)
face_image_list = glob.glob(os.path.join(face_image_path, '*.[jpJP][pnPN]*[gG]'))
face_image_list = sorted(face_image_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
face_list_cycle = read_images(face_image_list)
full_image_list = glob.glob(os.path.join(full_image_path, '*.[jpJP][pnPN]*[gG]'))
full_image_list = sorted(full_image_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
frame_list_cycle = read_images(full_image_list)
return frame_list_cycle, face_list_cycle, coord_list_frames
def config_logging(file_name: str, console_level: int = logging.INFO, file_level: int = logging.DEBUG): def config_logging(file_name: str, console_level: int = logging.INFO, file_level: int = logging.DEBUG):
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8") file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
file_handler.setFormatter(logging.Formatter( file_handler.setFormatter(logging.Formatter(