modify ui and nlp tts code
This commit is contained in:
parent
6c0733d6b9
commit
ad54248ff3
@ -1,8 +1,7 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
@ -16,18 +15,26 @@ except ImportError as e:
|
|||||||
|
|
||||||
import sherpa_ncnn
|
import sherpa_ncnn
|
||||||
|
|
||||||
|
|
||||||
from asr.asr_base import AsrBase
|
from asr.asr_base import AsrBase
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
class SherpaNcnnAsr(AsrBase):
|
class SherpaNcnnAsr(AsrBase):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._recognizer = self._create_recognizer()
|
self._recognizer = self._create_recognizer()
|
||||||
|
logger.info('SherpaNcnnAsr init')
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.__del__()
|
||||||
|
logger.info('SherpaNcnnAsr del')
|
||||||
|
|
||||||
def _create_recognizer(self):
|
def _create_recognizer(self):
|
||||||
base_path = os.path.join(os.getcwd(), 'data', 'asr', 'sherpa-ncnn',
|
base_path = os.path.join(current_file_path, '..', 'data', 'asr', 'sherpa-ncnn',
|
||||||
'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
|
'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
|
||||||
|
logger.info(f'_create_recognizer init, path:{base_path}')
|
||||||
recognizer = sherpa_ncnn.Recognizer(
|
recognizer = sherpa_ncnn.Recognizer(
|
||||||
tokens=base_path + '/tokens.txt',
|
tokens=base_path + '/tokens.txt',
|
||||||
encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param',
|
encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param',
|
||||||
@ -50,6 +57,7 @@ class SherpaNcnnAsr(AsrBase):
|
|||||||
def _recognize_loop(self):
|
def _recognize_loop(self):
|
||||||
segment_id = 0
|
segment_id = 0
|
||||||
last_result = ""
|
last_result = ""
|
||||||
|
logger.info(f'_recognize_loop')
|
||||||
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
|
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
|
||||||
while not self._stop_event.is_set():
|
while not self._stop_event.is_set():
|
||||||
samples, _ = s.read(self._samples_per_read) # a blocking read
|
samples, _ = s.read(self._samples_per_read) # a blocking read
|
||||||
@ -70,74 +78,3 @@ class SherpaNcnnAsr(AsrBase):
|
|||||||
self._notify_complete(result)
|
self._notify_complete(result)
|
||||||
segment_id += 1
|
segment_id += 1
|
||||||
self._recognizer.reset()
|
self._recognizer.reset()
|
||||||
|
|
||||||
def main():
|
|
||||||
print("Started! Please speak")
|
|
||||||
asr = SherpaNcnnAsr()
|
|
||||||
time.sleep(20)
|
|
||||||
print("Stop! ")
|
|
||||||
asr.stop()
|
|
||||||
|
|
||||||
# print("Started! Please speak")
|
|
||||||
# recognizer = create_recognizer()
|
|
||||||
# sample_rate = recognizer.sample_rate
|
|
||||||
# samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
|
||||||
# last_result = ""
|
|
||||||
# segment_id = 0
|
|
||||||
#
|
|
||||||
# with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
|
|
||||||
# while True:
|
|
||||||
# samples, _ = s.read(samples_per_read) # a blocking read
|
|
||||||
# samples = samples.reshape(-1)
|
|
||||||
# recognizer.accept_waveform(sample_rate, samples)
|
|
||||||
#
|
|
||||||
# is_endpoint = recognizer.is_endpoint
|
|
||||||
#
|
|
||||||
# result = recognizer.text
|
|
||||||
# if result and (last_result != result):
|
|
||||||
# last_result = result
|
|
||||||
# print("\r{}:{}".format(segment_id, result), end=".", flush=True)
|
|
||||||
#
|
|
||||||
# if is_endpoint:
|
|
||||||
# if result:
|
|
||||||
# print("\r{}:{}".format(segment_id, result), flush=True)
|
|
||||||
# segment_id += 1
|
|
||||||
# recognizer.reset()
|
|
||||||
|
|
||||||
# print("Started! Please speak")
|
|
||||||
# recognizer = create_recognizer()
|
|
||||||
# sample_rate = recognizer.sample_rate
|
|
||||||
# samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
|
||||||
# last_result = ""
|
|
||||||
# with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
|
|
||||||
# while True:
|
|
||||||
# samples, _ = s.read(samples_per_read) # a blocking read
|
|
||||||
# samples = samples.reshape(-1)
|
|
||||||
# recognizer.accept_waveform(sample_rate, samples)
|
|
||||||
# result = recognizer.text
|
|
||||||
# if last_result != result:
|
|
||||||
# last_result = result
|
|
||||||
# print("\r{}".format(result), end="", flush=True)
|
|
||||||
|
|
||||||
'''
|
|
||||||
if __name__ == "__main__":
|
|
||||||
devices = sd.query_devices()
|
|
||||||
print(devices)
|
|
||||||
default_input_device_idx = sd.default.device[0]
|
|
||||||
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
|
||||||
|
|
||||||
try:
|
|
||||||
main()
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nCaught Ctrl + C. Exiting")
|
|
||||||
|
|
||||||
# devices = sd.query_devices()
|
|
||||||
# print(devices)
|
|
||||||
# default_input_device_idx = sd.default.device[0]
|
|
||||||
# print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# main()
|
|
||||||
# except KeyboardInterrupt:
|
|
||||||
# print("\nCaught Ctrl + C. Exiting")
|
|
||||||
'''
|
|
||||||
|
@ -55,6 +55,7 @@ def detect(net, img, device):
|
|||||||
|
|
||||||
return bboxlist
|
return bboxlist
|
||||||
|
|
||||||
|
|
||||||
def batch_detect(net, imgs, device):
|
def batch_detect(net, imgs, device):
|
||||||
imgs = imgs - np.array([104, 117, 123])
|
imgs = imgs - np.array([104, 117, 123])
|
||||||
imgs = imgs.transpose(0, 3, 1, 2)
|
imgs = imgs.transpose(0, 3, 1, 2)
|
||||||
@ -93,6 +94,7 @@ def batch_detect(net, imgs, device):
|
|||||||
|
|
||||||
return bboxlist
|
return bboxlist
|
||||||
|
|
||||||
|
|
||||||
def flip_detect(net, img, device):
|
def flip_detect(net, img, device):
|
||||||
img = cv2.flip(img, 1)
|
img = cv2.flip(img, 1)
|
||||||
b = detect(net, img, device)
|
b = detect(net, img, device)
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
|
||||||
from .audio_handler import AudioHandler
|
|
||||||
from .human_context import HumanContext
|
from .human_context import HumanContext
|
||||||
|
from .audio_mal_handler import AudioMalHandler
|
||||||
|
from .audio_inference_handler import AudioInferenceHandler
|
||||||
|
from .human_render import HumanRender
|
||||||
|
@ -1,21 +0,0 @@
|
|||||||
#encoding = utf8
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AudioHandler(ABC):
|
|
||||||
def __init__(self, context, handler):
|
|
||||||
self._context = context
|
|
||||||
self._handler = handler
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def on_handle(self, stream, index):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_next_handle(self, stream, type_):
|
|
||||||
if self._handler is not None:
|
|
||||||
self._handler.on_handle(stream, type_)
|
|
||||||
else:
|
|
||||||
logging.info(f'_handler is None')
|
|
@ -1,4 +1,6 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
@ -7,9 +9,12 @@ from threading import Event, Thread
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .audio_handler import AudioHandler
|
from human_handler import AudioHandler
|
||||||
from utils import load_model, mirror_index, get_device
|
from utils import load_model, mirror_index, get_device
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
class AudioInferenceHandler(AudioHandler):
|
class AudioInferenceHandler(AudioHandler):
|
||||||
def __init__(self, context, handler):
|
def __init__(self, context, handler):
|
||||||
@ -22,6 +27,7 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
self._run_thread = Thread(target=self.__on_run)
|
self._run_thread = Thread(target=self.__on_run)
|
||||||
self._exit_event.set()
|
self._exit_event.set()
|
||||||
self._run_thread.start()
|
self._run_thread.start()
|
||||||
|
logger.info("AudioInferenceHandler init")
|
||||||
|
|
||||||
def on_handle(self, stream, type_):
|
def on_handle(self, stream, type_):
|
||||||
if type_ == 1:
|
if type_ == 1:
|
||||||
@ -30,8 +36,10 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
self._audio_queue.put(stream)
|
self._audio_queue.put(stream)
|
||||||
|
|
||||||
def __on_run(self):
|
def __on_run(self):
|
||||||
model = load_model(r'.\checkpoints\wav2lip.pth')
|
wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'wav2lip.pth')
|
||||||
print("Model loaded")
|
logger.info(f'AudioInferenceHandler init, path:{wav2lip_path}')
|
||||||
|
model = load_model(wav2lip_path)
|
||||||
|
logger.info("Model loaded")
|
||||||
|
|
||||||
face_list_cycle = self._context.face_list_cycle
|
face_list_cycle = self._context.face_list_cycle
|
||||||
|
|
||||||
@ -39,10 +47,10 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
index = 0
|
index = 0
|
||||||
count = 0
|
count = 0
|
||||||
count_time = 0
|
count_time = 0
|
||||||
print('start inference')
|
logger.info('start inference')
|
||||||
|
|
||||||
device = get_device()
|
device = get_device()
|
||||||
print(f'use device:{device}')
|
logger.info(f'use device:{device}')
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if self._exit_event.is_set():
|
if self._exit_event.is_set():
|
||||||
@ -66,7 +74,7 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
0)
|
0)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
else:
|
else:
|
||||||
print('infer=======')
|
logger.info('infer=======')
|
||||||
t = time.perf_counter()
|
t = time.perf_counter()
|
||||||
img_batch = []
|
img_batch = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
@ -95,20 +103,21 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
count += batch_size
|
count += batch_size
|
||||||
|
|
||||||
if count >= 100:
|
if count >= 100:
|
||||||
print(f"------actual avg infer fps:{count / count_time:.4f}")
|
logger.info(f"------actual avg infer fps:{count / count_time:.4f}")
|
||||||
count = 0
|
count = 0
|
||||||
count_time = 0
|
count_time = 0
|
||||||
|
|
||||||
image_index = 0
|
|
||||||
for i, res_frame in enumerate(pred):
|
for i, res_frame in enumerate(pred):
|
||||||
self.on_next_handle(
|
self.on_next_handle(
|
||||||
(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
||||||
0)
|
0)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
image_index = image_index + 1
|
logger.info(f'total batch time: {time.perf_counter() - start_time}')
|
||||||
print('batch count', image_index)
|
|
||||||
print('total batch time:', time.perf_counter() - start_time)
|
|
||||||
else:
|
else:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
break
|
break
|
||||||
print('musereal inference processor stop')
|
logger.info('AudioInferenceHandler inference processor stop')
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._exit_event.clear()
|
||||||
|
self._run_thread.join()
|
||||||
|
@ -7,7 +7,7 @@ from threading import Thread, Event
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .audio_handler import AudioHandler
|
from human_handler import AudioHandler
|
||||||
from utils import melspectrogram
|
from utils import melspectrogram
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -25,6 +25,7 @@ class AudioMalHandler(AudioHandler):
|
|||||||
|
|
||||||
self.frames = []
|
self.frames = []
|
||||||
self.chunk = context.sample_rate // context.fps
|
self.chunk = context.sample_rate // context.fps
|
||||||
|
logger.info("AudioMalHandler init")
|
||||||
|
|
||||||
def on_handle(self, stream, index):
|
def on_handle(self, stream, index):
|
||||||
self._queue.put(stream)
|
self._queue.put(stream)
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from asr import SherpaNcnnAsr
|
from asr import SherpaNcnnAsr
|
||||||
from human.audio_inference_handler import AudioInferenceHandler
|
from .audio_inference_handler import AudioInferenceHandler
|
||||||
from human.audio_mal_handler import AudioMalHandler
|
from .audio_mal_handler import AudioMalHandler
|
||||||
from human.human_render import HumanRender
|
from .human_render import HumanRender
|
||||||
from nlp import PunctuationSplit, DouBao
|
from nlp import PunctuationSplit, DouBao
|
||||||
from tts import TTSEdge, TTSAudioSplitHandle
|
from tts import TTSEdge, TTSAudioSplitHandle
|
||||||
from utils import load_avatar, get_device
|
from utils import load_avatar, get_device
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
class HumanContext:
|
class HumanContext:
|
||||||
@ -23,7 +25,9 @@ class HumanContext:
|
|||||||
|
|
||||||
self._device = get_device()
|
self._device = get_device()
|
||||||
print(f'device:{self._device}')
|
print(f'device:{self._device}')
|
||||||
full_images, face_frames, coord_frames = load_avatar(r'./face/', self._image_size, self._device)
|
base_path = os.path.join(current_file_path, '..', 'face')
|
||||||
|
logger.info(f'_create_recognizer init, path:{base_path}')
|
||||||
|
full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
|
||||||
self._frame_list_cycle = full_images
|
self._frame_list_cycle = full_images
|
||||||
self._face_list_cycle = face_frames
|
self._face_list_cycle = face_frames
|
||||||
self._coord_list_cycle = coord_frames
|
self._coord_list_cycle = coord_frames
|
||||||
@ -31,14 +35,24 @@ class HumanContext:
|
|||||||
logging.info(f'face images length: {face_images_length}')
|
logging.info(f'face images length: {face_images_length}')
|
||||||
print(f'face images length: {face_images_length}')
|
print(f'face images length: {face_images_length}')
|
||||||
|
|
||||||
self.asr = None
|
self._asr = None
|
||||||
self.nlp = None
|
self._nlp = None
|
||||||
self.tts = None
|
self._tts = None
|
||||||
self.tts_handle = None
|
self._tts_handle = None
|
||||||
self.mal_handler = None
|
self._mal_handler = None
|
||||||
self.infer_handler = None
|
self._infer_handler = None
|
||||||
self._render_handler = None
|
self._render_handler = None
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
print(f'HumanContext: __del__')
|
||||||
|
self._asr.stop()
|
||||||
|
self._nlp.stop()
|
||||||
|
self._tts.stop()
|
||||||
|
self._tts_handle.stop()
|
||||||
|
self._mal_handler.stop()
|
||||||
|
self._infer_handler.stop()
|
||||||
|
self._render_handler.stop()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self):
|
def fps(self):
|
||||||
return self._fps
|
return self._fps
|
||||||
@ -81,17 +95,17 @@ class HumanContext:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def render_handler(self):
|
def render_handler(self):
|
||||||
return self.render_handler
|
return self._render_handler
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
self._render_handler = HumanRender(self, None)
|
self._render_handler = HumanRender(self, None)
|
||||||
self.infer_handler = AudioInferenceHandler(self, self._render_handler)
|
self._infer_handler = AudioInferenceHandler(self, self._render_handler)
|
||||||
self.mal_handler = AudioMalHandler(self, self.infer_handler)
|
self._mal_handler = AudioMalHandler(self, self._infer_handler)
|
||||||
self.tts_handle = TTSAudioSplitHandle(self, self.mal_handler)
|
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
|
||||||
self.tts = TTSEdge(self.tts_handle)
|
self._tts = TTSEdge(self._tts_handle)
|
||||||
split = PunctuationSplit()
|
split = PunctuationSplit()
|
||||||
nlp = DouBao(split, self.tts)
|
self._nlp = DouBao(split, self._tts)
|
||||||
self.asr = SherpaNcnnAsr()
|
self._asr = SherpaNcnnAsr()
|
||||||
self.asr.attach(nlp)
|
self._asr.attach(self._nlp)
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,8 +9,7 @@ from threading import Thread, Event
|
|||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from audio_render import AudioRender
|
from human_handler import AudioHandler
|
||||||
from .audio_handler import AudioHandler
|
|
||||||
|
|
||||||
|
|
||||||
class HumanRender(AudioHandler):
|
class HumanRender(AudioHandler):
|
||||||
@ -27,12 +26,12 @@ class HumanRender(AudioHandler):
|
|||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
|
||||||
def _on_run(self):
|
def _on_run(self):
|
||||||
logging.info('chunk2mal run')
|
logging.info('human render run')
|
||||||
while self._exit_event.is_set():
|
while self._exit_event.is_set():
|
||||||
self._run_step()
|
self._run_step()
|
||||||
time.sleep(0.002)
|
time.sleep(0.002)
|
||||||
|
|
||||||
logging.info('chunk2mal exit')
|
logging.info('human render exit')
|
||||||
|
|
||||||
def _run_step(self):
|
def _run_step(self):
|
||||||
try:
|
try:
|
||||||
@ -58,7 +57,7 @@ class HumanRender(AudioHandler):
|
|||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
if self._image_render is not None:
|
if self._image_render is not None:
|
||||||
self._image_render.render(image)
|
self._image_render.on_render(image)
|
||||||
|
|
||||||
for audio_frame in audio_frames:
|
for audio_frame in audio_frames:
|
||||||
frame, type_ = audio_frame
|
frame, type_ = audio_frame
|
||||||
@ -69,7 +68,6 @@ class HumanRender(AudioHandler):
|
|||||||
# new_frame.planes[0].update(frame.tobytes())
|
# new_frame.planes[0].update(frame.tobytes())
|
||||||
# new_frame.sample_rate = 16000
|
# new_frame.sample_rate = 16000
|
||||||
|
|
||||||
|
|
||||||
def set_audio_render(self, render):
|
def set_audio_render(self, render):
|
||||||
self._audio_render = render
|
self._audio_render = render
|
||||||
|
|
||||||
@ -79,4 +77,6 @@ class HumanRender(AudioHandler):
|
|||||||
def on_handle(self, stream, index):
|
def on_handle(self, stream, index):
|
||||||
self._queue.put(stream)
|
self._queue.put(stream)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._exit_event.clear()
|
||||||
|
self._thread.join()
|
||||||
|
@ -12,6 +12,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class DouBao(NLPBase):
|
class DouBao(NLPBase):
|
||||||
def __init__(self, split, callback=None):
|
def __init__(self, split, callback=None):
|
||||||
super().__init__(split, callback)
|
super().__init__(split, callback)
|
||||||
|
logger.info("DouBao init")
|
||||||
# Access Key ID
|
# Access Key ID
|
||||||
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
|
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
|
||||||
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
|
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
|
||||||
@ -30,7 +31,7 @@ class DouBao(NLPBase):
|
|||||||
async def _request(self, question):
|
async def _request(self, question):
|
||||||
t = time.time()
|
t = time.time()
|
||||||
logger.info(f'_request:{question}')
|
logger.info(f'_request:{question}')
|
||||||
print(f'-------dou_bao ask:', question)
|
logger.info(f'-------dou_bao ask:{question}')
|
||||||
try:
|
try:
|
||||||
stream = await self.__client.chat.completions.create(
|
stream = await self.__client.chat.completions.create(
|
||||||
model="ep-20241008152048-fsgzf",
|
model="ep-20241008152048-fsgzf",
|
||||||
@ -51,38 +52,9 @@ class DouBao(NLPBase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
|
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
|
||||||
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||||
|
|
||||||
async def _on_close(self):
|
async def _on_close(self):
|
||||||
print('AsyncArk close')
|
logger.info('AsyncArk close')
|
||||||
if self.__client is not None and not self.__client.is_closed():
|
if self.__client is not None and not self.__client.is_closed():
|
||||||
await self.__client.close()
|
await self.__client.close()
|
||||||
|
|
||||||
'''
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# print(get_dou_bao_api())
|
|
||||||
dou_bao = DouBao()
|
|
||||||
dou_bao.ask('你好。')
|
|
||||||
dou_bao.ask('你好,你是谁?')
|
|
||||||
dou_bao.ask('你能做什么?')
|
|
||||||
dou_bao.ask('介绍一下,我自己。')
|
|
||||||
count = 1000
|
|
||||||
sec = ''
|
|
||||||
while count >= 0:
|
|
||||||
count = count - 1
|
|
||||||
if nlp_queue.empty():
|
|
||||||
time.sleep(0.1)
|
|
||||||
continue
|
|
||||||
sec = sec + nlp_queue.get(block=True, timeout=0.01)
|
|
||||||
|
|
||||||
pattern = r'[,。、;?!,.!?]'
|
|
||||||
match = re.search(pattern, sec)
|
|
||||||
if match:
|
|
||||||
pos = match.start() + 1
|
|
||||||
print(sec[:pos])
|
|
||||||
sec = sec[pos:]
|
|
||||||
print(sec)
|
|
||||||
|
|
||||||
|
|
||||||
dou_bao.stop()
|
|
||||||
'''
|
|
||||||
|
44
test/test_human_context.py
Normal file
44
test/test_human_context.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from human import HumanContext
|
||||||
|
from utils import config_logging
|
||||||
|
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# import sounddevice as sd
|
||||||
|
# except ImportError as e:
|
||||||
|
# print("Please install sounddevice first. You can use")
|
||||||
|
# print()
|
||||||
|
# print(" pip install sounddevice")
|
||||||
|
# print()
|
||||||
|
# print("to install it")
|
||||||
|
# sys.exit(-1)
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("Started! Please speak")
|
||||||
|
human = HumanContext()
|
||||||
|
human.build()
|
||||||
|
time.sleep(60)
|
||||||
|
print("Stop! ")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# devices = sd.query_devices()
|
||||||
|
# print(devices)
|
||||||
|
# default_input_device_idx = sd.default.device[0]
|
||||||
|
# print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||||
|
|
||||||
|
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
log_path = os.path.join(current_file_path, '..', 'logs', 'info.log')
|
||||||
|
config_logging(log_path, logging.INFO, logging.INFO)
|
||||||
|
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nCaught Ctrl + C. Exiting")
|
@ -1,9 +1,13 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
import heapq
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from utils import save_wav
|
from utils import save_wav
|
||||||
from human import AudioHandler
|
from human_handler import AudioHandler
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TTSAudioHandle(AudioHandler):
|
class TTSAudioHandle(AudioHandler):
|
||||||
@ -27,26 +31,38 @@ class TTSAudioHandle(AudioHandler):
|
|||||||
def on_handle(self, stream, index):
|
def on_handle(self, stream, index):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TTSAudioSplitHandle(TTSAudioHandle):
|
class TTSAudioSplitHandle(TTSAudioHandle):
|
||||||
def __init__(self, context, handler):
|
def __init__(self, context, handler):
|
||||||
super().__init__(context, handler)
|
super().__init__(context, handler)
|
||||||
self.sample_rate = self._context.sample_rate
|
self.sample_rate = self._context.sample_rate
|
||||||
self._chunk = self.sample_rate // self._context.fps
|
self._chunk = self.sample_rate // self._context.fps
|
||||||
|
self._priority_queue = []
|
||||||
|
logger.info("TTSAudioSplitHandle init")
|
||||||
|
|
||||||
def on_handle(self, stream, index):
|
def on_handle(self, stream, index):
|
||||||
|
# heapq.heappush(self._priority_queue, (index, stream))
|
||||||
|
if stream is None:
|
||||||
|
heapq.heappush(self._priority_queue, (index, None))
|
||||||
|
|
||||||
stream_len = stream.shape[0]
|
stream_len = stream.shape[0]
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
||||||
while stream_len >= self._chunk:
|
while stream_len >= self._chunk:
|
||||||
self._context.put_audio_frame(stream[idx:idx + self._chunk])
|
self.on_next_handle(stream[idx:idx + self._chunk], 0)
|
||||||
stream_len -= self._chunk
|
stream_len -= self._chunk
|
||||||
idx += self._chunk
|
idx += self._chunk
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TTSAudioSaveHandle(TTSAudioHandle):
|
class TTSAudioSaveHandle(TTSAudioHandle):
|
||||||
def __init__(self):
|
def __init__(self, context, handler):
|
||||||
super().__init__()
|
super().__init__(context, handler)
|
||||||
self._save_path_dir = '../temp/audio/'
|
self._save_path_dir = '../temp/audio/'
|
||||||
self._clean()
|
self._clean()
|
||||||
|
|
||||||
@ -72,3 +88,6 @@ class TTSAudioSaveHandle(TTSAudioHandle):
|
|||||||
file_name = self._save_path_dir + str(index) + '.wav'
|
file_name = self._save_path_dir + str(index) + '.wav'
|
||||||
save_wav(stream, file_name, self.sample_rate)
|
save_wav(stream, file_name, self.sample_rate)
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -9,11 +9,14 @@ import resampy
|
|||||||
|
|
||||||
from .tts_base import TTSBase
|
from .tts_base import TTSBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TTSEdge(TTSBase):
|
class TTSEdge(TTSBase):
|
||||||
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
|
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
|
||||||
super().__init__(handle)
|
super().__init__(handle)
|
||||||
self._voice = voice
|
self._voice = voice
|
||||||
|
logger.info(f"TTSEdge init, {voice}")
|
||||||
|
|
||||||
async def _on_request(self, txt: str):
|
async def _on_request(self, txt: str):
|
||||||
print('_on_request, txt')
|
print('_on_request, txt')
|
||||||
@ -42,6 +45,7 @@ class TTSEdge(TTSBase):
|
|||||||
print('-------tts finish push chunk')
|
print('-------tts finish push chunk')
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self._handle.on_handle(None, index)
|
||||||
stream.seek(0)
|
stream.seek(0)
|
||||||
stream.truncate()
|
stream.truncate()
|
||||||
print('-------tts finish error:', e)
|
print('-------tts finish error:', e)
|
||||||
|
41
ui.py
41
ui.py
@ -16,8 +16,11 @@ from PIL import Image, ImageTk
|
|||||||
|
|
||||||
from playsound import playsound
|
from playsound import playsound
|
||||||
|
|
||||||
|
from audio_render import AudioRender
|
||||||
# from Human import Human
|
# from Human import Human
|
||||||
from human import HumanContext
|
from human import HumanContext
|
||||||
|
from utils import config_logging
|
||||||
|
|
||||||
# from tts.EdgeTTS import EdgeTTS
|
# from tts.EdgeTTS import EdgeTTS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -48,7 +51,7 @@ class App(customtkinter.CTk):
|
|||||||
# self.logo_label.grid(row=0, column=0, padx=20, pady=(20, 10))
|
# self.logo_label.grid(row=0, column=0, padx=20, pady=(20, 10))
|
||||||
|
|
||||||
self.entry = customtkinter.CTkEntry(self, placeholder_text="输入内容")
|
self.entry = customtkinter.CTkEntry(self, placeholder_text="输入内容")
|
||||||
self.entry.insert(0, "大家好,我是九零科技有限公司,虚拟数字人。")
|
self.entry.insert(0, "大家好,测试虚拟数字人。")
|
||||||
self.entry.grid(row=2, column=0, columnspan=2, padx=(20, 0), pady=(20, 20), sticky="nsew")
|
self.entry.grid(row=2, column=0, columnspan=2, padx=(20, 0), pady=(20, 20), sticky="nsew")
|
||||||
|
|
||||||
self.main_button_1 = customtkinter.CTkButton(master=self, fg_color="transparent", border_width=2,
|
self.main_button_1 = customtkinter.CTkButton(master=self, fg_color="transparent", border_width=2,
|
||||||
@ -58,13 +61,14 @@ class App(customtkinter.CTk):
|
|||||||
|
|
||||||
self._init_image_canvas()
|
self._init_image_canvas()
|
||||||
|
|
||||||
self._is_play_audio = False
|
self._audio_render = AudioRender()
|
||||||
# self._human = Human()
|
# self._human = Human()
|
||||||
self._queue = Queue()
|
self._queue = Queue()
|
||||||
self._human_context = HumanContext()
|
self._human_context = HumanContext()
|
||||||
self._human_context.build()
|
self._human_context.build()
|
||||||
render = self._human_context.render_handler
|
render = self._human_context.render_handler
|
||||||
render.set_image_render(self)
|
render.set_image_render(self)
|
||||||
|
render.set_audio_render(self._audio_render)
|
||||||
self._render()
|
self._render()
|
||||||
# self.play_audio()
|
# self.play_audio()
|
||||||
|
|
||||||
@ -84,13 +88,14 @@ class App(customtkinter.CTk):
|
|||||||
self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES)
|
self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES)
|
||||||
|
|
||||||
def _render(self):
|
def _render(self):
|
||||||
|
after_time = 29
|
||||||
try:
|
try:
|
||||||
image = self._queue.get()
|
image = self._queue.get(block=True, timeout=0.003)
|
||||||
if image is None:
|
if image is None:
|
||||||
self.after(20, self._render)
|
self.after(after_time, self._render)
|
||||||
return
|
return
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
self.after(20, self._render)
|
self.after(after_time, self._render)
|
||||||
return
|
return
|
||||||
|
|
||||||
iheight, iwidth = image.shape[0], image.shape[1]
|
iheight, iwidth = image.shape[0], image.shape[1]
|
||||||
@ -111,7 +116,7 @@ class App(customtkinter.CTk):
|
|||||||
height = self.winfo_height() * 0.5
|
height = self.winfo_height() * 0.5
|
||||||
self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk)
|
self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk)
|
||||||
self._canvas.update()
|
self._canvas.update()
|
||||||
self.after(20, self._render)
|
self.after(after_time, self._render)
|
||||||
|
|
||||||
def request_tts(self):
|
def request_tts(self):
|
||||||
content = self.entry.get()
|
content = self.entry.get()
|
||||||
@ -148,27 +153,9 @@ class App(customtkinter.CTk):
|
|||||||
sound = AudioSegment.from_mp3('./audio/mp3/' + file_name)
|
sound = AudioSegment.from_mp3('./audio/mp3/' + file_name)
|
||||||
sound.export('./audio/wav/' + file_name + '.wav', format="wav")
|
sound.export('./audio/wav/' + file_name + '.wav', format="wav")
|
||||||
|
|
||||||
# open('./audio/', 'wb') with
|
def on_render(self, image):
|
||||||
|
self._queue.put(image)
|
||||||
|
print('on_render', self._queue.qsize())
|
||||||
def config_logging(file_name: str, console_level: int=logging.INFO, file_level: int=logging.DEBUG):
|
|
||||||
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
|
|
||||||
file_handler.setFormatter(logging.Formatter(
|
|
||||||
'%(asctime)s [%(levelname)s] %(module)s.%(lineno)d %(name)s:\t%(message)s'
|
|
||||||
))
|
|
||||||
file_handler.setLevel(file_level)
|
|
||||||
|
|
||||||
console_handler = logging.StreamHandler()
|
|
||||||
console_handler.setFormatter(logging.Formatter(
|
|
||||||
'[%(asctime)s %(levelname)s] %(message)s',
|
|
||||||
datefmt="%Y/%m/%d %H:%M:%S"
|
|
||||||
))
|
|
||||||
console_handler.setLevel(console_level)
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=min(console_level, file_level),
|
|
||||||
handlers=[file_handler, console_handler],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
|
||||||
from .async_task_queue import AsyncTaskQueue
|
from .async_task_queue import AsyncTaskQueue
|
||||||
from .utils import mirror_index, load_model, get_device, load_avatar
|
from .utils import mirror_index, load_model, get_device, load_avatar, config_logging
|
||||||
from .audio_utils import melspectrogram, save_wav
|
from .audio_utils import melspectrogram, save_wav
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ def read_files_path(path):
|
|||||||
files = os.listdir(path)
|
files = os.listdir(path)
|
||||||
for file in files:
|
for file in files:
|
||||||
if not os.path.isdir(file):
|
if not os.path.isdir(file):
|
||||||
file_paths.append(path + file)
|
file_paths.append(os.path.join(path, file))
|
||||||
return file_paths
|
return file_paths
|
||||||
|
|
||||||
|
|
||||||
@ -160,6 +160,7 @@ def load_model(path):
|
|||||||
|
|
||||||
|
|
||||||
def load_avatar(path, img_size, device):
|
def load_avatar(path, img_size, device):
|
||||||
|
print(f'load avatar:{path}')
|
||||||
face_images_path = path
|
face_images_path = path
|
||||||
face_images_path = read_files_path(face_images_path)
|
face_images_path = read_files_path(face_images_path)
|
||||||
full_list_cycle = read_images(face_images_path)
|
full_list_cycle = read_images(face_images_path)
|
||||||
@ -174,3 +175,23 @@ def load_avatar(path, img_size, device):
|
|||||||
coord_frames.append(coord)
|
coord_frames.append(coord)
|
||||||
|
|
||||||
return full_list_cycle, face_frames, coord_frames
|
return full_list_cycle, face_frames, coord_frames
|
||||||
|
|
||||||
|
|
||||||
|
def config_logging(file_name: str, console_level: int=logging.INFO, file_level: int=logging.DEBUG):
|
||||||
|
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
|
||||||
|
file_handler.setFormatter(logging.Formatter(
|
||||||
|
'%(asctime)s [%(levelname)s] %(module)s.%(lineno)d %(name)s:\t%(message)s'
|
||||||
|
))
|
||||||
|
file_handler.setLevel(file_level)
|
||||||
|
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setFormatter(logging.Formatter(
|
||||||
|
'[%(asctime)s %(levelname)s] %(message)s',
|
||||||
|
datefmt="%Y/%m/%d %H:%M:%S"
|
||||||
|
))
|
||||||
|
console_handler.setLevel(console_level)
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=min(console_level, file_level),
|
||||||
|
handlers=[file_handler, console_handler],
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user