modify load avatar
This commit is contained in:
parent
053a2afab5
commit
499dba9bed
@ -9,7 +9,7 @@ from .audio_mal_handler import AudioMalHandler
|
|||||||
from .human_render import HumanRender
|
from .human_render import HumanRender
|
||||||
from nlp import PunctuationSplit, DouBao
|
from nlp import PunctuationSplit, DouBao
|
||||||
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
|
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
|
||||||
from utils import load_avatar, get_device, object_stop
|
from utils import load_avatar, get_device, object_stop, load_avatar_from_processed
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
@ -35,9 +35,11 @@ class HumanContext:
|
|||||||
|
|
||||||
self._device = get_device()
|
self._device = get_device()
|
||||||
print(f'device:{self._device}')
|
print(f'device:{self._device}')
|
||||||
base_path = os.path.join(current_file_path, '..', 'face')
|
base_path = os.path.join(current_file_path, '..')
|
||||||
logger.info(f'_create_recognizer init, path:{base_path}')
|
logger.info(f'base path:{base_path}')
|
||||||
full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
|
# full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
|
||||||
|
full_images, face_frames, coord_frames = load_avatar_from_processed(base_path,
|
||||||
|
'wav2lip_avatar1')
|
||||||
self._frame_list_cycle = full_images
|
self._frame_list_cycle = full_images
|
||||||
self._face_list_cycle = face_frames
|
self._face_list_cycle = face_frames
|
||||||
self._coord_list_cycle = coord_frames
|
self._coord_list_cycle = coord_frames
|
||||||
|
@ -25,7 +25,7 @@ def download_tts(url):
|
|||||||
|
|
||||||
def __create_bytes_stream(byte_stream):
|
def __create_bytes_stream(byte_stream):
|
||||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
# print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
||||||
stream = stream.astype(np.float32)
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
if stream.ndim > 1:
|
if stream.ndim > 1:
|
||||||
|
@ -79,15 +79,15 @@ class TTSEdgeHttp(TTSBase):
|
|||||||
|
|
||||||
def __create_bytes_stream(self, byte_stream):
|
def __create_bytes_stream(self, byte_stream):
|
||||||
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
||||||
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
logging.info(f'tts audio stream {sample_rate}: {stream.shape}')
|
||||||
stream = stream.astype(np.float32)
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
if stream.ndim > 1:
|
if stream.ndim > 1:
|
||||||
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
logging.warning(f'audio has {stream.shape[1]} channels, only use the first.')
|
||||||
stream = stream[:, 0]
|
stream = stream[:, 0]
|
||||||
|
|
||||||
if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
|
if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
|
||||||
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
|
logging.warning(f'audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
|
||||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)
|
||||||
|
|
||||||
return stream
|
return stream
|
||||||
|
@ -4,4 +4,5 @@ from .async_task_queue import AsyncTaskQueue
|
|||||||
from .sync_queue import SyncQueue
|
from .sync_queue import SyncQueue
|
||||||
from .utils import mirror_index, load_model, get_device, load_avatar, config_logging
|
from .utils import mirror_index, load_model, get_device, load_avatar, config_logging
|
||||||
from .utils import read_image, object_stop
|
from .utils import read_image, object_stop
|
||||||
|
from .utils import load_avatar_from_processed
|
||||||
from .audio_utils import melspectrogram, save_wav
|
from .audio_utils import melspectrogram, save_wav
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pickle
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -33,7 +35,6 @@ def read_images(img_list):
|
|||||||
frames = []
|
frames = []
|
||||||
print('reading images...')
|
print('reading images...')
|
||||||
for img_path in tqdm(img_list):
|
for img_path in tqdm(img_list):
|
||||||
print(f'read image path:{img_path}')
|
|
||||||
# frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
# frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||||
frame = Image.open(img_path)
|
frame = Image.open(img_path)
|
||||||
frame = np.array(frame)
|
frame = np.array(frame)
|
||||||
@ -185,6 +186,30 @@ def load_avatar(path, img_size, device):
|
|||||||
return full_list_cycle, face_frames, coord_frames
|
return full_list_cycle, face_frames, coord_frames
|
||||||
|
|
||||||
|
|
||||||
|
def load_avatar_from_processed(base_path, avatar_name):
|
||||||
|
avatar_path = os.path.join(base_path, 'data', 'avatars', avatar_name)
|
||||||
|
print(f'load avatar from processed:{avatar_path}')
|
||||||
|
coord_path = os.path.join(avatar_path, 'coords.pkl')
|
||||||
|
print(f'load avatar_path from processed:{avatar_path}')
|
||||||
|
face_image_path = os.path.join(avatar_path, 'face_imgs')
|
||||||
|
print(f'load face_image_path from processed:{face_image_path}')
|
||||||
|
full_image_path = os.path.join(avatar_path, 'full_imgs')
|
||||||
|
print(f'load full_image_path from processed:{full_image_path}')
|
||||||
|
|
||||||
|
with open(coord_path, 'rb') as f:
|
||||||
|
coord_list_frames = pickle.load(f)
|
||||||
|
|
||||||
|
face_image_list = glob.glob(os.path.join(face_image_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
face_image_list = sorted(face_image_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
face_list_cycle = read_images(face_image_list)
|
||||||
|
|
||||||
|
full_image_list = glob.glob(os.path.join(full_image_path, '*.[jpJP][pnPN]*[gG]'))
|
||||||
|
full_image_list = sorted(full_image_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||||
|
frame_list_cycle = read_images(full_image_list)
|
||||||
|
|
||||||
|
return frame_list_cycle, face_list_cycle, coord_list_frames
|
||||||
|
|
||||||
|
|
||||||
def config_logging(file_name: str, console_level: int = logging.INFO, file_level: int = logging.DEBUG):
|
def config_logging(file_name: str, console_level: int = logging.INFO, file_level: int = logging.DEBUG):
|
||||||
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
|
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
|
||||||
file_handler.setFormatter(logging.Formatter(
|
file_handler.setFormatter(logging.Formatter(
|
||||||
|
Loading…
Reference in New Issue
Block a user