#encoding = utf8
import asyncio
import logging
import time

import edge_tts
import numpy as np
import soundfile
import resampy

from tts.TTSBase import TTSBase

logger = logging.getLogger(__name__)


class EdgeTTS(TTSBase):
    def __init__(self, human):
        super().__init__(human)

    def _request(self, txt):
        voice = 'zh-CN-XiaoyiNeural'
        t = time.time()
        asyncio.new_event_loop().run_until_complete(self.__on_request(voice, txt))
        logger.info(f'edge tts time:{time.time() - t : 0.4f}s')

        self._io_stream.seek(0)
        stream = self.__create_bytes_stream(self._io_stream)
        stream_len = stream.shape[0]
        index = 0

        while stream_len >= self._chunk:
            self._human.push_audio_chunk(stream[index:index + self._chunk])
            stream_len -= self._chunk
            index += self._chunk

    def __create_bytes_stream(self, io_stream):
        stream, sample_rate = soundfile.read(io_stream)
        logger.info(f'tts audio stream {sample_rate} : {stream.shape}')
        stream = stream.astype(np.float32)

        if stream.ndim > 1:
            logger.warning(f'tts audio has {stream.shape[1]} channels, only use the first')
            stream = stream[:, 1]

        if sample_rate != self._sample_rate and stream.shape[0] > 0:
            logger.warning(f'tts audio sample rate is { sample_rate }, resample to {self._sample_rate}')
            stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate)

        return stream

    async def __on_request(self, voice, txt):
        communicate = edge_tts.Communicate(txt, voice)
        first = True
        async for chuck in communicate.stream():
            if first:
                first = False

            if chuck['type'] == 'audio':
                self._io_stream.write(chuck['data'])