add support 256
This commit is contained in:
parent
f86368bc37
commit
c0d6e01b23
@ -59,7 +59,9 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
super().on_message(message)
|
super().on_message(message)
|
||||||
|
|
||||||
def __on_run(self):
|
def __on_run(self):
|
||||||
wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'wav2lip.pth')
|
# wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'wav2lip.pth')
|
||||||
|
wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'weights', 'wav2lip',
|
||||||
|
'ema_checkpoint_step000300000.pth')
|
||||||
logger.info(f'AudioInferenceHandler init, path:{wav2lip_path}')
|
logger.info(f'AudioInferenceHandler init, path:{wav2lip_path}')
|
||||||
model = load_model(wav2lip_path)
|
model = load_model(wav2lip_path)
|
||||||
logger.info("Model loaded")
|
logger.info("Model loaded")
|
||||||
@ -130,7 +132,7 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
|
|
||||||
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
||||||
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
||||||
print('img_batch:', img_batch.shape, 'mel_batch:', mel_batch.shape)
|
# print('img_batch:', img_batch.shape, 'mel_batch:', mel_batch.shape)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pred = model(mel_batch, img_batch)
|
pred = model(mel_batch, img_batch)
|
||||||
|
|
||||||
|
@ -79,7 +79,7 @@ class AudioMalHandler(AudioHandler):
|
|||||||
mel = melspectrogram(inputs)
|
mel = melspectrogram(inputs)
|
||||||
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
|
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
|
||||||
# cut off stride
|
# cut off stride
|
||||||
left = max(0, self._context.stride_left_size * 80 / 50)
|
left = max(0, self._context.stride_left_size * 80 / self._context.fps)
|
||||||
right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50)
|
right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50)
|
||||||
mel_idx_multiplier = 80. * 2 / self._context.fps
|
mel_idx_multiplier = 80. * 2 / self._context.fps
|
||||||
mel_step_size = 16
|
mel_step_size = 16
|
||||||
|
@ -17,7 +17,7 @@ current_file_path = os.path.dirname(os.path.abspath(__file__))
|
|||||||
|
|
||||||
class HumanContext:
|
class HumanContext:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._fps = 50 # 20 ms per frame
|
self._fps = 25 # 20 ms per frame
|
||||||
self._image_size = 288
|
self._image_size = 288
|
||||||
self._batch_size = 16
|
self._batch_size = 16
|
||||||
self._sample_rate = 16000
|
self._sample_rate = 16000
|
||||||
@ -39,7 +39,7 @@ class HumanContext:
|
|||||||
logger.info(f'base path:{base_path}')
|
logger.info(f'base path:{base_path}')
|
||||||
# full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
|
# full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
|
||||||
full_images, face_frames, coord_frames = load_avatar_from_processed(base_path,
|
full_images, face_frames, coord_frames = load_avatar_from_processed(base_path,
|
||||||
'wav2lip_avatar2')
|
'wav2lip_avatar3')
|
||||||
self._frame_list_cycle = full_images
|
self._frame_list_cycle = full_images
|
||||||
self._face_list_cycle = face_frames
|
self._face_list_cycle = face_frames
|
||||||
self._coord_list_cycle = coord_frames
|
self._coord_list_cycle = coord_frames
|
||||||
|
@ -50,7 +50,7 @@ class HumanRender(AudioHandler):
|
|||||||
# t = time.time()
|
# t = time.time()
|
||||||
self._run_step()
|
self._run_step()
|
||||||
# delay = time.time() - t
|
# delay = time.time() - t
|
||||||
delay = 0.038 # - delay
|
delay = 0.04 # - delay
|
||||||
# print(delay)
|
# print(delay)
|
||||||
# if delay <= 0.0:
|
# if delay <= 0.0:
|
||||||
# continue
|
# continue
|
||||||
|
@ -2,3 +2,4 @@
|
|||||||
|
|
||||||
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
|
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
|
||||||
from .syncnet import SyncNet_color
|
from .syncnet import SyncNet_color
|
||||||
|
from .wav2lip_v2 import Wav2LipV2
|
||||||
|
@ -5,6 +5,7 @@ import math
|
|||||||
|
|
||||||
from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
|
from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
|
||||||
|
|
||||||
|
|
||||||
class Wav2Lip(nn.Module):
|
class Wav2Lip(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Wav2Lip, self).__init__()
|
super(Wav2Lip, self).__init__()
|
||||||
|
221
models/wav2lip_v2.py
Normal file
221
models/wav2lip_v2.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2LipV2(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Wav2LipV2, self).__init__()
|
||||||
|
|
||||||
|
self.face_encoder_blocks = nn.ModuleList([
|
||||||
|
nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
|
||||||
|
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
|
||||||
|
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||||
|
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
||||||
|
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||||||
|
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||||
|
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
|
||||||
|
Conv2d(512, 512, kernel_size=1, stride=1, padding=0)), ])
|
||||||
|
|
||||||
|
self.audio_encoder = nn.Sequential(
|
||||||
|
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
||||||
|
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
|
||||||
|
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
||||||
|
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
|
||||||
|
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
||||||
|
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
|
||||||
|
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
||||||
|
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
|
||||||
|
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
||||||
|
Conv2d(512, 512, kernel_size=1, stride=1, padding=0), )
|
||||||
|
|
||||||
|
self.face_decoder_blocks = nn.ModuleList([
|
||||||
|
nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=4, stride=1, padding=0),
|
||||||
|
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||||
|
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||||
|
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||||
|
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), ),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||||
|
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), ),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||||
|
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), ),
|
||||||
|
|
||||||
|
nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||||
|
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||||
|
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), ), ])
|
||||||
|
|
||||||
|
self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
|
||||||
|
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.Sigmoid())
|
||||||
|
|
||||||
|
def audio_forward(self, audio_sequences, a_alpha=1.):
|
||||||
|
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
|
||||||
|
if a_alpha != 1.:
|
||||||
|
audio_embedding *= a_alpha
|
||||||
|
return audio_embedding
|
||||||
|
|
||||||
|
def inference(self, audio_embedding, face_sequences):
|
||||||
|
feats = []
|
||||||
|
x = face_sequences
|
||||||
|
for f in self.face_encoder_blocks:
|
||||||
|
x = f(x)
|
||||||
|
feats.append(x)
|
||||||
|
|
||||||
|
x = audio_embedding
|
||||||
|
for f in self.face_decoder_blocks:
|
||||||
|
x = f(x)
|
||||||
|
try:
|
||||||
|
x = torch.cat((x, feats[-1]), dim=1)
|
||||||
|
except Exception as e:
|
||||||
|
print(x.size())
|
||||||
|
print(feats[-1].size())
|
||||||
|
raise e
|
||||||
|
|
||||||
|
feats.pop()
|
||||||
|
|
||||||
|
x = self.output_block(x)
|
||||||
|
outputs = x
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def forward(self, audio_sequences, face_sequences, a_alpha=1.):
|
||||||
|
# audio_sequences = (B, T, 1, 80, 16)
|
||||||
|
B = audio_sequences.size(0)
|
||||||
|
|
||||||
|
input_dim_size = len(face_sequences.size())
|
||||||
|
if input_dim_size > 4:
|
||||||
|
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)#[bz, 5, 1, 80, 16]->[bz*5, 1, 80, 16]
|
||||||
|
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)#[bz, 6, 5, 256, 256]->[bz*5, 6, 256, 256]
|
||||||
|
|
||||||
|
audio_embedding = self.audio_encoder(audio_sequences) # [bz*5, 1, 80, 16]->[bz*5, 512, 1, 1]
|
||||||
|
if a_alpha != 1.:
|
||||||
|
audio_embedding *= a_alpha #放大音频强度
|
||||||
|
|
||||||
|
feats = []
|
||||||
|
x = face_sequences
|
||||||
|
for f in self.face_encoder_blocks:
|
||||||
|
x = f(x)
|
||||||
|
feats.append(x)
|
||||||
|
|
||||||
|
x = audio_embedding
|
||||||
|
for f in self.face_decoder_blocks:
|
||||||
|
x = f(x)
|
||||||
|
try:
|
||||||
|
x = torch.cat((x, feats[-1]), dim=1)
|
||||||
|
except Exception as e:
|
||||||
|
print(x.size())
|
||||||
|
print(feats[-1].size())
|
||||||
|
raise e
|
||||||
|
|
||||||
|
feats.pop()
|
||||||
|
|
||||||
|
x = self.output_block(x) #[bz*5, 80, 256, 256]->[bz*5, 3, 256, 256]
|
||||||
|
|
||||||
|
if input_dim_size > 4: #[bz*5, 3, 256, 256]->[B, 3, 5, 256, 256]
|
||||||
|
x = torch.split(x, B, dim=0)
|
||||||
|
outputs = torch.stack(x, dim=2)
|
||||||
|
|
||||||
|
else:
|
||||||
|
outputs = x
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class Wav2Lip_disc_qual(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Wav2Lip_disc_qual, self).__init__()
|
||||||
|
|
||||||
|
self.face_encoder_blocks = nn.ModuleList([
|
||||||
|
nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)),
|
||||||
|
|
||||||
|
nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2),
|
||||||
|
nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
|
||||||
|
|
||||||
|
nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
|
||||||
|
nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
|
||||||
|
|
||||||
|
nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
|
||||||
|
nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
|
||||||
|
|
||||||
|
nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||||||
|
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
|
||||||
|
|
||||||
|
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||||
|
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1), ),
|
||||||
|
|
||||||
|
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||||
|
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1), ),
|
||||||
|
|
||||||
|
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
|
||||||
|
nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)), ])
|
||||||
|
|
||||||
|
self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
|
||||||
|
self.label_noise = .0
|
||||||
|
|
||||||
|
def get_lower_half(self, face_sequences): #取得输入图片的下半部分。
|
||||||
|
return face_sequences[:, :, face_sequences.size(2) // 2:]
|
||||||
|
|
||||||
|
def to_2d(self, face_sequences): #将输入的图片序列连接起来,形成一个二维的tensor。
|
||||||
|
B = face_sequences.size(0)
|
||||||
|
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
|
||||||
|
return face_sequences
|
||||||
|
|
||||||
|
def perceptual_forward(self, false_face_sequences): #前传生成图像
|
||||||
|
false_face_sequences = self.to_2d(false_face_sequences) #[bz, 3, 5, 256, 256]->[bz*5, 3, 256, 256]
|
||||||
|
false_face_sequences = self.get_lower_half(false_face_sequences)#[bz*5, 3, 256, 256]->[bz*5, 3, 128, 256]
|
||||||
|
|
||||||
|
false_feats = false_face_sequences
|
||||||
|
for f in self.face_encoder_blocks: #[bz*5, 3, 128, 256]->[bz*5, 512, 1, 1]
|
||||||
|
false_feats = f(false_feats)
|
||||||
|
|
||||||
|
return self.binary_pred(false_feats).view(len(false_feats), -1) #[bz*5, 512, 1, 1]->[bz*5, 1, 1]
|
||||||
|
|
||||||
|
def forward(self, face_sequences): #前传真值图像
|
||||||
|
face_sequences = self.to_2d(face_sequences)
|
||||||
|
face_sequences = self.get_lower_half(face_sequences)
|
||||||
|
|
||||||
|
x = face_sequences
|
||||||
|
for f in self.face_encoder_blocks:
|
||||||
|
x = f(x)
|
||||||
|
|
||||||
|
return self.binary_pred(x).view(len(x), -1)
|
@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import face_detection
|
import face_detection
|
||||||
from models import Wav2Lip
|
from models import Wav2Lip, Wav2LipV2
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -144,7 +144,7 @@ def get_device():
|
|||||||
|
|
||||||
|
|
||||||
def _load(checkpoint_path):
|
def _load(checkpoint_path):
|
||||||
device = get_device
|
device = get_device()
|
||||||
if device == 'cuda':
|
if device == 'cuda':
|
||||||
checkpoint = torch.load(checkpoint_path)
|
checkpoint = torch.load(checkpoint_path)
|
||||||
else:
|
else:
|
||||||
@ -154,7 +154,7 @@ def _load(checkpoint_path):
|
|||||||
|
|
||||||
|
|
||||||
def load_model(path):
|
def load_model(path):
|
||||||
model = Wav2Lip()
|
model = Wav2LipV2()
|
||||||
print("Load checkpoint from: {}".format(path))
|
print("Load checkpoint from: {}".format(path))
|
||||||
logging.info(f'Load checkpoint from {path}')
|
logging.info(f'Load checkpoint from {path}')
|
||||||
checkpoint = _load(path)
|
checkpoint = _load(path)
|
||||||
|
Loading…
Reference in New Issue
Block a user