From c0d6e01b23e96b3bb40fed61e7d8ae07a30dec75 Mon Sep 17 00:00:00 2001 From: jiegeaiai Date: Thu, 21 Nov 2024 00:30:15 +0800 Subject: [PATCH] add support 256 --- human/audio_inference_handler.py | 6 +- human/audio_mal_handler.py | 2 +- human/human_context.py | 4 +- human/human_render.py | 2 +- models/__init__.py | 1 + models/wav2lip.py | 1 + models/wav2lip_v2.py | 221 +++++++++++++++++++++++++++++++ utils/utils.py | 6 +- 8 files changed, 234 insertions(+), 9 deletions(-) create mode 100644 models/wav2lip_v2.py diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py index 6c78c84..82a8d51 100644 --- a/human/audio_inference_handler.py +++ b/human/audio_inference_handler.py @@ -59,7 +59,9 @@ class AudioInferenceHandler(AudioHandler): super().on_message(message) def __on_run(self): - wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'wav2lip.pth') + # wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'wav2lip.pth') + wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'weights', 'wav2lip', + 'ema_checkpoint_step000300000.pth') logger.info(f'AudioInferenceHandler init, path:{wav2lip_path}') model = load_model(wav2lip_path) logger.info("Model loaded") @@ -130,7 +132,7 @@ class AudioInferenceHandler(AudioHandler): img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) - print('img_batch:', img_batch.shape, 'mel_batch:', mel_batch.shape) + # print('img_batch:', img_batch.shape, 'mel_batch:', mel_batch.shape) with torch.no_grad(): pred = model(mel_batch, img_batch) diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py index 53fc14b..679f27d 100644 --- a/human/audio_mal_handler.py +++ b/human/audio_mal_handler.py @@ -79,7 +79,7 @@ class AudioMalHandler(AudioHandler): mel = melspectrogram(inputs) # print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames)) # cut off stride - left = max(0, self._context.stride_left_size * 80 / 50) + left = max(0, self._context.stride_left_size * 80 / self._context.fps) right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50) mel_idx_multiplier = 80. * 2 / self._context.fps mel_step_size = 16 diff --git a/human/human_context.py b/human/human_context.py index 40eee13..0f1bc1f 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -17,7 +17,7 @@ current_file_path = os.path.dirname(os.path.abspath(__file__)) class HumanContext: def __init__(self): - self._fps = 50 # 20 ms per frame + self._fps = 25 # 20 ms per frame self._image_size = 288 self._batch_size = 16 self._sample_rate = 16000 @@ -39,7 +39,7 @@ class HumanContext: logger.info(f'base path:{base_path}') # full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device) full_images, face_frames, coord_frames = load_avatar_from_processed(base_path, - 'wav2lip_avatar2') + 'wav2lip_avatar3') self._frame_list_cycle = full_images self._face_list_cycle = face_frames self._coord_list_cycle = coord_frames diff --git a/human/human_render.py b/human/human_render.py index 87c1666..5334b35 100644 --- a/human/human_render.py +++ b/human/human_render.py @@ -50,7 +50,7 @@ class HumanRender(AudioHandler): # t = time.time() self._run_step() # delay = time.time() - t - delay = 0.038 # - delay + delay = 0.04 # - delay # print(delay) # if delay <= 0.0: # continue diff --git a/models/__init__.py b/models/__init__.py index 1c144be..8b8b3be 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -2,3 +2,4 @@ from .wav2lip import Wav2Lip, Wav2Lip_disc_qual from .syncnet import SyncNet_color +from .wav2lip_v2 import Wav2LipV2 diff --git a/models/wav2lip.py b/models/wav2lip.py index ae5d691..0a4fb78 100644 --- a/models/wav2lip.py +++ b/models/wav2lip.py @@ -5,6 +5,7 @@ import math from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d + class Wav2Lip(nn.Module): def __init__(self): super(Wav2Lip, self).__init__() diff --git a/models/wav2lip_v2.py b/models/wav2lip_v2.py new file mode 100644 index 0000000..bf6cbd1 --- /dev/null +++ b/models/wav2lip_v2.py @@ -0,0 +1,221 @@ +import torch +from torch import nn +from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d + + +class Wav2LipV2(nn.Module): + def __init__(self): + super(Wav2LipV2, self).__init__() + + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), + + nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ), + + nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ), + + nn.Sequential(Conv2d(512, 512, kernel_size=4, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0)), ]) + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ) + + self.face_decoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=4, stride=1, padding=0), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ), + + nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), ), + + nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), ), + + nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), ), + + nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), ), ]) + + self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), + nn.Sigmoid()) + + def audio_forward(self, audio_sequences, a_alpha=1.): + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + if a_alpha != 1.: + audio_embedding *= a_alpha + return audio_embedding + + def inference(self, audio_embedding, face_sequences): + feats = [] + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + feats.append(x) + + x = audio_embedding + for f in self.face_decoder_blocks: + x = f(x) + try: + x = torch.cat((x, feats[-1]), dim=1) + except Exception as e: + print(x.size()) + print(feats[-1].size()) + raise e + + feats.pop() + + x = self.output_block(x) + outputs = x + + return outputs + + def forward(self, audio_sequences, face_sequences, a_alpha=1.): + # audio_sequences = (B, T, 1, 80, 16) + B = audio_sequences.size(0) + + input_dim_size = len(face_sequences.size()) + if input_dim_size > 4: + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)#[bz, 5, 1, 80, 16]->[bz*5, 1, 80, 16] + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)#[bz, 6, 5, 256, 256]->[bz*5, 6, 256, 256] + + audio_embedding = self.audio_encoder(audio_sequences) # [bz*5, 1, 80, 16]->[bz*5, 512, 1, 1] + if a_alpha != 1.: + audio_embedding *= a_alpha #放大音频强度 + + feats = [] + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + feats.append(x) + + x = audio_embedding + for f in self.face_decoder_blocks: + x = f(x) + try: + x = torch.cat((x, feats[-1]), dim=1) + except Exception as e: + print(x.size()) + print(feats[-1].size()) + raise e + + feats.pop() + + x = self.output_block(x) #[bz*5, 80, 256, 256]->[bz*5, 3, 256, 256] + + if input_dim_size > 4: #[bz*5, 3, 256, 256]->[B, 3, 5, 256, 256] + x = torch.split(x, B, dim=0) + outputs = torch.stack(x, dim=2) + + else: + outputs = x + + return outputs + + +class Wav2Lip_disc_qual(nn.Module): + def __init__(self): + super(Wav2Lip_disc_qual, self).__init__() + + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), + + nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), + nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)), + + nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), + nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)), + + nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), + nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)), + + nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), + nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)), + + nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1), ), + + nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1), ), + + nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=4, stride=1, padding=0), + nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)), ]) + + self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) + self.label_noise = .0 + + def get_lower_half(self, face_sequences): #取得输入图片的下半部分。 + return face_sequences[:, :, face_sequences.size(2) // 2:] + + def to_2d(self, face_sequences): #将输入的图片序列连接起来,形成一个二维的tensor。 + B = face_sequences.size(0) + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) + return face_sequences + + def perceptual_forward(self, false_face_sequences): #前传生成图像 + false_face_sequences = self.to_2d(false_face_sequences) #[bz, 3, 5, 256, 256]->[bz*5, 3, 256, 256] + false_face_sequences = self.get_lower_half(false_face_sequences)#[bz*5, 3, 256, 256]->[bz*5, 3, 128, 256] + + false_feats = false_face_sequences + for f in self.face_encoder_blocks: #[bz*5, 3, 128, 256]->[bz*5, 512, 1, 1] + false_feats = f(false_feats) + + return self.binary_pred(false_feats).view(len(false_feats), -1) #[bz*5, 512, 1, 1]->[bz*5, 1, 1] + + def forward(self, face_sequences): #前传真值图像 + face_sequences = self.to_2d(face_sequences) + face_sequences = self.get_lower_half(face_sequences) + + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + + return self.binary_pred(x).view(len(x), -1) diff --git a/utils/utils.py b/utils/utils.py index 0a64429..174f74e 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -11,7 +11,7 @@ from tqdm import tqdm from PIL import Image import face_detection -from models import Wav2Lip +from models import Wav2Lip, Wav2LipV2 logger = logging.getLogger(__name__) @@ -144,7 +144,7 @@ def get_device(): def _load(checkpoint_path): - device = get_device + device = get_device() if device == 'cuda': checkpoint = torch.load(checkpoint_path) else: @@ -154,7 +154,7 @@ def _load(checkpoint_path): def load_model(path): - model = Wav2Lip() + model = Wav2LipV2() print("Load checkpoint from: {}".format(path)) logging.info(f'Load checkpoint from {path}') checkpoint = _load(path)