From f9d45e6d444e5be15b54c62a77fcbe35dad84617 Mon Sep 17 00:00:00 2001 From: brige Date: Wed, 9 Oct 2024 20:19:00 +0800 Subject: [PATCH] add asr and nlp --- asr/__init__.py | 4 ++ asr/asr_base.py | 37 +++++++++++ asr/asr_observer.py | 13 ++++ asr/sherpa_ncnn_asr.py | 143 +++++++++++++++++++++++++++++++++++++++++ models/__init__.py | 4 +- nlp/__init__.py | 2 + nlp/nlp_base.py | 13 +++- nlp/nlp_doubao.py | 16 +++-- nlp/nlp_split.py | 24 +++++++ test/__init__.py | 1 + test/test_asr_nlp.py | 43 +++++++++++++ 11 files changed, 292 insertions(+), 8 deletions(-) create mode 100644 asr/__init__.py create mode 100644 asr/asr_base.py create mode 100644 asr/asr_observer.py create mode 100644 asr/sherpa_ncnn_asr.py create mode 100644 nlp/nlp_split.py create mode 100644 test/__init__.py create mode 100644 test/test_asr_nlp.py diff --git a/asr/__init__.py b/asr/__init__.py new file mode 100644 index 0000000..676292f --- /dev/null +++ b/asr/__init__.py @@ -0,0 +1,4 @@ +#encoding = utf8 + +from .sherpa_ncnn_asr import SherpaNcnnAsr +from .asr_observer import AsrObserver diff --git a/asr/asr_base.py b/asr/asr_base.py new file mode 100644 index 0000000..d59acea --- /dev/null +++ b/asr/asr_base.py @@ -0,0 +1,37 @@ +#encoding = utf8 + +import threading + +from .asr_observer import AsrObserver + + +class AsrBase: + def __init__(self): + self._sample_rate = 32000 + self._samples_per_read = 100 + self._observers = [] + + self._stop_event = threading.Event() + self._thread = threading.Thread(target=self._recognize_loop) + self._thread.start() + + def _recognize_loop(self): + pass + + def _notify_process(self, message: str): + for observer in self._observers: + observer.process(message) + + def _notify_complete(self, message: str): + for observer in self._observers: + observer.completed(message) + + def stop(self): + self._stop_event.set() + self._thread.join() + + def attach(self, observer: AsrObserver): + self._observers.append(observer) + + def detach(self, observer: AsrObserver): + self._observers.remove(observer) diff --git a/asr/asr_observer.py b/asr/asr_observer.py new file mode 100644 index 0000000..6be059d --- /dev/null +++ b/asr/asr_observer.py @@ -0,0 +1,13 @@ +#encoding = utf8 + +from abc import ABC, abstractmethod + + +class AsrObserver(ABC): + @abstractmethod + def process(self, message: str): + pass + + @abstractmethod + def completed(self, message: str): + pass diff --git a/asr/sherpa_ncnn_asr.py b/asr/sherpa_ncnn_asr.py new file mode 100644 index 0000000..6306b93 --- /dev/null +++ b/asr/sherpa_ncnn_asr.py @@ -0,0 +1,143 @@ +#encoding = utf8 + +import os +import sys +import time + +try: + import sounddevice as sd +except ImportError as e: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_ncnn + + +from asr.asr_base import AsrBase + + +class SherpaNcnnAsr(AsrBase): + def __init__(self): + super().__init__() + self._recognizer = self._create_recognizer() + + def _create_recognizer(self): + base_path = os.path.join(os.getcwd(), '..', 'data', 'asr', 'sherpa-ncnn', + 'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23') + recognizer = sherpa_ncnn.Recognizer( + tokens=base_path + '/tokens.txt', + encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param', + encoder_bin=base_path + '/encoder_jit_trace-pnnx.ncnn.bin', + decoder_param=base_path + '/decoder_jit_trace-pnnx.ncnn.param', + decoder_bin=base_path + '/decoder_jit_trace-pnnx.ncnn.bin', + joiner_param=base_path + '/joiner_jit_trace-pnnx.ncnn.param', + joiner_bin=base_path + '/joiner_jit_trace-pnnx.ncnn.bin', + num_threads=4, + decoding_method="modified_beam_search", + enable_endpoint_detection=True, + rule1_min_trailing_silence=2.4, + rule2_min_trailing_silence=1.2, + rule3_min_utterance_length=300, + hotwords_file="", + hotwords_score=1.5, + ) + return recognizer + + def _recognize_loop(self): + segment_id = 0 + last_result = "" + with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s: + while not self._stop_event.is_set(): + samples, _ = s.read(self._samples_per_read) # a blocking read + samples = samples.reshape(-1) + self._recognizer.accept_waveform(self._sample_rate, samples) + + is_endpoint = self._recognizer.is_endpoint + + result = self._recognizer.text + if result and (last_result != result): + last_result = result + print("\r{}:{}".format(segment_id, result), end=".", flush=True) + self._notify_process(result) + + if is_endpoint: + if result: + print("\r{}:{}".format(segment_id, result), flush=True) + self._notify_complete(result) + segment_id += 1 + self._recognizer.reset() + +def main(): + print("Started! Please speak") + asr = SherpaNcnnAsr() + time.sleep(20) + print("Stop! ") + asr.stop() + + # print("Started! Please speak") + # recognizer = create_recognizer() + # sample_rate = recognizer.sample_rate + # samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + # last_result = "" + # segment_id = 0 + # + # with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + # while True: + # samples, _ = s.read(samples_per_read) # a blocking read + # samples = samples.reshape(-1) + # recognizer.accept_waveform(sample_rate, samples) + # + # is_endpoint = recognizer.is_endpoint + # + # result = recognizer.text + # if result and (last_result != result): + # last_result = result + # print("\r{}:{}".format(segment_id, result), end=".", flush=True) + # + # if is_endpoint: + # if result: + # print("\r{}:{}".format(segment_id, result), flush=True) + # segment_id += 1 + # recognizer.reset() + + # print("Started! Please speak") + # recognizer = create_recognizer() + # sample_rate = recognizer.sample_rate + # samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + # last_result = "" + # with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + # while True: + # samples, _ = s.read(samples_per_read) # a blocking read + # samples = samples.reshape(-1) + # recognizer.accept_waveform(sample_rate, samples) + # result = recognizer.text + # if last_result != result: + # last_result = result + # print("\r{}".format(result), end="", flush=True) + +''' +if __name__ == "__main__": + devices = sd.query_devices() + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") + + # devices = sd.query_devices() + # print(devices) + # default_input_device_idx = sd.default.device[0] + # print(f'Use default device: {devices[default_input_device_idx]["name"]}') + # + # try: + # main() + # except KeyboardInterrupt: + # print("\nCaught Ctrl + C. Exiting") +''' diff --git a/models/__init__.py b/models/__init__.py index 4374370..1c144be 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,2 +1,4 @@ +#encoding = utf8 + from .wav2lip import Wav2Lip, Wav2Lip_disc_qual -from .syncnet import SyncNet_color \ No newline at end of file +from .syncnet import SyncNet_color diff --git a/nlp/__init__.py b/nlp/__init__.py index 6bb50fc..050c481 100644 --- a/nlp/__init__.py +++ b/nlp/__init__.py @@ -1,2 +1,4 @@ #encoding = utf8 +from .nlp_doubao import DouBao +from .nlp_split import PunctuationSplit diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py index e2f94aa..b22b62a 100644 --- a/nlp/nlp_base.py +++ b/nlp/nlp_base.py @@ -1,19 +1,28 @@ #encoding = utf8 import logging +from asr import AsrObserver from utils import AsyncTaskQueue logger = logging.getLogger(__name__) -class NLPBase: - def __init__(self): +class NLPBase(AsrObserver): + def __init__(self, split): self._ask_queue = AsyncTaskQueue() self._ask_queue.start_worker() + self._split_handle = split async def _request(self, question): pass + def process(self, message: str): + pass + + def completed(self, message: str): + print('complete :', message) + self.ask(message) + def ask(self, question): self._ask_queue.add_task(self._request(question)) diff --git a/nlp/nlp_doubao.py b/nlp/nlp_doubao.py index 53507d0..872652d 100644 --- a/nlp/nlp_doubao.py +++ b/nlp/nlp_doubao.py @@ -14,8 +14,8 @@ nlp_queue = Queue() class DouBao(NLPBase): - def __init__(self): - super().__init__() + def __init__(self, split): + super().__init__(split) # Access Key ID # AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y # AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc @@ -42,13 +42,19 @@ class DouBao(NLPBase): ], stream=True ) + sec = '' async for completion in stream: # print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') - nlp_queue.put(completion.choices[0].delta.content) + # nlp_queue.put(completion.choices[0].delta.content) # print(completion.choices[0].delta.content, end="") + sec = sec + completion.choices[0].delta.content + sec, message = self._split_handle.handle(sec) + if len(message) > 0: + print(message) + print(sec) print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') - +''' if __name__ == "__main__": # print(get_dou_bao_api()) dou_bao = DouBao() @@ -75,4 +81,4 @@ if __name__ == "__main__": dou_bao.stop() - +''' diff --git a/nlp/nlp_split.py b/nlp/nlp_split.py new file mode 100644 index 0000000..5b01228 --- /dev/null +++ b/nlp/nlp_split.py @@ -0,0 +1,24 @@ +#encoding = utf8 +import re +from abc import ABC, abstractmethod + + +class NLPSplit(ABC): + @abstractmethod + def handle(self, message: str): + pass + + +class PunctuationSplit(NLPSplit): + def __init__(self): + self._pattern = r'[,。、;?!,.!?]' + + def handle(self, message: str): + match = re.search(self._pattern, message) + if match: + pos = match.start() + 1 + msg = message[:pos] + msg = msg.strip() + message = message[pos:] + return message, msg + return message, '' diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..069ea69 --- /dev/null +++ b/test/__init__.py @@ -0,0 +1 @@ +#encoding = utf8 diff --git a/test/test_asr_nlp.py b/test/test_asr_nlp.py new file mode 100644 index 0000000..88b8bee --- /dev/null +++ b/test/test_asr_nlp.py @@ -0,0 +1,43 @@ +#encoding = utf8 + +import sys +import time + +from asr import SherpaNcnnAsr +from nlp import PunctuationSplit +from nlp.nlp_doubao import DouBao + +try: + import sounddevice as sd +except ImportError as e: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + + +def main(): + print("Started! Please speak") + split = PunctuationSplit() + nlp = DouBao(split) + asr = SherpaNcnnAsr() + asr.attach(nlp) + time.sleep(20) + print("Stop! ") + asr.stop() + asr.detach(nlp) + nlp.stop() + + +if __name__ == "__main__": + devices = sd.query_devices() + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting")