#encoding = utf8
import logging
import os
import sys
import time

try:
    import sounddevice as sd
except ImportError as e:
    print("Please install sounddevice first. You can use")
    print()
    print("  pip install sounddevice")
    print()
    print("to install it")
    sys.exit(-1)

import sherpa_ncnn

from asr.asr_base import AsrBase
logger = logging.getLogger(__name__)

current_file_path = os.path.dirname(os.path.abspath(__file__))


class SherpaNcnnAsr(AsrBase):
    def __init__(self):
        super().__init__()
        self._recognizer = self._create_recognizer()
        logger.info('SherpaNcnnAsr init')
        print('SherpaNcnnAsr init')

    def __del__(self):
        self.__del__()
        logger.info('SherpaNcnnAsr del')

    def _create_recognizer(self):
        base_path = os.path.join(current_file_path, '..', 'data', 'asr', 'sherpa-ncnn',
                                 'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
        logger.info(f'_create_recognizer init, path:{base_path}')
        recognizer = sherpa_ncnn.Recognizer(
            tokens=base_path + '/tokens.txt',
            encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param',
            encoder_bin=base_path + '/encoder_jit_trace-pnnx.ncnn.bin',
            decoder_param=base_path + '/decoder_jit_trace-pnnx.ncnn.param',
            decoder_bin=base_path + '/decoder_jit_trace-pnnx.ncnn.bin',
            joiner_param=base_path + '/joiner_jit_trace-pnnx.ncnn.param',
            joiner_bin=base_path + '/joiner_jit_trace-pnnx.ncnn.bin',
            num_threads=4,
            decoding_method='modified_beam_search',
            enable_endpoint_detection=True,
            rule1_min_trailing_silence=2.4,
            rule2_min_trailing_silence=1.2,
            rule3_min_utterance_length=300,
            hotwords_file=self._hot_words_file,
            hotwords_score=1.5,
        )
        return recognizer

    def _recognize_loop(self):
        segment_id = 0
        time.sleep(3)
        last_result = ""
        logger.info(f'_recognize_loop')
        print(f'_recognize_loop')

        with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
            while self._stop_event.is_set():
                samples, _ = s.read(self._samples_per_read)  # a blocking read
                samples = samples.reshape(-1)
                self._recognizer.accept_waveform(self._sample_rate, samples)

                is_endpoint = self._recognizer.is_endpoint

                result = self._recognizer.text
                if result and (last_result != result):
                    last_result = result
                    print("\r{}:{}".format(segment_id, result), end=".", flush=True)
                    self._notify_process(result)

                if is_endpoint:
                    if result:
                        print("\r{}:{}".format(segment_id, result), flush=True)
                        self._notify_complete(result)
                        segment_id += 1
                    self._recognizer.reset()
'''

        while self._stop_event.is_set():
            logger.info(f'_recognize_loop000')
            self._notify_complete('介绍中国5000年历史文学')
            logger.info(f'_recognize_loop111')
            segment_id += 1
            time.sleep(150)
            logger.info(f'_recognize_loop222')
        logger.info(f'_recognize_loop exit')
'''