#encoding = utf8

import logging
import queue
import time
from queue import Queue
from threading import Thread, Event

import numpy as np
import audio


class Chunk2Mal:
    def __init__(self, human):
        self._audio_chunk_queue = Queue()
        self._human = human
        self._thread = None

        self._chunks = []
        # 320 samples per chunk (20ms * 16000 / 1000)audio_chunk
        self._chunk_len = self._human.get_audio_sample_rate() // self._human.get_fps()

        self._exit_event = Event()
        self._thread = Thread(target=self._on_run)
        self._exit_event.set()
        self._thread.start()
        logging.info('chunk2mal start')

    def _on_run(self):
        logging.info('chunk2mal run')
        while self._exit_event.is_set():
            if self._audio_chunk_queue.empty():
                time.sleep(0.5)
                continue
            try:
                chunk = self._audio_chunk_queue.get(block=True, timeout=1)
                self._chunks.append(chunk)
                self._human.push_audio_frames(chunk, 0)
                if len(self._chunks) < 10:
                    continue
            except queue.Empty:
                # print('Chunk2Mal queue.Empty')
                continue


            logging.info('np.concatenate')
            inputs = np.concatenate(self._chunks)  # [N * chunk]
            mel = audio.melspectrogram(inputs)
            if np.isnan(mel.reshape(-1)).sum() > 0:
                raise ValueError(
                    'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')

            mel_step_size = 16
            print('fps:', self._human.get_fps())
            mel_idx_multiplier = 80. / self._human.get_fps()
            print('mel_idx_multiplier:', mel_idx_multiplier)

            i = 0
            while 1:
                start_idx = int(i * mel_idx_multiplier)
                if start_idx + mel_step_size > len(mel[0]):
                    self._human.push_mel_chunks_queue(mel[:, len(mel[0]) - mel_step_size:])
                    break
                self._human.push_mel_chunks_queue(mel[:, start_idx: start_idx + mel_step_size])
                i += 1

            batch_size = 128

            '''
            while i < (len(self._chunks) - self._human.get_stride_left_size()
                       - self._human.get_stride_right_size()) / 2:
                start_idx = int(left + i * mel_idx_multiplier)
                # print(start_idx)
                if start_idx + mel_step_size > len(mel[0]):
                    mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
                else:
                    mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
                i += 1
            self._human.push_feat_queue(mel_chunks)

            # discard the old part to save memory
            self._chunks = self._chunks[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):]
            '''

        logging.info('chunk2mal exit')

    def stop(self):
        if self._exit_event is None:
            return

        self._exit_event.clear()
        if self._thread.is_alive():
            self._thread.join()
        logging.info('chunk2mal stop')

    def push_chunk(self, chunk):
        self._audio_chunk_queue.put(chunk)

    def pull_chunk(self):
        try:
            chunk = self._audio_chunk_queue.get(block=True, timeout=1)
            type = 1
        except queue.Empty:
            chunk = np.zeros(self._chunk_len, dtype=np.float32)
            type = 0
        return chunk, type