#encoding = utf8 import torch from models import Wav2Lip def _load(checkpoint_path, device): if device == 'cuda': checkpoint = torch.load(checkpoint_path) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) return checkpoint class Wav2LipModel: def __init__(self, check_points): self.__checkpoints = check_points device = 'cuda' if torch.cuda.is_available() else 'cpu' print('Using {} for inference.'.format(device)) self._model = self.__load_model(device) def __load_model(self, device): model = Wav2Lip() print("Load checkpoint from: {}".format(self.__checkpoints)) checkpoint = _load(self.__checkpoints) s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): new_s[k.replace('module.', '')] = v model.load_state_dict(new_s) model = model.to(device) return model.eval()