#encoding = utf8
import logging
from io import BytesIO

import numpy as np
import soundfile as sf
import edge_tts
import resampy

from .tts_base import TTSBase

logger = logging.getLogger(__name__)


class TTSEdge(TTSBase):
    def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
        super().__init__(handle)
        self._voice = voice
        logger.info(f"TTSEdge init, {voice}")

    async def _on_request(self, txt: str):
        print('_on_request, txt')
        communicate = edge_tts.Communicate(txt, self._voice)
        first = True
        byte_stream = BytesIO()
        async for chunk in communicate.stream():
            if first:
                first = False
            if chunk["type"] == "audio":
                byte_stream.write(chunk["data"])
            elif chunk["type"] == "WordBoundary":
                pass
        await communicate.stream().aclose()
        return byte_stream

    async def _on_handle(self, stream, index):
        stream.seek(0)
        try:
            stream.seek(0)
            byte_stream = self.__create_bytes_stream(stream)
            print('-------tts start push chunk')
            self._handle.on_handle(byte_stream, index)
            stream.seek(0)
            stream.truncate()
            print('-------tts finish push chunk')

        except Exception as e:
            self._handle.on_handle(None, index)
            stream.seek(0)
            stream.truncate()
            print('-------tts finish error:', e)
        stream.close()

    def __create_bytes_stream(self, byte_stream):
        stream, sample_rate = sf.read(byte_stream)  # [T*sample_rate,] float64
        print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
        stream = stream.astype(np.float32)

        if stream.ndim > 1:
            print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
            stream = stream[:, 0]

        if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
            print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
            stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)

        return stream

    async def _on_close(self):
        print('TTSEdge close')
        # if self._byte_stream is not None and not self._byte_stream.closed:
        #     self._byte_stream.close()