#encoding = utf8

import os
import sys
import time

try:
    import sounddevice as sd
except ImportError as e:
    print("Please install sounddevice first. You can use")
    print()
    print("  pip install sounddevice")
    print()
    print("to install it")
    sys.exit(-1)

import sherpa_ncnn


from asr.asr_base import AsrBase


class SherpaNcnnAsr(AsrBase):
    def __init__(self):
        super().__init__()
        self._recognizer = self._create_recognizer()

    def _create_recognizer(self):
        base_path = os.path.join(os.getcwd(), '..', 'data', 'asr', 'sherpa-ncnn',
                                 'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
        recognizer = sherpa_ncnn.Recognizer(
            tokens=base_path + '/tokens.txt',
            encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param',
            encoder_bin=base_path + '/encoder_jit_trace-pnnx.ncnn.bin',
            decoder_param=base_path + '/decoder_jit_trace-pnnx.ncnn.param',
            decoder_bin=base_path + '/decoder_jit_trace-pnnx.ncnn.bin',
            joiner_param=base_path + '/joiner_jit_trace-pnnx.ncnn.param',
            joiner_bin=base_path + '/joiner_jit_trace-pnnx.ncnn.bin',
            num_threads=4,
            decoding_method="modified_beam_search",
            enable_endpoint_detection=True,
            rule1_min_trailing_silence=2.4,
            rule2_min_trailing_silence=1.2,
            rule3_min_utterance_length=300,
            hotwords_file="",
            hotwords_score=1.5,
        )
        return recognizer

    def _recognize_loop(self):
        segment_id = 0
        last_result = ""
        with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
            while not self._stop_event.is_set():
                samples, _ = s.read(self._samples_per_read)  # a blocking read
                samples = samples.reshape(-1)
                self._recognizer.accept_waveform(self._sample_rate, samples)

                is_endpoint = self._recognizer.is_endpoint

                result = self._recognizer.text
                if result and (last_result != result):
                    last_result = result
                    print("\r{}:{}".format(segment_id, result), end=".", flush=True)
                    self._notify_process(result)

                if is_endpoint:
                    if result:
                        print("\r{}:{}".format(segment_id, result), flush=True)
                        self._notify_complete(result)
                        segment_id += 1
                    self._recognizer.reset()

def main():
    print("Started! Please speak")
    asr = SherpaNcnnAsr()
    time.sleep(20)
    print("Stop! ")
    asr.stop()

    # print("Started! Please speak")
    # recognizer = create_recognizer()
    # sample_rate = recognizer.sample_rate
    # samples_per_read = int(0.1 * sample_rate)  # 0.1 second = 100 ms
    # last_result = ""
    # segment_id = 0
    #
    # with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
    #     while True:
    #         samples, _ = s.read(samples_per_read)  # a blocking read
    #         samples = samples.reshape(-1)
    #         recognizer.accept_waveform(sample_rate, samples)
    #
    #         is_endpoint = recognizer.is_endpoint
    #
    #         result = recognizer.text
    #         if result and (last_result != result):
    #             last_result = result
    #             print("\r{}:{}".format(segment_id, result), end=".", flush=True)
    #
    #         if is_endpoint:
    #             if result:
    #                 print("\r{}:{}".format(segment_id, result), flush=True)
    #                 segment_id += 1
    #             recognizer.reset()

    # print("Started! Please speak")
    # recognizer = create_recognizer()
    # sample_rate = recognizer.sample_rate
    # samples_per_read = int(0.1 * sample_rate)  # 0.1 second = 100 ms
    # last_result = ""
    # with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
    #     while True:
    #         samples, _ = s.read(samples_per_read)  # a blocking read
    #         samples = samples.reshape(-1)
    #         recognizer.accept_waveform(sample_rate, samples)
    #         result = recognizer.text
    #         if last_result != result:
    #             last_result = result
    #             print("\r{}".format(result), end="", flush=True)

'''
if __name__ == "__main__":
    devices = sd.query_devices()
    print(devices)
    default_input_device_idx = sd.default.device[0]
    print(f'Use default device: {devices[default_input_device_idx]["name"]}')

    try:
        main()
    except KeyboardInterrupt:
        print("\nCaught Ctrl + C. Exiting")

    # devices = sd.query_devices()
    # print(devices)
    # default_input_device_idx = sd.default.device[0]
    # print(f'Use default device: {devices[default_input_device_idx]["name"]}')
    #
    # try:
    #     main()
    # except KeyboardInterrupt:
    #     print("\nCaught Ctrl + C. Exiting")
'''