#encoding = utf8
import heapq
import logging
import os
import shutil
from threading import Lock

from eventbus import EventBus
from utils import save_wav
from human_handler import AudioHandler

logger = logging.getLogger(__name__)


class TTSAudioHandle(AudioHandler):
    def __init__(self, context, handler):
        super().__init__(context, handler)
        self._sample_rate = 16000
        self._index = -1

        EventBus().register('stop', self._on_stop)
        EventBus().register('clear_cache', self.on_clear_cache)

    def __del__(self):
        EventBus().unregister('stop', self._on_stop)
        EventBus().unregister('clear_cache', self.on_clear_cache)

    def _on_stop(self, *args, **kwargs):
        self.stop()

    def on_clear_cache(self, *args, **kwargs):
        self._index = -1

    @property
    def sample_rate(self):
        return self._sample_rate

    @sample_rate.setter
    def sample_rate(self, value):
        self._sample_rate = value

    def get_index(self):
        self._index = self._index + 1
        return self._index

    def on_handle(self, stream, index):
        pass

    def stop(self):
        pass

    def pause_talk(self):
        pass


class TTSAudioSplitHandle(TTSAudioHandle):
    def __init__(self, context, handler):
        super().__init__(context, handler)
        self.sample_rate = self._context.sample_rate
        self._chunk = self.sample_rate // self._context.fps
        self._priority_queue = []
        self._lock = Lock()
        self._current = 0
        self._is_running = True
        logger.info("TTSAudioSplitHandle init")

    def on_handle(self, stream, index):
        if not self._is_running:
            logger.info('TTSAudioSplitHandle::on_handle is not running')
            return

        s, txt = stream
        current = 0
        with self._lock:
            if len(self._priority_queue) != 0:
                current = self._priority_queue[0][0]
                if current == 0:
                    self._current = 0
                self._priority_queue.clear()

        if s is None:
            heapq.heappush(self._priority_queue, (index, None))
        else:
            stream_len = s.shape[0]
            idx = 0
            chunks = []
            while stream_len >= self._chunk and self._is_running:
                # self.on_next_handle(stream[idx:idx + self._chunk], 0)
                chunks.append(s[idx:idx + self._chunk])
                stream_len -= self._chunk
                idx += self._chunk
            if not self._is_running:
                return
            heapq.heappush(self._priority_queue, (index, (chunks, txt)))

        logger.info(f'TTSAudioSplitHandle::on_handle {index}, {current}, {self._current}, {len(self._priority_queue)}')
        if current == self._current:
            self._current = self._current + 1
            chunks = heapq.heappop(self._priority_queue)[1]
            chunks, txt = chunks

            if chunks is not None:
                for chunk in chunks:
                    logger.info(f'TTSAudioSplitHandle::on_handle push')
                    self.on_next_handle((chunk, txt), 0)
                    logger.info(f'TTSAudioSplitHandle::on_handle push finish')

    def stop(self):
        self._is_running = False

    def on_clear_cache(self, *args, **kwargs):
        super().on_clear_cache()
        if self._priority_queue is None or len(self._priority_queue) == 0:
            return

        with self._lock:
            print('TTSAudioSplitHandle::on_clear_cache', self._current)
            self._current = 0
            print('TTSAudioSplitHandle::on_clear_cache', self._current)

        self._priority_queue.clear()


class TTSAudioSaveHandle(TTSAudioHandle):
    def __init__(self, context, handler):
        super().__init__(context, handler)
        self._save_path_dir = '../temp/audio/'
        self._clean()

    def _clean(self):
        directory = self._save_path_dir
        if not os.path.exists(directory):
            print(f"The directory {directory} does not exist.")
            return

        for filename in os.listdir(directory):
            file_path = os.path.join(directory, filename)

            # 如果是文件,删除
            if os.path.isfile(file_path):
                os.remove(file_path)
                print(f"Deleted file: {file_path}")
            # 如果是文件夹,递归删除所有文件夹中的内容
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
                print(f"Deleted directory and its contents: {file_path}")

    def on_handle(self, stream, index):
        file_name = self._save_path_dir + str(index) + '.wav'
        save_wav(stream, file_name, self.sample_rate)

    def stop(self):
        pass