human/models/wav2lip_v2.py
2024-11-21 00:30:15 +08:00

222 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
class Wav2LipV2(nn.Module):
def __init__(self):
super(Wav2LipV2, self).__init__()
self.face_encoder_blocks = nn.ModuleList([
nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)),
nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
nn.Sequential(Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0)), ])
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0), )
self.face_decoder_blocks = nn.ModuleList([
nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ),
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=4, stride=1, padding=0),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), ),
nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), ),
nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), ),
nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), ), ])
self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
nn.Sigmoid())
def audio_forward(self, audio_sequences, a_alpha=1.):
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
if a_alpha != 1.:
audio_embedding *= a_alpha
return audio_embedding
def inference(self, audio_embedding, face_sequences):
feats = []
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
feats.append(x)
x = audio_embedding
for f in self.face_decoder_blocks:
x = f(x)
try:
x = torch.cat((x, feats[-1]), dim=1)
except Exception as e:
print(x.size())
print(feats[-1].size())
raise e
feats.pop()
x = self.output_block(x)
outputs = x
return outputs
def forward(self, audio_sequences, face_sequences, a_alpha=1.):
# audio_sequences = (B, T, 1, 80, 16)
B = audio_sequences.size(0)
input_dim_size = len(face_sequences.size())
if input_dim_size > 4:
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)#[bz, 5, 1, 80, 16]->[bz*5, 1, 80, 16]
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)#[bz, 6, 5, 256, 256]->[bz*5, 6, 256, 256]
audio_embedding = self.audio_encoder(audio_sequences) # [bz*5, 1, 80, 16]->[bz*5, 512, 1, 1]
if a_alpha != 1.:
audio_embedding *= a_alpha #放大音频强度
feats = []
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
feats.append(x)
x = audio_embedding
for f in self.face_decoder_blocks:
x = f(x)
try:
x = torch.cat((x, feats[-1]), dim=1)
except Exception as e:
print(x.size())
print(feats[-1].size())
raise e
feats.pop()
x = self.output_block(x) #[bz*5, 80, 256, 256]->[bz*5, 3, 256, 256]
if input_dim_size > 4: #[bz*5, 3, 256, 256]->[B, 3, 5, 256, 256]
x = torch.split(x, B, dim=0)
outputs = torch.stack(x, dim=2)
else:
outputs = x
return outputs
class Wav2Lip_disc_qual(nn.Module):
def __init__(self):
super(Wav2Lip_disc_qual, self).__init__()
self.face_encoder_blocks = nn.ModuleList([
nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)),
nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2),
nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1), ),
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1), ),
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)), ])
self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
self.label_noise = .0
def get_lower_half(self, face_sequences): #取得输入图片的下半部分。
return face_sequences[:, :, face_sequences.size(2) // 2:]
def to_2d(self, face_sequences): #将输入的图片序列连接起来形成一个二维的tensor。
B = face_sequences.size(0)
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
return face_sequences
def perceptual_forward(self, false_face_sequences): #前传生成图像
false_face_sequences = self.to_2d(false_face_sequences) #[bz, 3, 5, 256, 256]->[bz*5, 3, 256, 256]
false_face_sequences = self.get_lower_half(false_face_sequences)#[bz*5, 3, 256, 256]->[bz*5, 3, 128, 256]
false_feats = false_face_sequences
for f in self.face_encoder_blocks: #[bz*5, 3, 128, 256]->[bz*5, 512, 1, 1]
false_feats = f(false_feats)
return self.binary_pred(false_feats).view(len(false_feats), -1) #[bz*5, 512, 1, 1]->[bz*5, 1, 1]
def forward(self, face_sequences): #前传真值图像
face_sequences = self.to_2d(face_sequences)
face_sequences = self.get_lower_half(face_sequences)
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
return self.binary_pred(x).view(len(x), -1)