#encoding = utf8
import logging
import time
from io import BytesIO

import aiohttp
import numpy as np
import requests
import soundfile as sf
import edge_tts
import resampy

from .tts_base import TTSBase

logger = logging.getLogger(__name__)


class TTSEdgeHttp(TTSBase):
    def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
        super().__init__(handle)
        self._voice = voice
        # self._url = 'http://localhost:8082/v1/audio/speech'
        self._url = 'https://tts.mzzsfy.eu.org/v1/audio/speech'
        logger.info(f"TTSEdge init, {voice}")
        self._response_list = []

    def _on_async_request(self, data):
        with aiohttp.ClientSession() as session:
            with session.post(self._url, json=data) as response:
                print('TTSEdgeHttp, _on_request, response:', response)
                if response.status == 200:
                    stream = BytesIO(response.read())
                    return stream
                else:
                    byte_stream = None
                    return byte_stream, None

    def _on_sync_request(self, data):
        response = requests.post(self._url, json=data)
        self._response_list.append(response)
        stream = None
        if response.status_code == 200:
            stream = BytesIO(response.content)
        self._response_list.remove(response)
        return stream

    def _on_request(self, txt: str):
        logger.info(f'TTSEdgeHttp, _on_request, txt:{txt}')
        data = {
            "model": "tts-1",
            "input": txt,
            "voice": "alloy",
            "speed": 1.0,
            "thread": 10
        }

        # return self._on_async_request(data)
        return self._on_sync_request(data)

    def _on_handle(self, stream, index):
        st, txt = stream
        try:
            st.seek(0)
            t = time.time()
            byte_stream = self.__create_bytes_stream(st)
            logger.info(f'-------tts resample time:{time.time() - t:.4f}s, txt:{txt}')
            t = time.time()
            self._handle.on_handle((byte_stream, txt), index)
            logger.info(f'-------tts handle time:{time.time() - t:.4f}s')
            st.seek(0)
            st.truncate()

        except Exception as e:
            self._handle.on_handle(None, index)
            st.seek(0)
            st.truncate()
            logger.error(f'-------tts finish error:{e}')
        st.close()

    def __create_bytes_stream(self, byte_stream):
        stream, sample_rate = sf.read(byte_stream)  # [T*sample_rate,] float64
        logging.info(f'tts audio stream {sample_rate}: {stream.shape}')
        stream = stream.astype(np.float32)

        if stream.ndim > 1:
            logging.warning(f'audio has {stream.shape[1]} channels, only use the first.')
            stream = stream[:, 0]

        if sample_rate != self._handle.sample_rate and stream.shape[0] > 0:
            logging.warning(f'audio sample rate is {sample_rate}, resampling into {self._handle.sample_rate}.')
            stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._handle.sample_rate)

        return stream

    def _on_close(self):
        print('TTSEdge close')
        # if self._byte_stream is not None and not self._byte_stream.closed:
        #     self._byte_stream.close()

    def on_clear(self):
        logger.info('TTSEdgeHttp clear_cache')
        for response in self._response_list:
            response.close()
        super().on_clear()