add asr and nlp
This commit is contained in:
parent
0ed6249f15
commit
f9d45e6d44
4
asr/__init__.py
Normal file
4
asr/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
from .sherpa_ncnn_asr import SherpaNcnnAsr
|
||||||
|
from .asr_observer import AsrObserver
|
37
asr/asr_base.py
Normal file
37
asr/asr_base.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from .asr_observer import AsrObserver
|
||||||
|
|
||||||
|
|
||||||
|
class AsrBase:
|
||||||
|
def __init__(self):
|
||||||
|
self._sample_rate = 32000
|
||||||
|
self._samples_per_read = 100
|
||||||
|
self._observers = []
|
||||||
|
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
self._thread = threading.Thread(target=self._recognize_loop)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def _recognize_loop(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _notify_process(self, message: str):
|
||||||
|
for observer in self._observers:
|
||||||
|
observer.process(message)
|
||||||
|
|
||||||
|
def _notify_complete(self, message: str):
|
||||||
|
for observer in self._observers:
|
||||||
|
observer.completed(message)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._stop_event.set()
|
||||||
|
self._thread.join()
|
||||||
|
|
||||||
|
def attach(self, observer: AsrObserver):
|
||||||
|
self._observers.append(observer)
|
||||||
|
|
||||||
|
def detach(self, observer: AsrObserver):
|
||||||
|
self._observers.remove(observer)
|
13
asr/asr_observer.py
Normal file
13
asr/asr_observer.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class AsrObserver(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def process(self, message: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def completed(self, message: str):
|
||||||
|
pass
|
143
asr/sherpa_ncnn_asr.py
Normal file
143
asr/sherpa_ncnn_asr.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
except ImportError as e:
|
||||||
|
print("Please install sounddevice first. You can use")
|
||||||
|
print()
|
||||||
|
print(" pip install sounddevice")
|
||||||
|
print()
|
||||||
|
print("to install it")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
import sherpa_ncnn
|
||||||
|
|
||||||
|
|
||||||
|
from asr.asr_base import AsrBase
|
||||||
|
|
||||||
|
|
||||||
|
class SherpaNcnnAsr(AsrBase):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._recognizer = self._create_recognizer()
|
||||||
|
|
||||||
|
def _create_recognizer(self):
|
||||||
|
base_path = os.path.join(os.getcwd(), '..', 'data', 'asr', 'sherpa-ncnn',
|
||||||
|
'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
|
||||||
|
recognizer = sherpa_ncnn.Recognizer(
|
||||||
|
tokens=base_path + '/tokens.txt',
|
||||||
|
encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param',
|
||||||
|
encoder_bin=base_path + '/encoder_jit_trace-pnnx.ncnn.bin',
|
||||||
|
decoder_param=base_path + '/decoder_jit_trace-pnnx.ncnn.param',
|
||||||
|
decoder_bin=base_path + '/decoder_jit_trace-pnnx.ncnn.bin',
|
||||||
|
joiner_param=base_path + '/joiner_jit_trace-pnnx.ncnn.param',
|
||||||
|
joiner_bin=base_path + '/joiner_jit_trace-pnnx.ncnn.bin',
|
||||||
|
num_threads=4,
|
||||||
|
decoding_method="modified_beam_search",
|
||||||
|
enable_endpoint_detection=True,
|
||||||
|
rule1_min_trailing_silence=2.4,
|
||||||
|
rule2_min_trailing_silence=1.2,
|
||||||
|
rule3_min_utterance_length=300,
|
||||||
|
hotwords_file="",
|
||||||
|
hotwords_score=1.5,
|
||||||
|
)
|
||||||
|
return recognizer
|
||||||
|
|
||||||
|
def _recognize_loop(self):
|
||||||
|
segment_id = 0
|
||||||
|
last_result = ""
|
||||||
|
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
|
||||||
|
while not self._stop_event.is_set():
|
||||||
|
samples, _ = s.read(self._samples_per_read) # a blocking read
|
||||||
|
samples = samples.reshape(-1)
|
||||||
|
self._recognizer.accept_waveform(self._sample_rate, samples)
|
||||||
|
|
||||||
|
is_endpoint = self._recognizer.is_endpoint
|
||||||
|
|
||||||
|
result = self._recognizer.text
|
||||||
|
if result and (last_result != result):
|
||||||
|
last_result = result
|
||||||
|
print("\r{}:{}".format(segment_id, result), end=".", flush=True)
|
||||||
|
self._notify_process(result)
|
||||||
|
|
||||||
|
if is_endpoint:
|
||||||
|
if result:
|
||||||
|
print("\r{}:{}".format(segment_id, result), flush=True)
|
||||||
|
self._notify_complete(result)
|
||||||
|
segment_id += 1
|
||||||
|
self._recognizer.reset()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Started! Please speak")
|
||||||
|
asr = SherpaNcnnAsr()
|
||||||
|
time.sleep(20)
|
||||||
|
print("Stop! ")
|
||||||
|
asr.stop()
|
||||||
|
|
||||||
|
# print("Started! Please speak")
|
||||||
|
# recognizer = create_recognizer()
|
||||||
|
# sample_rate = recognizer.sample_rate
|
||||||
|
# samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||||
|
# last_result = ""
|
||||||
|
# segment_id = 0
|
||||||
|
#
|
||||||
|
# with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
|
||||||
|
# while True:
|
||||||
|
# samples, _ = s.read(samples_per_read) # a blocking read
|
||||||
|
# samples = samples.reshape(-1)
|
||||||
|
# recognizer.accept_waveform(sample_rate, samples)
|
||||||
|
#
|
||||||
|
# is_endpoint = recognizer.is_endpoint
|
||||||
|
#
|
||||||
|
# result = recognizer.text
|
||||||
|
# if result and (last_result != result):
|
||||||
|
# last_result = result
|
||||||
|
# print("\r{}:{}".format(segment_id, result), end=".", flush=True)
|
||||||
|
#
|
||||||
|
# if is_endpoint:
|
||||||
|
# if result:
|
||||||
|
# print("\r{}:{}".format(segment_id, result), flush=True)
|
||||||
|
# segment_id += 1
|
||||||
|
# recognizer.reset()
|
||||||
|
|
||||||
|
# print("Started! Please speak")
|
||||||
|
# recognizer = create_recognizer()
|
||||||
|
# sample_rate = recognizer.sample_rate
|
||||||
|
# samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||||
|
# last_result = ""
|
||||||
|
# with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
|
||||||
|
# while True:
|
||||||
|
# samples, _ = s.read(samples_per_read) # a blocking read
|
||||||
|
# samples = samples.reshape(-1)
|
||||||
|
# recognizer.accept_waveform(sample_rate, samples)
|
||||||
|
# result = recognizer.text
|
||||||
|
# if last_result != result:
|
||||||
|
# last_result = result
|
||||||
|
# print("\r{}".format(result), end="", flush=True)
|
||||||
|
|
||||||
|
'''
|
||||||
|
if __name__ == "__main__":
|
||||||
|
devices = sd.query_devices()
|
||||||
|
print(devices)
|
||||||
|
default_input_device_idx = sd.default.device[0]
|
||||||
|
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nCaught Ctrl + C. Exiting")
|
||||||
|
|
||||||
|
# devices = sd.query_devices()
|
||||||
|
# print(devices)
|
||||||
|
# default_input_device_idx = sd.default.device[0]
|
||||||
|
# print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||||
|
#
|
||||||
|
# try:
|
||||||
|
# main()
|
||||||
|
# except KeyboardInterrupt:
|
||||||
|
# print("\nCaught Ctrl + C. Exiting")
|
||||||
|
'''
|
@ -1,2 +1,4 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
|
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
|
||||||
from .syncnet import SyncNet_color
|
from .syncnet import SyncNet_color
|
||||||
|
@ -1,2 +1,4 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
|
||||||
|
from .nlp_doubao import DouBao
|
||||||
|
from .nlp_split import PunctuationSplit
|
||||||
|
@ -1,19 +1,28 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from asr import AsrObserver
|
||||||
from utils import AsyncTaskQueue
|
from utils import AsyncTaskQueue
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class NLPBase:
|
class NLPBase(AsrObserver):
|
||||||
def __init__(self):
|
def __init__(self, split):
|
||||||
self._ask_queue = AsyncTaskQueue()
|
self._ask_queue = AsyncTaskQueue()
|
||||||
self._ask_queue.start_worker()
|
self._ask_queue.start_worker()
|
||||||
|
self._split_handle = split
|
||||||
|
|
||||||
async def _request(self, question):
|
async def _request(self, question):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def process(self, message: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def completed(self, message: str):
|
||||||
|
print('complete :', message)
|
||||||
|
self.ask(message)
|
||||||
|
|
||||||
def ask(self, question):
|
def ask(self, question):
|
||||||
self._ask_queue.add_task(self._request(question))
|
self._ask_queue.add_task(self._request(question))
|
||||||
|
|
||||||
|
@ -14,8 +14,8 @@ nlp_queue = Queue()
|
|||||||
|
|
||||||
|
|
||||||
class DouBao(NLPBase):
|
class DouBao(NLPBase):
|
||||||
def __init__(self):
|
def __init__(self, split):
|
||||||
super().__init__()
|
super().__init__(split)
|
||||||
# Access Key ID
|
# Access Key ID
|
||||||
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
|
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
|
||||||
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
|
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
|
||||||
@ -42,13 +42,19 @@ class DouBao(NLPBase):
|
|||||||
],
|
],
|
||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
|
sec = ''
|
||||||
async for completion in stream:
|
async for completion in stream:
|
||||||
# print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
# print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||||
nlp_queue.put(completion.choices[0].delta.content)
|
# nlp_queue.put(completion.choices[0].delta.content)
|
||||||
# print(completion.choices[0].delta.content, end="")
|
# print(completion.choices[0].delta.content, end="")
|
||||||
|
sec = sec + completion.choices[0].delta.content
|
||||||
|
sec, message = self._split_handle.handle(sec)
|
||||||
|
if len(message) > 0:
|
||||||
|
print(message)
|
||||||
|
print(sec)
|
||||||
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||||
|
|
||||||
|
'''
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# print(get_dou_bao_api())
|
# print(get_dou_bao_api())
|
||||||
dou_bao = DouBao()
|
dou_bao = DouBao()
|
||||||
@ -75,4 +81,4 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
|
|
||||||
dou_bao.stop()
|
dou_bao.stop()
|
||||||
|
'''
|
||||||
|
24
nlp/nlp_split.py
Normal file
24
nlp/nlp_split.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
import re
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class NLPSplit(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def handle(self, message: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PunctuationSplit(NLPSplit):
|
||||||
|
def __init__(self):
|
||||||
|
self._pattern = r'[,。、;?!,.!?]'
|
||||||
|
|
||||||
|
def handle(self, message: str):
|
||||||
|
match = re.search(self._pattern, message)
|
||||||
|
if match:
|
||||||
|
pos = match.start() + 1
|
||||||
|
msg = message[:pos]
|
||||||
|
msg = msg.strip()
|
||||||
|
message = message[pos:]
|
||||||
|
return message, msg
|
||||||
|
return message, ''
|
1
test/__init__.py
Normal file
1
test/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
#encoding = utf8
|
43
test/test_asr_nlp.py
Normal file
43
test/test_asr_nlp.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from asr import SherpaNcnnAsr
|
||||||
|
from nlp import PunctuationSplit
|
||||||
|
from nlp.nlp_doubao import DouBao
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
except ImportError as e:
|
||||||
|
print("Please install sounddevice first. You can use")
|
||||||
|
print()
|
||||||
|
print(" pip install sounddevice")
|
||||||
|
print()
|
||||||
|
print("to install it")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Started! Please speak")
|
||||||
|
split = PunctuationSplit()
|
||||||
|
nlp = DouBao(split)
|
||||||
|
asr = SherpaNcnnAsr()
|
||||||
|
asr.attach(nlp)
|
||||||
|
time.sleep(20)
|
||||||
|
print("Stop! ")
|
||||||
|
asr.stop()
|
||||||
|
asr.detach(nlp)
|
||||||
|
nlp.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
devices = sd.query_devices()
|
||||||
|
print(devices)
|
||||||
|
default_input_device_idx = sd.default.device[0]
|
||||||
|
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||||
|
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nCaught Ctrl + C. Exiting")
|
Loading…
Reference in New Issue
Block a user