36 lines
1010 B
Python
36 lines
1010 B
Python
#encoding = utf8
|
|
|
|
import torch
|
|
from models import Wav2Lip
|
|
|
|
|
|
def _load(checkpoint_path, device):
|
|
if device == 'cuda':
|
|
checkpoint = torch.load(checkpoint_path)
|
|
else:
|
|
checkpoint = torch.load(checkpoint_path,
|
|
map_location=lambda storage, loc: storage)
|
|
return checkpoint
|
|
|
|
|
|
class Wav2LipModel:
|
|
def __init__(self, check_points):
|
|
self.__checkpoints = check_points
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
print('Using {} for inference.'.format(device))
|
|
|
|
self._model = self.__load_model(device)
|
|
|
|
def __load_model(self, device):
|
|
model = Wav2Lip()
|
|
print("Load checkpoint from: {}".format(self.__checkpoints))
|
|
checkpoint = _load(self.__checkpoints)
|
|
s = checkpoint["state_dict"]
|
|
new_s = {}
|
|
for k, v in s.items():
|
|
new_s[k.replace('module.', '')] = v
|
|
model.load_state_dict(new_s)
|
|
|
|
model = model.to(device)
|
|
return model.eval()
|