#encoding = utf8
import logging
import asyncio
import time

import edge_tts
import numpy as np
import soundfile as sf
import resampy
import queue
from io import BytesIO
from queue import Queue
from threading import Thread, Event


import audio

logger = logging.getLogger(__name__)


class TTSBase:
    def __init__(self, human):
        self._human = human
        self._thread = None
        self._queue = Queue()
        self.input_stream = BytesIO()
        self.chunk = self._human.get_audio_sample_rate() // self._human.get_fps()

        self._exit_event = Event()
        self._thread = Thread(target=self._on_run)
        self._exit_event.set()
        self._thread.start()
        logging.info('tts start')

    def _on_run(self):
        logging.info('tts run')
        while self._exit_event.is_set():
            try:
                txt = self._queue.get(block=True, timeout=1)
            except queue.Empty:
                continue
            self._request(txt)
        logging.info('tts exit')

    def _request(self, txt):
        voice = 'zh-CN-XiaoyiNeural'
        t = time.time()
        asyncio.new_event_loop().run_until_complete(self.__main(voice, txt))
        print(f'-------edge tts time:{time.time() - t:.4f}s')

        self.input_stream.seek(0)
        stream = self.__create_bytes_stream(self.input_stream)
        streamlen = stream.shape[0]
        idx = 0
        print('-------tts start push chunk')
        while streamlen >= self.chunk:
            self._human.put_audio_frame(stream[idx:idx + self.chunk])
            streamlen -= self.chunk
            idx += self.chunk
        # if streamlen>0:  #skip last frame(not 20ms)
        #    self.queue.put(stream[idx:])
        self.input_stream.seek(0)
        self.input_stream.truncate()
        print('-------tts finish push chunk')

    def __create_bytes_stream(self, byte_stream):
            # byte_stream=BytesIO(buffer)
            stream, sample_rate = sf.read(byte_stream)  # [T*sample_rate,] float64
            print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
            stream = stream.astype(np.float32)

            if stream.ndim > 1:
                print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
                stream = stream[:, 0]

            if sample_rate != self._human.get_audio_sample_rate() and stream.shape[0] > 0:
                print(f'[WARN] audio sample rate is {sample_rate},'
                      f'resampling into {self._human.get_audio_sample_rate()}.')
                stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._human.get_audio_sample_rate())

            return stream

    async def __main(self, voicename: str, text: str):
        communicate = edge_tts.Communicate(text, voicename)

        #with open(OUTPUT_FILE, "wb") as file:
        first = True
        async for chunk in communicate.stream():
            if first:
                first = False
            if chunk["type"] == "audio":
                #self.push_audio(chunk["data"])
                self.input_stream.write(chunk["data"])
                #file.write(chunk["data"])
            elif chunk["type"] == "WordBoundary":
                pass

    def stop(self):
        self.input_stream.seek(0)
        self.input_stream.truncate()
        if self._exit_event is None:
            return

        self._exit_event.clear()
        self._thread.join()
        logging.info('tts stop')

    def pause_talk(self):
        self.clear()

    def clear(self):
        self._queue.queue.clear()

    def push_txt(self, txt):
        self._queue.put(txt)