#encoding = utf8
import logging
import os
import sys
import time

try:
    import sounddevice as sd
except ImportError as e:
    print("Please install sounddevice first. You can use")
    print()
    print("  pip install sounddevice")
    print()
    print("to install it")
    sys.exit(-1)

import sherpa_ncnn

from asr.asr_base import AsrBase
logger = logging.getLogger(__name__)

current_file_path = os.path.dirname(os.path.abspath(__file__))


class SherpaNcnnAsr(AsrBase):
    def __init__(self):
        super().__init__()
        self._recognizer = self._create_recognizer()
        logger.info('SherpaNcnnAsr init')

    def __del__(self):
        self.__del__()
        logger.info('SherpaNcnnAsr del')

    def _create_recognizer(self):
        base_path = os.path.join(current_file_path, '..', 'data', 'asr', 'sherpa-ncnn',
                                 'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
        logger.info(f'_create_recognizer init, path:{base_path}')
        recognizer = sherpa_ncnn.Recognizer(
            tokens=base_path + '/tokens.txt',
            encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param',
            encoder_bin=base_path + '/encoder_jit_trace-pnnx.ncnn.bin',
            decoder_param=base_path + '/decoder_jit_trace-pnnx.ncnn.param',
            decoder_bin=base_path + '/decoder_jit_trace-pnnx.ncnn.bin',
            joiner_param=base_path + '/joiner_jit_trace-pnnx.ncnn.param',
            joiner_bin=base_path + '/joiner_jit_trace-pnnx.ncnn.bin',
            num_threads=4,
            decoding_method='modified_beam_search',
            enable_endpoint_detection=True,
            rule1_min_trailing_silence=2.4,
            rule2_min_trailing_silence=1.2,
            rule3_min_utterance_length=300,
            hotwords_file=self._hot_words_file,
            hotwords_score=1.5,
        )
        return recognizer

    def _recognize_loop(self):
        segment_id = 0
        last_result = ""
        logger.info(f'_recognize_loop')
        while not self._stop_event.is_set():
            self._notify_complete('中国人民万岁')
            segment_id += 1
            time.sleep(10)
        #
        # with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
        #     while not self._stop_event.is_set():
        #         samples, _ = s.read(self._samples_per_read)  # a blocking read
        #         samples = samples.reshape(-1)
        #         self._recognizer.accept_waveform(self._sample_rate, samples)
        #
        #         is_endpoint = self._recognizer.is_endpoint
        #
        #         result = self._recognizer.text
        #         if result and (last_result != result):
        #             last_result = result
        #             print("\r{}:{}".format(segment_id, result), end=".", flush=True)
        #             self._notify_process(result)
        #
        #         if is_endpoint:
        #             if result:
        #                 print("\r{}:{}".format(segment_id, result), flush=True)
        #                 self._notify_complete(result)
        #                 segment_id += 1
        #             self._recognizer.reset()