#encoding = utf8
import time

import librosa
import numpy as np
import requests
import resampy
import soundfile as sf
from io import BytesIO


def download_tts(url):
    file_name = url[3:]
    print(file_name)
    download_url = url
    print('download tts', download_url)
    resp = requests.get(download_url)
    with open('./audio/mp3/' + file_name, 'wb') as mp3:
        mp3.write(resp.content)

    from pydub import AudioSegment
    sound = AudioSegment.from_mp3('./audio/mp3/' + file_name)
    sound.export('./audio/wav/' + file_name + '.wav', format="wav")


def __create_bytes_stream(byte_stream):
    stream, sample_rate = sf.read(byte_stream)  # [T*sample_rate,] float64
    print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
    stream = stream.astype(np.float32)

    if stream.ndim > 1:
        print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
        stream = stream[:, 0]

    if sample_rate != 16000 and stream.shape[0] > 0:
        print(f'[WARN] audio sample rate is {sample_rate}, resampling into {16000}.')
        stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=16000)

    return stream


def test_async_tts(url, content):
    import aiohttp
    import asyncio

    async def fetch_audio():

        data = {
            "model": "tts-1",
            "input": content,
            "voice": "alloy",
            "speed": 1.0
        }

        async with aiohttp.ClientSession() as session:
            async with session.post(url, json=data) as response:
                if response.status == 200:
                    audio_data = BytesIO(await response.read())
                    audio_stream = __create_bytes_stream(audio_data)

                    # 保存为新的音频文件
                    sf.write("output_audio.wav", audio_stream, 16000)
                    print("Audio data received and saved to output_audio.wav")
                else:
                    print("Error:", response.status, await response.text())

    # Run the async function
    asyncio.run(fetch_audio())


def test_sync_tts(url, content):
    data = {
        "model": "tts-1",
        "input": content,
        "voice": "alloy",
        "speed": 1.0
    }
    response = requests.post(url, json=data)
    if response.status_code == 200:
        audio_data = BytesIO(response.content)
        audio_stream = __create_bytes_stream(audio_data)

        # 保存为新的音频文件
        sf.write("output_audio.wav", audio_stream, 16000)
        print("Audio data received and saved to output_audio.wav")
    else:
        print("Error:", response.status_code, response.text)


def main():
    # url = "http://localhost:8082/v1/audio/speech"
    url = "https://tts.mzzsfy.eu.org/v1/audio/speech"
    content = "写了一个高性能tts(文本转声音)工具,5千字仅需5秒,免费使用"
    # test_async_tts(url, content)
    test_sync_tts(url, content)


if __name__ == "__main__":
    try:
        t = time.time()
        main()
        print(f'-------tts time:{time.time() - t:.4f}s')
    except KeyboardInterrupt:
        print("\nCaught Ctrl + C. Exiting")
    except Exception as e:
        print(e)