#encoding = utf8

import torch
from models import Wav2Lip


def _load(checkpoint_path, device):
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint


class Wav2LipModel:
    def __init__(self, check_points):
        self.__checkpoints = check_points
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print('Using {} for inference.'.format(device))

        self._model = self.__load_model(device)

    def __load_model(self, device):
        model = Wav2Lip()
        print("Load checkpoint from: {}".format(self.__checkpoints))
        checkpoint = _load(self.__checkpoints)
        s = checkpoint["state_dict"]
        new_s = {}
        for k, v in s.items():
            new_s[k.replace('module.', '')] = v
        model.load_state_dict(new_s)

        model = model.to(device)
        return model.eval()