add asr and nlp

This commit is contained in:
brige 2024-10-09 20:19:00 +08:00
parent 0ed6249f15
commit f9d45e6d44
11 changed files with 292 additions and 8 deletions

4
asr/__init__.py Normal file
View File

@ -0,0 +1,4 @@
#encoding = utf8
from .sherpa_ncnn_asr import SherpaNcnnAsr
from .asr_observer import AsrObserver

37
asr/asr_base.py Normal file
View File

@ -0,0 +1,37 @@
#encoding = utf8
import threading
from .asr_observer import AsrObserver
class AsrBase:
def __init__(self):
self._sample_rate = 32000
self._samples_per_read = 100
self._observers = []
self._stop_event = threading.Event()
self._thread = threading.Thread(target=self._recognize_loop)
self._thread.start()
def _recognize_loop(self):
pass
def _notify_process(self, message: str):
for observer in self._observers:
observer.process(message)
def _notify_complete(self, message: str):
for observer in self._observers:
observer.completed(message)
def stop(self):
self._stop_event.set()
self._thread.join()
def attach(self, observer: AsrObserver):
self._observers.append(observer)
def detach(self, observer: AsrObserver):
self._observers.remove(observer)

13
asr/asr_observer.py Normal file
View File

@ -0,0 +1,13 @@
#encoding = utf8
from abc import ABC, abstractmethod
class AsrObserver(ABC):
@abstractmethod
def process(self, message: str):
pass
@abstractmethod
def completed(self, message: str):
pass

143
asr/sherpa_ncnn_asr.py Normal file
View File

@ -0,0 +1,143 @@
#encoding = utf8
import os
import sys
import time
try:
import sounddevice as sd
except ImportError as e:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
import sherpa_ncnn
from asr.asr_base import AsrBase
class SherpaNcnnAsr(AsrBase):
def __init__(self):
super().__init__()
self._recognizer = self._create_recognizer()
def _create_recognizer(self):
base_path = os.path.join(os.getcwd(), '..', 'data', 'asr', 'sherpa-ncnn',
'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
recognizer = sherpa_ncnn.Recognizer(
tokens=base_path + '/tokens.txt',
encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param',
encoder_bin=base_path + '/encoder_jit_trace-pnnx.ncnn.bin',
decoder_param=base_path + '/decoder_jit_trace-pnnx.ncnn.param',
decoder_bin=base_path + '/decoder_jit_trace-pnnx.ncnn.bin',
joiner_param=base_path + '/joiner_jit_trace-pnnx.ncnn.param',
joiner_bin=base_path + '/joiner_jit_trace-pnnx.ncnn.bin',
num_threads=4,
decoding_method="modified_beam_search",
enable_endpoint_detection=True,
rule1_min_trailing_silence=2.4,
rule2_min_trailing_silence=1.2,
rule3_min_utterance_length=300,
hotwords_file="",
hotwords_score=1.5,
)
return recognizer
def _recognize_loop(self):
segment_id = 0
last_result = ""
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
while not self._stop_event.is_set():
samples, _ = s.read(self._samples_per_read) # a blocking read
samples = samples.reshape(-1)
self._recognizer.accept_waveform(self._sample_rate, samples)
is_endpoint = self._recognizer.is_endpoint
result = self._recognizer.text
if result and (last_result != result):
last_result = result
print("\r{}:{}".format(segment_id, result), end=".", flush=True)
self._notify_process(result)
if is_endpoint:
if result:
print("\r{}:{}".format(segment_id, result), flush=True)
self._notify_complete(result)
segment_id += 1
self._recognizer.reset()
def main():
print("Started! Please speak")
asr = SherpaNcnnAsr()
time.sleep(20)
print("Stop! ")
asr.stop()
# print("Started! Please speak")
# recognizer = create_recognizer()
# sample_rate = recognizer.sample_rate
# samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
# last_result = ""
# segment_id = 0
#
# with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
# while True:
# samples, _ = s.read(samples_per_read) # a blocking read
# samples = samples.reshape(-1)
# recognizer.accept_waveform(sample_rate, samples)
#
# is_endpoint = recognizer.is_endpoint
#
# result = recognizer.text
# if result and (last_result != result):
# last_result = result
# print("\r{}:{}".format(segment_id, result), end=".", flush=True)
#
# if is_endpoint:
# if result:
# print("\r{}:{}".format(segment_id, result), flush=True)
# segment_id += 1
# recognizer.reset()
# print("Started! Please speak")
# recognizer = create_recognizer()
# sample_rate = recognizer.sample_rate
# samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
# last_result = ""
# with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
# while True:
# samples, _ = s.read(samples_per_read) # a blocking read
# samples = samples.reshape(-1)
# recognizer.accept_waveform(sample_rate, samples)
# result = recognizer.text
# if last_result != result:
# last_result = result
# print("\r{}".format(result), end="", flush=True)
'''
if __name__ == "__main__":
devices = sd.query_devices()
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")
# devices = sd.query_devices()
# print(devices)
# default_input_device_idx = sd.default.device[0]
# print(f'Use default device: {devices[default_input_device_idx]["name"]}')
#
# try:
# main()
# except KeyboardInterrupt:
# print("\nCaught Ctrl + C. Exiting")
'''

View File

@ -1,2 +1,4 @@
#encoding = utf8
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
from .syncnet import SyncNet_color
from .syncnet import SyncNet_color

View File

@ -1,2 +1,4 @@
#encoding = utf8
from .nlp_doubao import DouBao
from .nlp_split import PunctuationSplit

View File

@ -1,19 +1,28 @@
#encoding = utf8
import logging
from asr import AsrObserver
from utils import AsyncTaskQueue
logger = logging.getLogger(__name__)
class NLPBase:
def __init__(self):
class NLPBase(AsrObserver):
def __init__(self, split):
self._ask_queue = AsyncTaskQueue()
self._ask_queue.start_worker()
self._split_handle = split
async def _request(self, question):
pass
def process(self, message: str):
pass
def completed(self, message: str):
print('complete :', message)
self.ask(message)
def ask(self, question):
self._ask_queue.add_task(self._request(question))

View File

@ -14,8 +14,8 @@ nlp_queue = Queue()
class DouBao(NLPBase):
def __init__(self):
super().__init__()
def __init__(self, split):
super().__init__(split)
# Access Key ID
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
@ -42,13 +42,19 @@ class DouBao(NLPBase):
],
stream=True
)
sec = ''
async for completion in stream:
# print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
nlp_queue.put(completion.choices[0].delta.content)
# nlp_queue.put(completion.choices[0].delta.content)
# print(completion.choices[0].delta.content, end="")
sec = sec + completion.choices[0].delta.content
sec, message = self._split_handle.handle(sec)
if len(message) > 0:
print(message)
print(sec)
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
'''
if __name__ == "__main__":
# print(get_dou_bao_api())
dou_bao = DouBao()
@ -75,4 +81,4 @@ if __name__ == "__main__":
dou_bao.stop()
'''

24
nlp/nlp_split.py Normal file
View File

@ -0,0 +1,24 @@
#encoding = utf8
import re
from abc import ABC, abstractmethod
class NLPSplit(ABC):
@abstractmethod
def handle(self, message: str):
pass
class PunctuationSplit(NLPSplit):
def __init__(self):
self._pattern = r'[,。、;?!,.!?]'
def handle(self, message: str):
match = re.search(self._pattern, message)
if match:
pos = match.start() + 1
msg = message[:pos]
msg = msg.strip()
message = message[pos:]
return message, msg
return message, ''

1
test/__init__.py Normal file
View File

@ -0,0 +1 @@
#encoding = utf8

43
test/test_asr_nlp.py Normal file
View File

@ -0,0 +1,43 @@
#encoding = utf8
import sys
import time
from asr import SherpaNcnnAsr
from nlp import PunctuationSplit
from nlp.nlp_doubao import DouBao
try:
import sounddevice as sd
except ImportError as e:
print("Please install sounddevice first. You can use")
print()
print(" pip install sounddevice")
print()
print("to install it")
sys.exit(-1)
def main():
print("Started! Please speak")
split = PunctuationSplit()
nlp = DouBao(split)
asr = SherpaNcnnAsr()
asr.attach(nlp)
time.sleep(20)
print("Stop! ")
asr.stop()
asr.detach(nlp)
nlp.stop()
if __name__ == "__main__":
devices = sd.query_devices()
print(devices)
default_input_device_idx = sd.default.device[0]
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
try:
main()
except KeyboardInterrupt:
print("\nCaught Ctrl + C. Exiting")