human/wav2lip.py

36 lines
1010 B
Python
Raw Permalink Normal View History

2024-09-02 00:13:34 +00:00
#encoding = utf8
import torch
from models import Wav2Lip
def _load(checkpoint_path, device):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
class Wav2LipModel:
def __init__(self, check_points):
self.__checkpoints = check_points
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} for inference.'.format(device))
self._model = self.__load_model(device)
def __load_model(self, device):
model = Wav2Lip()
print("Load checkpoint from: {}".format(self.__checkpoints))
checkpoint = _load(self.__checkpoints)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()