#encoding = utf8
import os
import glob
import queue
import multiprocessing as mp
import time
from queue import Queue
from threading import Thread, Event
import logging

import cv2
import numpy as np
import torch
from tqdm import tqdm

import face_detection
import utils
from models import Wav2Lip
from utils import mirror_index

logger = logging.getLogger(__name__)

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def read_images(img_list):
    frames = []
    print('reading images...')
    for img_path in tqdm(img_list):
        print(f'read image path:{img_path}')
        frame = cv2.imread(img_path)
        frames.append(frame)
    return frames


def _load(checkpoint_path):
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint


def load_model(path):
    model = Wav2Lip()
    print("Load checkpoint from: {}".format(path))
    logging.info(f'Load checkpoint from {path}')
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()


def get_smoothened_boxes(boxes, T):
    for i in range(len(boxes)):
        if i + T > len(boxes):
            window = boxes[len(boxes) - T:]
        else:
            window = boxes[i : i + T]
        boxes[i] = np.mean(window, axis=0)
    return boxes


def face_detect(images):
    detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
                                            flip_input=False, device=device)
    batch_size = 16

    while 1:
        predictions = []
        try:
            for i in range(0, len(images), batch_size):
                predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
        except RuntimeError:
            if batch_size == 1:
                raise RuntimeError(
                    'Image too big to run face detection on GPU. Please use the --resize_factor argument')
            batch_size //= 2
            print('Recovering from OOM error; New batch size: {}'.format(batch_size))
            continue
        break

    results = []
    pady1, pady2, padx1, padx2 = [0, 10, 0, 0]
    for rect, image in zip(predictions, images):
        if rect is None:
            cv2.imwrite('temp/faulty_frame.jpg', image)  # check this frame where the face was not detected.
            raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')

        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)

        results.append([x1, y1, x2, y2])

    boxes = np.array(results)
    boxes = get_smoothened_boxes(boxes, T=5)
    results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]

    del detector
    return results


img_size = 96
wav2lip_batch_size = 128


def datagen(frames, mels):
    img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []

    face_det_results = face_detect(frames)  # BGR2RGB for CNN face detection

    i = 0
    for m in enumerate(mels):
    # for i in range(mels.qsize()):
        idx = 0 if True else i%len(frames)
        frame_to_save = frames[mirror_index(1, i)].copy()
        face, coords = face_det_results[idx].copy()

        face = cv2.resize(face, (img_size, img_size))

        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coords_batch.append(coords)

        if len(img_batch) >= wav2lip_batch_size:
            img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)

            img_masked = img_batch.copy()
            img_masked[:, img_size//2:] = 0
            img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
            mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

            yield img_batch, mel_batch, frame_batch, coords_batch
            img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
        i = i + 1

    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, img_size//2:] = 0

        img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

        yield img_batch, mel_batch, frame_batch, coords_batch


def datagen_signal(frame, mel, face_det_results):
    img_batch, mel_batch, frame_batch, coord_batch = [], [], [], []

    # for i, m in enumerate(mels):
    idx = 0
    frame_to_save = frame.copy()
    face, coord = face_det_results[idx].copy()

    face = cv2.resize(face, (img_size, img_size))

    for i, m in enumerate(mel):
        img_batch.append(face)
        mel_batch.append(m)
        frame_batch.append(frame_to_save)
        coord_batch.append(coord)

    if len(img_batch) >= wav2lip_batch_size:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)

        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0
        img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

        return img_batch, mel_batch, frame_batch, coord_batch

    if len(img_batch) > 0:
        img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
        img_masked = img_batch.copy()
        img_masked[:, img_size // 2:] = 0

        img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
        mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

        return img_batch, mel_batch, frame_batch, coord_batch


def inference(render_event, batch_size, face_imgs_path, audio_feat_queue, audio_out_queue, res_frame_queue):
    model = load_model(r'.\checkpoints\wav2lip.pth')
    # face_list_cycle = read_images(face_imgs_path)
    input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
    input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
    face_list_cycle = read_images(input_face_list)

    # input_latent_list_cycle = torch.load(latents_out_path)
    length = len(face_list_cycle)
    index = 0
    count = 0
    counttime = 0
    print('start inference')
    while True:
        if render_event.is_set():
            starttime = time.perf_counter()
            mel_batch = []
            try:
                mel_batch = audio_feat_queue.get(block=True, timeout=1)
            except queue.Empty:
                continue

            is_all_silence = True
            audio_frames = []
            for _ in range(batch_size * 2):
                frame, type_ = audio_out_queue.get()
                audio_frames.append((frame, type_))
                if type_ == 0:
                    is_all_silence = False

            if is_all_silence:
                for i in range(batch_size):
                    res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
                    index = index + 1
            else:
                print('infer=======')
                t = time.perf_counter()
                img_batch = []
                for i in range(batch_size):
                    idx = mirror_index(length, index + i)
                    face = face_list_cycle[idx]
                    img_batch.append(face)
                img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)

                img_masked = img_batch.copy()
                img_masked[:, face.shape[0] // 2:] = 0

                img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
                mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

                img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
                mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

                with torch.no_grad():
                    pred = model(mel_batch, img_batch)
                pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.

                counttime += (time.perf_counter() - t)
                count += batch_size
                # _totalframe += 1
                if count >= 100:
                    print(f"------actual avg infer fps:{count / counttime:.4f}")
                    count = 0
                    counttime = 0
                for i, res_frame in enumerate(pred):
                    # self.__pushmedia(res_frame,loop,audio_track,video_track)
                    res_frame_queue.put((res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
                    index = index + 1
                # print('total batch time:',time.perf_counter()-starttime)
        else:
            time.sleep(1)
    print('musereal inference processor stop')

class Infer:
    def __init__(self, human):
        self._human = human
        self._feat_queue = Queue()
        self._audio_out_queue = Queue()

        self.batch_size = human.get_batch_size()
        # self.asr = human.chunk_2_mal
        # self.res_frame_queue = human.res_render_queue

        # face_images_path = r'./face/'
        # self.face_images_path = utils.read_files_path(face_images_path)
        self.avatar_id = 'wav2lip_avatar1'
        self.avatar_path = f"./data/{self.avatar_id}"
        self.full_imgs_path = f"{self.avatar_path}/full_imgs"
        self.face_images_path = f"{self.avatar_path}/face_imgs"
        self.coords_path = f"{self.avatar_path}/coords.pkl"
        # self.render_event = mp.Event()
        # mp.Process(target=inference, args=(self.render_event, self.batch_size, self.face_images_path,
        #                                    self.asr.feat_queue, self.asr.output_queue, self.res_frame_queue,
        #                                    )).start()
        # self.render_event.set()

        self._exit_event = Event()
        self._run_thread = Thread(target=self.__on_run)
        self._exit_event.set()
        self._run_thread.start()

    def __on_run(self):
        model = load_model(r'.\checkpoints\wav2lip.pth')
        print("Model loaded")

        face_list_cycle = self._human.get_face_list_cycle()
        # input_face_list = glob.glob(os.path.join(self.face_images_path, '*.[jpJP][pnPN]*[gG]'))
        # input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
        # face_list_cycle = read_images(input_face_list)

        # self.__do_run1(face_list_cycle, model)
        self.__do_run2(face_list_cycle, model)

        # frame_h, frame_w = face_list_cycle[0].shape[:-1]

    def __do_run1(self, face_list_cycle, model):
        face_det_results = face_detect(face_list_cycle)

        j = 0

        count = 0
        while self._exit_event.is_set():
            try:
                m = self._feat_queue.get(block=True, timeout=1)
            except queue.Empty:
                continue

            img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results)

            img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
            mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

            time.sleep(0.01)

            with torch.no_grad():
                pred = model(mel_batch, img_batch)

            pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
            for p, f, c in zip(pred, frames, coords):
                y1, y2, x1, x2 = c
                p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))

                f[y1:y2, x1:x2] = p
                # name = "%04d" % j
                cv2.imwrite(f'temp/images/{j}.jpg', p)
                j = j + 1
                # count = count + 1
                p = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
                self._human.push_render_image(p)
                # out.write(f)
            # print('infer count:', count)

    def __do_run2(self, face_list_cycle, model):
        length = len(face_list_cycle)
        index = 0
        count = 0
        count_time = 0
        print('start inference')

        while True:
            if self._exit_event.is_set():
                start_time = time.perf_counter()
                batch_size = self._human.get_batch_size()
                try:
                    mel_batch = self._feat_queue.get(block=True, timeout=0.1)
                except queue.Empty:
                    continue

                is_all_silence = True
                audio_frames = []
                for _ in range(batch_size * 2):
                    frame, type_ = self._audio_out_queue.get()
                    audio_frames.append((frame, type_))
                    if type_ == 0:
                        is_all_silence = False
                if is_all_silence:
                    for i in range(batch_size):
                        self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])
                        index = index + 1
                else:
                    print('infer=======')
                    t = time.perf_counter()
                    img_batch = []
                    for i in range(batch_size):
                        idx = mirror_index(length, index + i)
                        face = face_list_cycle[idx]
                        img_batch.append(face)

                    img_batch = np.asarray(img_batch)
                    mel_batch = np.asarray(mel_batch)
                    img_masked = img_batch.copy()
                    img_masked[:, face.shape[0] // 2:] = 0
                    #
                    img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
                    mel_batch = np.reshape(mel_batch,
                                           [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

                    img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
                    mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

                    with torch.no_grad():
                        pred = model(mel_batch, img_batch)

                    pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.

                    count_time += (time.perf_counter() - t)
                    count += batch_size
                    # _totalframe += 1
                    if count >= 100:
                        print(f"------actual avg infer fps:{count / count_time:.4f}")
                        count = 0
                        count_time = 0

                    image_index = 0
                    for i, res_frame in enumerate(pred):
                        self._human.push_res_frame(res_frame, mirror_index(length, index),
                                                   audio_frames[i * 2:i * 2 + 2])
                        index = index + 1
                        image_index = image_index + 1
                    print('batch count', image_index)
                    print('total batch time:', time.perf_counter() - start_time)
            else:
                time.sleep(1)
                break
        print('musereal inference processor stop')

    def stop(self):
        if self._exit_event is None:
            return

        self.pause_talk()

        self._exit_event.clear()
        self._run_thread.join()
        logging.info('Infer stop')

    def pause_talk(self):
        self._feat_queue.queue.clear()
        self._audio_out_queue.queue.clear()

    def push(self, mel_chunks):
        self._feat_queue.put(mel_chunks)

    def push_out_queue(self, frame, type_):
        self._audio_out_queue.put((frame, type_))

    def get_out_put(self):
        return self._audio_out_queue.get()