#encoding = utf8
import copy
import glob
import io
import logging

import multiprocessing as mp
import os
import pickle
import platform, subprocess
import queue
import threading
import time


import numpy as np
import pyaudio

import audio
import face_detection
import utils
from audio_render import AudioRender
from infer import Infer, read_images
from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal
import torch
import cv2
from tqdm import tqdm
from queue import Queue

from tts.EdgeTTS import EdgeTTS
from tts.TTSBase import TTSBase
from utils import mirror_index

device = 'cuda' if torch.cuda.is_available() else 'cpu'


def _load(checkpoint_path):
    if device == 'cuda':
        checkpoint = torch.load(checkpoint_path)
    else:
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
    return checkpoint


def load_model(path):
    model = Wav2Lip()
    print("Load checkpoint from: {}".format(path))
    logging.info(f'Load checkpoint from {path}')
    checkpoint = _load(path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to(device)
    return model.eval()


#  python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
#  .\face\img00016.jpg --audio .\audio\audio1.wav
def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
    logging.info(f'Using {device} for inference.')
    print(f'Using {device} for inference.')

    print(f'face_images_path: {face_images_path}')

    model = load_model(r'.\checkpoints\wav2lip.pth')
    face_list_cycle = read_images(face_images_path)
    face_images_length = len(face_list_cycle)
    logging.info(f'face images length: {face_images_length}')
    print(f'face images length: {face_images_length}')

    length = len(face_list_cycle)
    index = 0
    count = 0
    count_time = 0
    logging.info('start inference')
    print(f'start inference: {render_event.is_set()}')
    while render_event.is_set():
        mel_batch = []
        try:
            mel_batch = audio_feat_queue.get(block=True, timeout=1)
        except queue.Empty:
            continue

        audio_frames = []
        is_all_silence = True
        for _ in range(batch_size * 2):
            frame, type = audio_out_queue.get()
            audio_frames.append((frame, type))

            if type == 0:
                is_all_silence = False

        print(f'is_all_silence {is_all_silence}')
        if is_all_silence:
            for i in range(batch_size):
                res_frame_queue.put((None, mirror_index(length, index), audio_frames[i*2:i*2+2]))
                index = index + 1
        else:
            t = time.perf_counter()
            image_batch = []
            for i in range(batch_size):
                idx = mirror_index(length, index + i)
                face = face_list_cycle[idx]
                image_batch.append(face)
            image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch)

            image_masked = image_batch.copy()
            image_masked[:, face.shape[0]//2:] = 0

            image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255.
            mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])

            image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device)
            mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

            with torch.no_grad():
                pred = model(mel_batch, image_batch)
            pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.

            count_time += (time.perf_counter() - t)
            count += batch_size
            if count >= 100:
                logging.info(f"------actual avg infer fps:{count/count_time:.4f}")
                count = 0
                count_time = 0

            for i, res_frame in enumerate(pred):
                res_frame_queue.put((res_frame, mirror_index(length, index), audio_frames[i*2 : i*2+2]))
                index = index + 1

    logging.info('finish inference')


def get_smoothened_boxes(boxes, T):
    for i in range(len(boxes)):
        if i + T > len(boxes):
            window = boxes[len(boxes) - T:]
        else:
            window = boxes[i : i + T]
        boxes[i] = np.mean(window, axis=0)
    return boxes


def face_detect(images):
    detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
                                            flip_input=False, device=device)
    batch_size = 16

    while 1:
        predictions = []
        try:
            for i in range(0, len(images), batch_size):
                predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
        except RuntimeError:
            if batch_size == 1:
                raise RuntimeError(
                    'Image too big to run face detection on GPU. Please use the --resize_factor argument')
            batch_size //= 2
            print('Recovering from OOM error; New batch size: {}'.format(batch_size))
            continue
        break

    results = []
    pady1, pady2, padx1, padx2 = [0, 10, 0, 0]
    for rect, image in zip(predictions, images):
        if rect is None:
            cv2.imwrite('temp/faulty_frame.jpg', image)  # check this frame where the face was not detected.
            raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')

        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)

        results.append([x1, y1, x2, y2])

    boxes = np.array(results)
    if not False: boxes = get_smoothened_boxes(boxes, T=5)
    results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]

    del detector
    return results


img_size = 96
wav2lip_batch_size = 128


# 从字节流加载音频数据
def load_audio_from_bytes(byte_data):
    # 使用 BytesIO 创建一个字节流
    with io.BytesIO(byte_data) as b:
        wav = audio.load_wav(b, 16000)  # 根据实际库的参数进行调整
    return wav

# 假设你有音频文件的字节数据


class Human:
    def __init__(self):
        self._fps = 50  # 20 ms per frame
        self._batch_size = 16
        self._sample_rate = 16000
        self._stride_left_size = 10
        self._stride_right_size = 10
        self._res_frame_queue = mp.Queue(self._batch_size * 2)

        full_images, face_frames, coord_frames = self._avatar()
        self._frame_list_cycle = full_images
        self._face_list_cycle = face_frames
        self._coord_list_cycle = coord_frames
        face_images_length = len(self._face_list_cycle)
        logging.info(f'face images length: {face_images_length}')
        print(f'face images length: {face_images_length}')

        # self.avatar_id = 'wav2lip_avatar1'
        # self.avatar_path = f"./data/{self.avatar_id}"
        # self.full_imgs_path = f"{self.avatar_path}/full_imgs"
        # self.face_imgs_path = f"{self.avatar_path}/face_imgs"
        # self.coords_path = f"{self.avatar_path}/coords.pkl"
        # self.__loadavatar()

        self.stop = False
        self.res_render_queue = Queue(self._batch_size * 2)

        self.chunk_2_mal = Chunk2Mal(self)
        self._tts = TTSBase(self)
        self._infer = Infer(self)
        # self.chunk_2_mal.warm_up()

        self.audio_render = AudioRender()

        #
        # self._thread = None
        # thread = threading.Thread(target=self.test)
        # thread.start()
        # self.test()
        # self.play_pcm()

        # face_images_path = r'./face/'
        # self._face_image_paths = utils.read_files_path(face_images_path)
        # print(self._face_image_paths)
        # self.render_event = mp.Event()
        # mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths,
        #                                    self._feat_queue, self._output_queue, self._res_frame_queue,
        #                                    )).start()
        # self.render_event.set()

    def __del__(self):
        print('Human del')
    # def play_pcm(self):
    #     p = pyaudio.PyAudio()
    #     stream = p.open(format=p.get_format_from_width(2), channels=1, rate=16000, output=True)
    #     file1 = r'./audio/en_weather.pcm'
    #
    #     # 将 pcm 数据直接写入 PyAudio 的数据流
    #     with open(file1, "rb") as f:
    #         stream.write(f.read())
    #
    #     stream.stop_stream()
    #     stream.close()
    #     p.terminate()

    def __loadavatar(self):
        with open(self.coords_path, 'rb') as f:
            self._coord_list_cycle = pickle.load(f)
        input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
        input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
        self._frame_list_cycle = read_images(input_img_list)

    def _avatar(self):
        face_images_path = r'./face/'
        face_images_path = utils.read_files_path(face_images_path)
        full_list_cycle = read_images(face_images_path)

        face_det_results = face_detect(full_list_cycle)

        face_frames = []
        coord_frames = []
        for face, coord in face_det_results:
            resized_crop_frame = cv2.resize(face, (img_size, img_size))
            face_frames.append(resized_crop_frame)
            coord_frames.append(coord)

        return full_list_cycle, face_frames, coord_frames

    def inter(self, model, chunks, face_list_cycle, face_det_results, out, j):
        inputs = np.concatenate(chunks)  # [5 * chunk]
        mel = audio.melspectrogram(inputs)
        print("inter", len(mel[0]))
        if np.isnan(mel.reshape(-1)).sum() > 0:
            raise ValueError(
                'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')

        mel_step_size = 16

        print('fps:', self._fps)
        mel_idx_multiplier = 80. / self._fps
        print('mel_idx_multiplier:', mel_idx_multiplier)
        i = 0
        mel_chunks = []
        while 1:
            start_idx = int(i * mel_idx_multiplier)
            if start_idx + mel_step_size > len(mel[0]):
                mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
                # self.mel_chunks_queue_.put(mel[:, len(mel[0]) - mel_step_size:])
                break
            mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
            # self.mel_chunks_queue_.put(mel[:, start_idx: start_idx + mel_step_size])
            i += 1
        self.mel_chunks_queue_.put(mel_chunks)
        while not self.mel_chunks_queue_.empty():
            print("self.mel_chunks_queue_ len:", self.mel_chunks_queue_.qsize())
            m = self.mel_chunks_queue_.get()
            # mel_batch = np.reshape(m, [len(m), mel_batch.shape[1], mel_batch.shape[2], 1])
            img_batch, mel_batch, frames, coords = utils.datagen_signal(face_list_cycle[0],
                                                                        m, face_det_results, img_size)

            img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
            mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)

            with torch.no_grad():
                pred = model(mel_batch, img_batch)

            pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
            for p, f, c in zip(pred, frames, coords):
                y1, y2, x1, x2 = c
                p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))

                f[y1:y2, x1:x2] = p
                name = "%04d" % j
                cv2.imwrite(f'temp/images/{j}.jpg', p)
                j = j + 1
                p = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
                self._test_image_queue.put(p)
                out.write(f)
        return j

    def test(self):
        batch_size = 128
        print('batch_size:', batch_size, ' mel_chunks len:', self.mel_chunks_queue_.qsize())

        face_images_path = r'./face/'
        face_images_path = utils.read_files_path(face_images_path)
        face_list_cycle = read_images(face_images_path)
        face_images_length = len(face_list_cycle)
        logging.info(f'face images length: {face_images_length}')
        print(f'face images length: {face_images_length}')

        model = load_model(r'.\checkpoints\wav2lip.pth')
        print("Model loaded")

        frame_h, frame_w = face_list_cycle[0].shape[:-1]
        out = cv2.VideoWriter('temp/resul_tttt.avi',
                              cv2.VideoWriter_fourcc(*'DIVX'), 25, (frame_w, frame_h))

        face_det_results = face_detect(face_list_cycle)

        audio_path = r'./temp/audio/chunk_0.wav'
        stream = audio.load_wav(audio_path, 16000)
        stream_len = stream.shape[0]
        print('wav length:', stream_len)
        _audio_chunk_queue = queue.Queue()
        index = 0
        chunk_len = 320# // 200
        print('chunk_len:', chunk_len)
        while stream_len >= chunk_len:
            audio_chunk = stream[index:index + chunk_len]
            _audio_chunk_queue.put(audio_chunk)
            stream_len -= chunk_len
            index += chunk_len
        if stream_len > 0:
            audio_chunk = stream[index:index + stream_len]
            _audio_chunk_queue.put(audio_chunk)
            index += stream_len
            stream_len -= stream_len

        print('_audio_chunk_queue:', _audio_chunk_queue.qsize())

        j = 0
        while not _audio_chunk_queue.empty():
            chunks = []
            length = min(128, _audio_chunk_queue.qsize())
            for i in range(length):
                chunks.append(_audio_chunk_queue.get())

            j = self.inter(model, chunks, face_list_cycle, face_det_results, out, j)

        out.release()
        command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, 'temp/resul_tttt.avi',
                                                                      'temp/resul_tttt.mp4')
        subprocess.call(command, shell=platform.system() != 'Windows')

        # gen = datagen(face_list_cycle, self.mel_chunks_queue_)

    def get_face_list_cycle(self):
        return self._face_list_cycle

    def get_fps(self):
        return self._fps

    def get_batch_size(self):
        return self._batch_size

    def get_audio_sample_rate(self):
        return self._sample_rate

    def get_stride_left_size(self):
        return self._stride_left_size

    def get_stride_right_size(self):
        return self._stride_right_size

    def on_destroy(self):
        self.stop = True
        # self.render_event.clear()

        self.chunk_2_mal.stop()
        self._tts.stop()
        self._infer.stop()
        # if self._tts is not None:
        #     self._tts.stop()
        logging.info('human destroy')

    def pause_talk(self):
        self._tts.pause_talk()
        self.chunk_2_mal.pause_talk()
        self._infer.pause_talk()

    def read(self, txt):
        if self._tts is None:
            logging.warning('tts is none')
            return
        self._tts.push_txt(txt)

    def put_audio_frame(self, audio_chunk):
        self.chunk_2_mal.put_audio_frame(audio_chunk)

    # def push_audio_chunk(self, audio_chunk):
    #     self._chunk_2_mal.push_chunk(audio_chunk)

    def push_mel_chunks(self, mel_chunks):
        self._infer.push(mel_chunks)

    def push_out_put(self, frame, type_):
        self._infer.push_out_queue(frame, type_)

    def get_out_put(self):
        return self._infer.get_out_put()

    def push_mel_chunks_queue(self, audio_chunk):
        self.audio_chunks_queue_.put(audio_chunk)

    def push_feat_queue(self, mel_chunks):
        print("push_feat_queue")
        self._feat_queue.put(mel_chunks)

    def push_audio_frames(self, chunk, type_):
        self._output_queue.put((chunk, type_))

    def push_render_image(self, image):
        self._test_image_queue.put(image)

    def push_res_frame(self, res_frame, idx, audio_frames):
        if self.stop:
            print("push_res_frame stop")
            return
        self.res_render_queue.put((res_frame, idx, audio_frames))

    def render(self):
        try:
            res_frame, idx, audio_frames = self.res_render_queue.get(block=True, timeout=.03)
        except queue.Empty:
            # print('render queue.Empty:')
            return None

        if audio_frames[0][1] != 0 and audio_frames[1][1] != 0:
            combine_frame = self._frame_list_cycle[idx]
        else:
            bbox = self._coord_list_cycle[idx]
            combine_frame = copy.deepcopy(self._frame_list_cycle[idx])
            y1, y2, x1, x2 = bbox
            try:
                res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
            except:
                return None
            # combine_frame = get_image(ori_frame,res_frame,bbox)
            # t=time.perf_counter()
            combine_frame[y1:y2, x1:x2] = res_frame

        image = combine_frame
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        for audio_frame in audio_frames:
            frame, type_ = audio_frame
            frame = (frame * 32767).astype(np.int16)
            self.audio_render.write(frame.tobytes(), int(frame.shape[0]*2))
            # new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0])
            # new_frame.planes[0].update(frame.tobytes())
            # new_frame.sample_rate = 16000
        return image


        # print('blending time:',time.perf_counter()-t)

    # def pull_audio_chunk(self):
    #     try:
    #         chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
    #         type = 1
    #     except queue.Empty:
    #         chunk = np.zeros(self._chunk, dtype=np.float32)
    #         type = 0
    #     return chunk, type