add audio inferance handler and about codes
This commit is contained in:
parent
dadfaf4eaf
commit
da37374232
106
human/audio_inference_handler.py
Normal file
106
human/audio_inference_handler.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
from threading import Event, Thread
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from human import AudioHandler
|
||||||
|
from utils import load_model, mirror_index, get_device
|
||||||
|
|
||||||
|
|
||||||
|
class AudioInferenceHandler(AudioHandler):
|
||||||
|
def __init__(self, context, handler):
|
||||||
|
super().__init__(context, handler)
|
||||||
|
|
||||||
|
self._exit_event = Event()
|
||||||
|
self._run_thread = Thread(target=self.__on_run)
|
||||||
|
self._exit_event.set()
|
||||||
|
self._run_thread.start()
|
||||||
|
|
||||||
|
def on_handle(self, stream, index):
|
||||||
|
if self._handler is not None:
|
||||||
|
self._handler.on_handle(stream, index)
|
||||||
|
|
||||||
|
def __on_run(self):
|
||||||
|
model = load_model(r'.\checkpoints\wav2lip.pth')
|
||||||
|
print("Model loaded")
|
||||||
|
|
||||||
|
face_list_cycle = self._human.get_face_list_cycle()
|
||||||
|
|
||||||
|
length = len(face_list_cycle)
|
||||||
|
index = 0
|
||||||
|
count = 0
|
||||||
|
count_time = 0
|
||||||
|
print('start inference')
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
print(f'use device:{device}')
|
||||||
|
|
||||||
|
while True:
|
||||||
|
if self._exit_event.is_set():
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
batch_size = self._context.batch_size()
|
||||||
|
try:
|
||||||
|
mel_batch = self._feat_queue.get(block=True, timeout=0.1)
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_all_silence = True
|
||||||
|
audio_frames = []
|
||||||
|
for _ in range(batch_size * 2):
|
||||||
|
frame, type_ = self._audio_out_queue.get()
|
||||||
|
audio_frames.append((frame, type_))
|
||||||
|
if type_ == 0:
|
||||||
|
is_all_silence = False
|
||||||
|
if is_all_silence:
|
||||||
|
for i in range(batch_size):
|
||||||
|
self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])
|
||||||
|
index = index + 1
|
||||||
|
else:
|
||||||
|
print('infer=======')
|
||||||
|
t = time.perf_counter()
|
||||||
|
img_batch = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
idx = mirror_index(length, index + i)
|
||||||
|
face = face_list_cycle[idx]
|
||||||
|
img_batch.append(face)
|
||||||
|
|
||||||
|
img_batch = np.asarray(img_batch)
|
||||||
|
mel_batch = np.asarray(mel_batch)
|
||||||
|
img_masked = img_batch.copy()
|
||||||
|
img_masked[:, face.shape[0] // 2:] = 0
|
||||||
|
#
|
||||||
|
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
||||||
|
mel_batch = np.reshape(mel_batch,
|
||||||
|
[len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
||||||
|
|
||||||
|
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
||||||
|
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = model(mel_batch, img_batch)
|
||||||
|
|
||||||
|
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
|
||||||
|
|
||||||
|
count_time += (time.perf_counter() - t)
|
||||||
|
count += batch_size
|
||||||
|
|
||||||
|
if count >= 100:
|
||||||
|
print(f"------actual avg infer fps:{count / count_time:.4f}")
|
||||||
|
count = 0
|
||||||
|
count_time = 0
|
||||||
|
|
||||||
|
image_index = 0
|
||||||
|
for i, res_frame in enumerate(pred):
|
||||||
|
self._human.push_res_frame(res_frame, mirror_index(length, index),
|
||||||
|
audio_frames[i * 2:i * 2 + 2])
|
||||||
|
index = index + 1
|
||||||
|
image_index = image_index + 1
|
||||||
|
print('batch count', image_index)
|
||||||
|
print('total batch time:', time.perf_counter() - start_time)
|
||||||
|
else:
|
||||||
|
time.sleep(1)
|
||||||
|
break
|
||||||
|
print('musereal inference processor stop')
|
@ -8,6 +8,7 @@ from threading import Thread, Event
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from human import AudioHandler
|
from human import AudioHandler
|
||||||
|
from utils import melspectrogram
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -45,20 +46,20 @@ class AudioMalHandler(AudioHandler):
|
|||||||
# self.output_queue.put((frame, _type))
|
# self.output_queue.put((frame, _type))
|
||||||
self._human.push_out_put(frame, _type)
|
self._human.push_out_put(frame, _type)
|
||||||
# context not enough, do not run network.
|
# context not enough, do not run network.
|
||||||
if len(self.frames) <= self.stride_left_size + self.stride_right_size:
|
if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size():
|
||||||
return
|
return
|
||||||
|
|
||||||
inputs = np.concatenate(self.frames) # [N * chunk]
|
inputs = np.concatenate(self.frames) # [N * chunk]
|
||||||
mel = audio.melspectrogram(inputs)
|
mel = melspectrogram(inputs)
|
||||||
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
|
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
|
||||||
# cut off stride
|
# cut off stride
|
||||||
left = max(0, self.stride_left_size * 80 / 50)
|
left = max(0, self._context.stride_left_size() * 80 / 50)
|
||||||
right = min(len(mel[0]), len(mel[0]) - self.stride_right_size * 80 / 50)
|
right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size() * 80 / 50)
|
||||||
mel_idx_multiplier = 80. * 2 / self.fps
|
mel_idx_multiplier = 80. * 2 / self._context.fps()
|
||||||
mel_step_size = 16
|
mel_step_size = 16
|
||||||
i = 0
|
i = 0
|
||||||
mel_chunks = []
|
mel_chunks = []
|
||||||
while i < (len(self.frames) - self.stride_left_size - self.stride_right_size) / 2:
|
while i < (len(self.frames) - self._context.stride_left_size() - self._context.stride_right_size()) / 2:
|
||||||
start_idx = int(left + i * mel_idx_multiplier)
|
start_idx = int(left + i * mel_idx_multiplier)
|
||||||
# print(start_idx)
|
# print(start_idx)
|
||||||
if start_idx + mel_step_size > len(mel[0]):
|
if start_idx + mel_step_size > len(mel[0]):
|
||||||
@ -70,7 +71,7 @@ class AudioMalHandler(AudioHandler):
|
|||||||
self._human.push_mel_chunks(mel_chunks)
|
self._human.push_mel_chunks(mel_chunks)
|
||||||
|
|
||||||
# discard the old part to save memory
|
# discard the old part to save memory
|
||||||
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):]
|
||||||
|
|
||||||
def get_audio_frame(self):
|
def get_audio_frame(self):
|
||||||
try:
|
try:
|
||||||
|
@ -1,8 +1,12 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
import logging
|
||||||
|
|
||||||
from asr import SherpaNcnnAsr
|
from asr import SherpaNcnnAsr
|
||||||
from nlp import PunctuationSplit, DouBao
|
from nlp import PunctuationSplit, DouBao
|
||||||
from tts import TTSEdge, TTSAudioSplitHandle
|
from tts import TTSEdge, TTSAudioSplitHandle
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HumanContext:
|
class HumanContext:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -12,6 +16,14 @@ class HumanContext:
|
|||||||
self._stride_left_size = 10
|
self._stride_left_size = 10
|
||||||
self._stride_right_size = 10
|
self._stride_right_size = 10
|
||||||
|
|
||||||
|
full_images, face_frames, coord_frames = load_avatar(r'./face/')
|
||||||
|
self._frame_list_cycle = full_images
|
||||||
|
self._face_list_cycle = face_frames
|
||||||
|
self._coord_list_cycle = coord_frames
|
||||||
|
face_images_length = len(self._face_list_cycle)
|
||||||
|
logging.info(f'face images length: {face_images_length}')
|
||||||
|
print(f'face images length: {face_images_length}')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self):
|
def fps(self):
|
||||||
return self._fps
|
return self._fps
|
||||||
@ -33,7 +45,7 @@ class HumanContext:
|
|||||||
return self._stride_right_size
|
return self._stride_right_size
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
tts_handle = TTSAudioSplitHandle(self)
|
tts_handle = TTSAudioSplitHandle(self, None)
|
||||||
tts = TTSEdge(tts_handle)
|
tts = TTSEdge(tts_handle)
|
||||||
split = PunctuationSplit()
|
split = PunctuationSplit()
|
||||||
nlp = DouBao(split, tts)
|
nlp = DouBao(split, tts)
|
||||||
|
@ -2,12 +2,13 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from audio import save_wav
|
from utils import save_wav
|
||||||
from human import AudioHandler
|
from human import AudioHandler
|
||||||
|
|
||||||
|
|
||||||
class TTSAudioHandle(AudioHandler):
|
class TTSAudioHandle(AudioHandler):
|
||||||
def __init__(self):
|
def __init__(self, context, handler):
|
||||||
|
super().__init__(context, handler)
|
||||||
self._sample_rate = 16000
|
self._sample_rate = 16000
|
||||||
self._index = 1
|
self._index = 1
|
||||||
|
|
||||||
@ -23,11 +24,13 @@ class TTSAudioHandle(AudioHandler):
|
|||||||
self._index = self._index + 1
|
self._index = self._index + 1
|
||||||
return self._index
|
return self._index
|
||||||
|
|
||||||
|
def on_handle(self, stream, index):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TTSAudioSplitHandle(TTSAudioHandle):
|
class TTSAudioSplitHandle(TTSAudioHandle):
|
||||||
def __init__(self, context):
|
def __init__(self, context, handler):
|
||||||
super().__init__()
|
super().__init__(context, handler)
|
||||||
self._context = context
|
|
||||||
self.sample_rate = self._context.get_audio_sample_rate()
|
self.sample_rate = self._context.get_audio_sample_rate()
|
||||||
self._chunk = self.sample_rate // self._context.get_fps()
|
self._chunk = self.sample_rate // self._context.get_fps()
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
|
||||||
from .async_task_queue import AsyncTaskQueue
|
from .async_task_queue import AsyncTaskQueue
|
||||||
from .utils import mirror_index
|
from .utils import mirror_index, load_model, get_device, load_avatar
|
||||||
|
from .audio_utils import melspectrogram, save_wav
|
||||||
|
|
||||||
|
@ -1,34 +1,41 @@
|
|||||||
|
#encoding = utf8
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.filters
|
import librosa.filters
|
||||||
import numpy as np
|
import numpy as np
|
||||||
# import tensorflow as tf
|
|
||||||
from scipy import signal
|
from scipy import signal
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
from hparams import hparams as hp
|
from hparams import hparams as hp
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from IPython.display import Audio
|
from IPython.display import Audio
|
||||||
|
|
||||||
|
|
||||||
def load_wav(path, sr):
|
def load_wav(path, sr):
|
||||||
return librosa.core.load(path, sr=sr)[0]
|
return librosa.core.load(path, sr=sr)[0]
|
||||||
|
|
||||||
|
|
||||||
def save_wav(wav, path, sr):
|
def save_wav(wav, path, sr):
|
||||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||||
#proposed by @dsmiller
|
# proposed by @dsmiller
|
||||||
wavfile.write(path, sr, wav.astype(np.int16))
|
wavfile.write(path, sr, wav.astype(np.int16))
|
||||||
|
|
||||||
|
|
||||||
def save_wavenet_wav(wav, path, sr):
|
def save_wavenet_wav(wav, path, sr):
|
||||||
librosa.output.write_wav(path, wav, sr=sr)
|
librosa.output.write_wav(path, wav, sr=sr)
|
||||||
|
|
||||||
|
|
||||||
def preemphasis(wav, k, preemphasize=True):
|
def preemphasis(wav, k, preemphasize=True):
|
||||||
if preemphasize:
|
if preemphasize:
|
||||||
return signal.lfilter([1, -k], [1], wav)
|
return signal.lfilter([1, -k], [1], wav)
|
||||||
return wav
|
return wav
|
||||||
|
|
||||||
|
|
||||||
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
||||||
if inv_preemphasize:
|
if inv_preemphasize:
|
||||||
return signal.lfilter([1], [1, -k], wav)
|
return signal.lfilter([1], [1, -k], wav)
|
||||||
return wav
|
return wav
|
||||||
|
|
||||||
|
|
||||||
def get_hop_size():
|
def get_hop_size():
|
||||||
hop_size = hp.hop_size
|
hop_size = hp.hop_size
|
||||||
if hop_size is None:
|
if hop_size is None:
|
||||||
@ -36,6 +43,7 @@ def get_hop_size():
|
|||||||
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
||||||
return hop_size
|
return hop_size
|
||||||
|
|
||||||
|
|
||||||
def linearspectrogram(wav):
|
def linearspectrogram(wav):
|
||||||
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
||||||
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
||||||
@ -44,6 +52,7 @@ def linearspectrogram(wav):
|
|||||||
return _normalize(S)
|
return _normalize(S)
|
||||||
return S
|
return S
|
||||||
|
|
||||||
|
|
||||||
def melspectrogram(wav):
|
def melspectrogram(wav):
|
||||||
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
||||||
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
||||||
@ -52,18 +61,21 @@ def melspectrogram(wav):
|
|||||||
return _normalize(S)
|
return _normalize(S)
|
||||||
return S
|
return S
|
||||||
|
|
||||||
|
|
||||||
def _lws_processor():
|
def _lws_processor():
|
||||||
import lws
|
import lws
|
||||||
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
||||||
|
|
||||||
|
|
||||||
def _stft(y):
|
def _stft(y):
|
||||||
if hp.use_lws:
|
if hp.use_lws:
|
||||||
return _lws_processor(hp).stft(y).T
|
return _lws_processor(hp).stft(y).T
|
||||||
else:
|
else:
|
||||||
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
# Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
||||||
def num_frames(length, fsize, fshift):
|
def num_frames(length, fsize, fshift):
|
||||||
"""Compute number of time frames of spectrogram
|
"""Compute number of time frames of spectrogram
|
||||||
"""
|
"""
|
||||||
@ -83,32 +95,40 @@ def pad_lr(x, fsize, fshift):
|
|||||||
T = len(x) + 2 * pad
|
T = len(x) + 2 * pad
|
||||||
r = (M - 1) * fshift + fsize - T
|
r = (M - 1) * fshift + fsize - T
|
||||||
return pad, pad + r
|
return pad, pad + r
|
||||||
|
|
||||||
|
|
||||||
##########################################################
|
##########################################################
|
||||||
#Librosa correct padding
|
# Librosa correct padding
|
||||||
def librosa_pad_lr(x, fsize, fshift):
|
def librosa_pad_lr(x, fsize, fshift):
|
||||||
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
||||||
|
|
||||||
|
|
||||||
# Conversions
|
# Conversions
|
||||||
_mel_basis = None
|
_mel_basis = None
|
||||||
|
|
||||||
|
|
||||||
def _linear_to_mel(spectogram):
|
def _linear_to_mel(spectogram):
|
||||||
global _mel_basis
|
global _mel_basis
|
||||||
if _mel_basis is None:
|
if _mel_basis is None:
|
||||||
_mel_basis = _build_mel_basis()
|
_mel_basis = _build_mel_basis()
|
||||||
return np.dot(_mel_basis, spectogram)
|
return np.dot(_mel_basis, spectogram)
|
||||||
|
|
||||||
|
|
||||||
def _build_mel_basis():
|
def _build_mel_basis():
|
||||||
assert hp.fmax <= hp.sample_rate // 2
|
assert hp.fmax <= hp.sample_rate // 2
|
||||||
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
|
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
|
||||||
fmin=hp.fmin, fmax=hp.fmax)
|
fmin=hp.fmin, fmax=hp.fmax)
|
||||||
|
|
||||||
|
|
||||||
def _amp_to_db(x):
|
def _amp_to_db(x):
|
||||||
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
||||||
return 20 * np.log10(np.maximum(min_level, x))
|
return 20 * np.log10(np.maximum(min_level, x))
|
||||||
|
|
||||||
|
|
||||||
def _db_to_amp(x):
|
def _db_to_amp(x):
|
||||||
return np.power(10.0, (x) * 0.05)
|
return np.power(10.0, (x) * 0.05)
|
||||||
|
|
||||||
|
|
||||||
def _normalize(S):
|
def _normalize(S):
|
||||||
if hp.allow_clipping_in_normalization:
|
if hp.allow_clipping_in_normalization:
|
||||||
if hp.symmetric_mels:
|
if hp.symmetric_mels:
|
||||||
@ -123,6 +143,7 @@ def _normalize(S):
|
|||||||
else:
|
else:
|
||||||
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
||||||
|
|
||||||
|
|
||||||
def _denormalize(D):
|
def _denormalize(D):
|
||||||
if hp.allow_clipping_in_normalization:
|
if hp.allow_clipping_in_normalization:
|
||||||
if hp.symmetric_mels:
|
if hp.symmetric_mels:
|
166
utils/utils.py
166
utils/utils.py
@ -1,4 +1,17 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import face_detection
|
||||||
|
from models import Wav2Lip
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def mirror_index(size, index):
|
def mirror_index(size, index):
|
||||||
# size = len(self.coord_list_cycle)
|
# size = len(self.coord_list_cycle)
|
||||||
@ -8,3 +21,156 @@ def mirror_index(size, index):
|
|||||||
return res
|
return res
|
||||||
else:
|
else:
|
||||||
return size - res - 1
|
return size - res - 1
|
||||||
|
|
||||||
|
|
||||||
|
def read_images(img_list):
|
||||||
|
frames = []
|
||||||
|
print('reading images...')
|
||||||
|
for img_path in tqdm(img_list):
|
||||||
|
print(f'read image path:{img_path}')
|
||||||
|
frame = cv2.imread(img_path)
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def read_files_path(path):
|
||||||
|
file_paths = []
|
||||||
|
files = os.listdir(path)
|
||||||
|
for file in files:
|
||||||
|
if not os.path.isdir(file):
|
||||||
|
file_paths.append(path + file)
|
||||||
|
return file_paths
|
||||||
|
|
||||||
|
|
||||||
|
def get_smoothened_boxes(boxes, t):
|
||||||
|
for i in range(len(boxes)):
|
||||||
|
if i + t > len(boxes):
|
||||||
|
window = boxes[len(boxes) - t:]
|
||||||
|
else:
|
||||||
|
window = boxes[i: i + t]
|
||||||
|
boxes[i] = np.mean(window, axis=0)
|
||||||
|
return boxes
|
||||||
|
|
||||||
|
|
||||||
|
def datagen_signal(frame, mel, face_det_results, img_size, wav2lip_batch_size=128):
|
||||||
|
img_batch, mel_batch, frame_batch, coord_batch = [], [], [], []
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
frame_to_save = frame.copy()
|
||||||
|
face, coord = face_det_results[idx].copy()
|
||||||
|
|
||||||
|
face = cv2.resize(face, (img_size, img_size))
|
||||||
|
|
||||||
|
for i, m in enumerate(mel):
|
||||||
|
img_batch.append(face)
|
||||||
|
mel_batch.append(m)
|
||||||
|
frame_batch.append(frame_to_save)
|
||||||
|
coord_batch.append(coord)
|
||||||
|
|
||||||
|
if len(img_batch) >= wav2lip_batch_size:
|
||||||
|
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
||||||
|
|
||||||
|
img_masked = img_batch.copy()
|
||||||
|
img_masked[:, img_size // 2:] = 0
|
||||||
|
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
||||||
|
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
||||||
|
|
||||||
|
return img_batch, mel_batch, frame_batch, coord_batch
|
||||||
|
|
||||||
|
if len(img_batch) > 0:
|
||||||
|
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
|
||||||
|
img_masked = img_batch.copy()
|
||||||
|
img_masked[:, img_size//2:] = 0
|
||||||
|
|
||||||
|
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
||||||
|
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
||||||
|
|
||||||
|
return img_batch, mel_batch, frame_batch, coord_batch
|
||||||
|
|
||||||
|
|
||||||
|
def face_detect(images, device):
|
||||||
|
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
|
||||||
|
flip_input=False, device=device)
|
||||||
|
batch_size = 16
|
||||||
|
|
||||||
|
while 1:
|
||||||
|
predictions = []
|
||||||
|
try:
|
||||||
|
for i in range(0, len(images), batch_size):
|
||||||
|
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
|
||||||
|
except RuntimeError:
|
||||||
|
if batch_size == 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Image too big to run face detection on GPU. Please use the --resize_factor argument')
|
||||||
|
batch_size //= 2
|
||||||
|
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
results = []
|
||||||
|
pad_y1, pad_y2, pad_x1, pad_x2 = [0, 10, 0, 0]
|
||||||
|
for rect, image in zip(predictions, images):
|
||||||
|
if rect is None:
|
||||||
|
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
|
||||||
|
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
|
||||||
|
|
||||||
|
y1 = max(0, rect[1] - pad_y1)
|
||||||
|
y2 = min(image.shape[0], rect[3] + pad_y2)
|
||||||
|
x1 = max(0, rect[0] - pad_x1)
|
||||||
|
x2 = min(image.shape[1], rect[2] + pad_x2)
|
||||||
|
|
||||||
|
results.append([x1, y1, x2, y2])
|
||||||
|
|
||||||
|
boxes = np.array(results)
|
||||||
|
if not False:
|
||||||
|
boxes = get_smoothened_boxes(boxes, t=5)
|
||||||
|
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
|
||||||
|
|
||||||
|
del detector
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def get_device():
|
||||||
|
return 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
|
||||||
|
def _load(checkpoint_path):
|
||||||
|
device = get_device
|
||||||
|
if device == 'cuda':
|
||||||
|
checkpoint = torch.load(checkpoint_path)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(checkpoint_path,
|
||||||
|
map_location=lambda storage, loc: storage)
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(path):
|
||||||
|
model = Wav2Lip()
|
||||||
|
print("Load checkpoint from: {}".format(path))
|
||||||
|
logging.info(f'Load checkpoint from {path}')
|
||||||
|
checkpoint = _load(path)
|
||||||
|
s = checkpoint["state_dict"]
|
||||||
|
new_s = {}
|
||||||
|
for k, v in s.items():
|
||||||
|
new_s[k.replace('module.', '')] = v
|
||||||
|
model.load_state_dict(new_s)
|
||||||
|
device = get_device()
|
||||||
|
model = model.to(device)
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
def load_avatar(path, img_size, device):
|
||||||
|
face_images_path = path
|
||||||
|
face_images_path = read_files_path(face_images_path)
|
||||||
|
full_list_cycle = read_images(face_images_path)
|
||||||
|
|
||||||
|
face_det_results = face_detect(full_list_cycle, device)
|
||||||
|
|
||||||
|
face_frames = []
|
||||||
|
coord_frames = []
|
||||||
|
for face, coord in face_det_results:
|
||||||
|
resized_crop_frame = cv2.resize(face, (img_size, img_size))
|
||||||
|
face_frames.append(resized_crop_frame)
|
||||||
|
coord_frames.append(coord)
|
||||||
|
|
||||||
|
return full_list_cycle, face_frames, coord_frames
|
||||||
|
Loading…
Reference in New Issue
Block a user