#encoding = utf8

import logging
import queue
import time
from queue import Queue
import multiprocessing as mp
from threading import Thread, Event

import numpy as np
import audio
from audio_render import AudioRender


class Chunk2Mal:
    def __init__(self, human):
        # self._audio_chunk_queue = Queue()
        self._human = human
        self._thread = None

        self.frames = []
        self.queue = Queue()

        self.fps = human.get_fps()
        self.batch_size = human.get_batch_size()
        self.stride_left_size = human.get_stride_left_size()
        self.stride_right_size = human.get_stride_right_size()
        # self.output_queue = mp.Queue()
        # self.feat_queue = mp.Queue(2)

        # 320 samples per chunk (20ms * 16000 / 1000)audio_chunk
        self.chunk = self._human.get_audio_sample_rate() // self._human.get_fps()

        self._exit_event = Event()
        self._thread = Thread(target=self._on_run)
        self._exit_event.set()
        self._thread.start()
        # self._audio_render = AudioRender()
        self._stream_len = 0
        logging.info('chunk2mal start')

    def _on_run(self):
        logging.info('chunk2mal run')
        while self._exit_event.is_set():
            self._run_step()
            time.sleep(0.01)

        logging.info('chunk2mal exit')

    def _run_step(self):
        for _ in range(self.batch_size * 2):
            frame, _type = self.get_audio_frame()
            self.frames.append(frame)
            self._human.push_out_put(frame, _type)
        # context not enough, do not run network.
        if len(self.frames) <= self.stride_left_size + self.stride_right_size:
            return

        inputs = np.concatenate(self.frames)  # [N * chunk]
        mel = audio.melspectrogram(inputs)
        # print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
        # cut off stride
        left = max(0, self.stride_left_size * 80 / 50)
        right = min(len(mel[0]), len(mel[0]) - self.stride_right_size * 80 / 50)
        mel_idx_multiplier = 80. * 2 / self.fps
        mel_step_size = 16
        i = 0
        mel_chunks = []
        while i < (len(self.frames) - self.stride_left_size - self.stride_right_size) / 2:
            start_idx = int(left + i * mel_idx_multiplier)
            # print(start_idx)
            if start_idx + mel_step_size > len(mel[0]):
                mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            else:
                mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
            i += 1
        # self.feat_queue.put(mel_chunks)
        self._human.push_mel_chunks(mel_chunks)

        # discard the old part to save memory
        self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]

    def stop(self):
        if self._exit_event is None:
            return

        self._exit_event.clear()
        if self._thread.is_alive():
            self._thread.join()
        logging.info('chunk2mal stop')

    def pause_talk(self):
        self.queue.queue.clear()

    def put_audio_frame(self, audio_chunk): #16khz 20ms pcm
        self.queue.put(audio_chunk)

    def get_audio_frame(self):
        try:
            frame = self.queue.get(block=True, timeout=0.01)
            type_ = 0
            # print(f'[INFO] get frame {frame.shape}')
        except queue.Empty:
            frame = np.zeros(self.chunk, dtype=np.float32)
            type_ = 1

        return frame, type_

    def warm_up(self):
        for _ in range(self.stride_left_size + self.stride_right_size):
            audio_frame, type_ = self.get_audio_frame()
            self.frames.append(audio_frame)
            # self.output_queue.put((audio_frame, type_))
            self._human.push_out_put(audio_frame, type_)
        for _ in range(self.stride_left_size):
            # self.output_queue.get()
            self._human.get_out_put()

    #
    # def get_next_feat(self, block, timeout):
    #     return self.feat_queue.get(block, timeout)