#encoding = utf8
import logging

import multiprocessing as mp
import queue
import time

import numpy as np

import utils
from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal
import torch
import cv2
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def _load(checkpoint_path):
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint


def load_model(path):
    model = Wav2Lip()
    print("Load checkpoint from: {}".format(path))
    logging.info(f'Load checkpoint from {path}')
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()


def read_images(img_list):
    frames = []
    print('reading images...')
    for img_path in tqdm(img_list):
        print(f'read image path:{img_path}')
        frame = cv2.imread(img_path)
        frames.append(frame)
    return frames


def __mirror_index(size, index):
    # size = len(self.coord_list_cycle)
    turn = index // size
    res = index % size
    if turn % 2 == 0:
        return res
    else:
        return size - res - 1


#  python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
#  .\face\img00016.jpg --audio .\audio\audio1.wav
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
    logging.info(f'Using {device} for inference.')
    print(f'Using {device} for inference.')

    print(f'face_images_path: {face_images_path}')

    model = load_model(r'.\checkpoints\wav2lip.pth')
    face_list_cycle = read_images(face_images_path)
    face_images_length = len(face_list_cycle)
    logging.info(f'face images length: {face_images_length}')
    print(f'face images length: {face_images_length}')

    length = len(face_list_cycle)
    index = 0
    count = 0
    count_time = 0
    logging.info('start inference')
    print(f'start inference: {render_event.is_set()}')
    while render_event.is_set():
        print('start inference')
        try:
            mel_batch = audio_feat_queue.get(block=True, timeout=1)
        except queue.Empty:
            continue

        audio_frames = []
        is_all_silence = True
        for _ in range(batch_size * 2):
            frame, type = audio_out_queue.get()
            audio_frames.append((frame, type))

            if type == 0:
                is_all_silence = False

        print(f'is_all_silence {is_all_silence}')
        if is_all_silence:
            for i in range(batch_size):
                res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
                index = index + 1
        else:
            t = time.perf_counter()
            image_batch = []
            for i in range(batch_size):
                idx = __mirror_index(length, index + i)
                face = face_list_cycle[idx]
                image_batch.append(face)
            image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch)

            image_masked = image_batch.copy()
            image_masked[:, face.shape[0]//2:] = 0

            image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255.
            mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

            image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device)
            mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

            with torch.no_grad():
                pred = model(mel_batch, image_batch)
            pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.

            count_time += (time.perf_counter() - t)
            count += batch_size
            if count >= 100:
                logging.info(f"------actual avg infer fps:{count/count_time:.4f}")
                count = 0
                count_time = 0

            for i, res_frame in enumerate(pred):
                res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
                index = index + 1

    logging.info('finish inference')


class Human:
    def __init__(self):
        self._tts = None
        self._fps = 50  # 20 ms per frame
        self._batch_size = 16
        self._sample_rate = 16000
        self._chunk = self._sample_rate // self._fps  # 320 samples per chunk (20ms * 16000 / 1000)
        self._chunk_2_mal = Chunk2Mal(self)
        self._stride_left_size = 10
        self._stride_right_size = 10
        self._feat_queue = mp.Queue(2)
        self._output_queue = mp.Queue()
        self._res_frame_queue = mp.Queue(self._batch_size * 2)

        face_images_path = r'./face/'
        self._face_image_paths = utils.read_files_path(face_images_path)
        print(self._face_image_paths)
        self.render_event = mp.Event()
        mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths,
                                           self._feat_queue, self._output_queue, self._res_frame_queue,
                                           )).start()
        self.render_event.set()

    def get_fps(self):
        return self._fps

    def get_batch_size(self):
        return self._batch_size

    def get_chunk(self):
        return self._chunk

    def get_stride_left_size(self):
        return self._stride_left_size

    def get_stride_right_size(self):
        return self._stride_right_size

    def on_destroy(self):
        self.render_event.clear()
        self._chunk_2_mal.stop()
        if self._tts is not None:
            self._tts.stop()
        logging.info('human destroy')

    def set_tts(self, tts):
        if self._tts == tts:
            return

        self._tts = tts
        self._tts.start()
        self._chunk_2_mal.start()

    def read(self, txt):
        if self._tts is None:
            logging.warning('tts is none')
            return

        self._tts.push_txt(txt)

    def push_audio_chunk(self, chunk):
        self._chunk_2_mal.push_chunk(chunk)

    def push_feat_queue(self, mel_chunks):
        print("21")
        self._feat_queue.put(mel_chunks)
        print("22")

    def push_audio_frames(self, chunk, type_):
        self._output_queue.put((chunk, type_))

    def render(self):
        try:
            img, aud = self._res_frame_queue.get(block=True, timeout=.3)
        except queue.Empty:
            print('queue.Empty:')
            return None
        return img

    # def pull_audio_chunk(self):
    #     try:
    #         chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
    #         type = 1
    #     except queue.Empty:
    #         chunk = np.zeros(self._chunk, dtype=np.float32)
    #         type = 0
    #     return chunk, type