modify ui and nlp tts code
This commit is contained in:
parent
6c0733d6b9
commit
ad54248ff3
@ -1,8 +1,7 @@
|
||||
#encoding = utf8
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
@ -16,18 +15,26 @@ except ImportError as e:
|
||||
|
||||
import sherpa_ncnn
|
||||
|
||||
|
||||
from asr.asr_base import AsrBase
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class SherpaNcnnAsr(AsrBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._recognizer = self._create_recognizer()
|
||||
logger.info('SherpaNcnnAsr init')
|
||||
|
||||
def __del__(self):
|
||||
self.__del__()
|
||||
logger.info('SherpaNcnnAsr del')
|
||||
|
||||
def _create_recognizer(self):
|
||||
base_path = os.path.join(os.getcwd(), 'data', 'asr', 'sherpa-ncnn',
|
||||
base_path = os.path.join(current_file_path, '..', 'data', 'asr', 'sherpa-ncnn',
|
||||
'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23')
|
||||
logger.info(f'_create_recognizer init, path:{base_path}')
|
||||
recognizer = sherpa_ncnn.Recognizer(
|
||||
tokens=base_path + '/tokens.txt',
|
||||
encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param',
|
||||
@ -50,6 +57,7 @@ class SherpaNcnnAsr(AsrBase):
|
||||
def _recognize_loop(self):
|
||||
segment_id = 0
|
||||
last_result = ""
|
||||
logger.info(f'_recognize_loop')
|
||||
with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s:
|
||||
while not self._stop_event.is_set():
|
||||
samples, _ = s.read(self._samples_per_read) # a blocking read
|
||||
@ -70,74 +78,3 @@ class SherpaNcnnAsr(AsrBase):
|
||||
self._notify_complete(result)
|
||||
segment_id += 1
|
||||
self._recognizer.reset()
|
||||
|
||||
def main():
|
||||
print("Started! Please speak")
|
||||
asr = SherpaNcnnAsr()
|
||||
time.sleep(20)
|
||||
print("Stop! ")
|
||||
asr.stop()
|
||||
|
||||
# print("Started! Please speak")
|
||||
# recognizer = create_recognizer()
|
||||
# sample_rate = recognizer.sample_rate
|
||||
# samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||
# last_result = ""
|
||||
# segment_id = 0
|
||||
#
|
||||
# with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
|
||||
# while True:
|
||||
# samples, _ = s.read(samples_per_read) # a blocking read
|
||||
# samples = samples.reshape(-1)
|
||||
# recognizer.accept_waveform(sample_rate, samples)
|
||||
#
|
||||
# is_endpoint = recognizer.is_endpoint
|
||||
#
|
||||
# result = recognizer.text
|
||||
# if result and (last_result != result):
|
||||
# last_result = result
|
||||
# print("\r{}:{}".format(segment_id, result), end=".", flush=True)
|
||||
#
|
||||
# if is_endpoint:
|
||||
# if result:
|
||||
# print("\r{}:{}".format(segment_id, result), flush=True)
|
||||
# segment_id += 1
|
||||
# recognizer.reset()
|
||||
|
||||
# print("Started! Please speak")
|
||||
# recognizer = create_recognizer()
|
||||
# sample_rate = recognizer.sample_rate
|
||||
# samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||
# last_result = ""
|
||||
# with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s:
|
||||
# while True:
|
||||
# samples, _ = s.read(samples_per_read) # a blocking read
|
||||
# samples = samples.reshape(-1)
|
||||
# recognizer.accept_waveform(sample_rate, samples)
|
||||
# result = recognizer.text
|
||||
# if last_result != result:
|
||||
# last_result = result
|
||||
# print("\r{}".format(result), end="", flush=True)
|
||||
|
||||
'''
|
||||
if __name__ == "__main__":
|
||||
devices = sd.query_devices()
|
||||
print(devices)
|
||||
default_input_device_idx = sd.default.device[0]
|
||||
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nCaught Ctrl + C. Exiting")
|
||||
|
||||
# devices = sd.query_devices()
|
||||
# print(devices)
|
||||
# default_input_device_idx = sd.default.device[0]
|
||||
# print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||
#
|
||||
# try:
|
||||
# main()
|
||||
# except KeyboardInterrupt:
|
||||
# print("\nCaught Ctrl + C. Exiting")
|
||||
'''
|
||||
|
@ -55,6 +55,7 @@ def detect(net, img, device):
|
||||
|
||||
return bboxlist
|
||||
|
||||
|
||||
def batch_detect(net, imgs, device):
|
||||
imgs = imgs - np.array([104, 117, 123])
|
||||
imgs = imgs.transpose(0, 3, 1, 2)
|
||||
@ -93,6 +94,7 @@ def batch_detect(net, imgs, device):
|
||||
|
||||
return bboxlist
|
||||
|
||||
|
||||
def flip_detect(net, img, device):
|
||||
img = cv2.flip(img, 1)
|
||||
b = detect(net, img, device)
|
||||
|
@ -1,4 +1,6 @@
|
||||
#encoding = utf8
|
||||
|
||||
from .audio_handler import AudioHandler
|
||||
from .human_context import HumanContext
|
||||
from .audio_mal_handler import AudioMalHandler
|
||||
from .audio_inference_handler import AudioInferenceHandler
|
||||
from .human_render import HumanRender
|
||||
|
@ -1,21 +0,0 @@
|
||||
#encoding = utf8
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioHandler(ABC):
|
||||
def __init__(self, context, handler):
|
||||
self._context = context
|
||||
self._handler = handler
|
||||
|
||||
@abstractmethod
|
||||
def on_handle(self, stream, index):
|
||||
pass
|
||||
|
||||
def on_next_handle(self, stream, type_):
|
||||
if self._handler is not None:
|
||||
self._handler.on_handle(stream, type_)
|
||||
else:
|
||||
logging.info(f'_handler is None')
|
@ -1,4 +1,6 @@
|
||||
#encoding = utf8
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import time
|
||||
from queue import Queue
|
||||
@ -7,9 +9,12 @@ from threading import Event, Thread
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .audio_handler import AudioHandler
|
||||
from human_handler import AudioHandler
|
||||
from utils import load_model, mirror_index, get_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class AudioInferenceHandler(AudioHandler):
|
||||
def __init__(self, context, handler):
|
||||
@ -22,6 +27,7 @@ class AudioInferenceHandler(AudioHandler):
|
||||
self._run_thread = Thread(target=self.__on_run)
|
||||
self._exit_event.set()
|
||||
self._run_thread.start()
|
||||
logger.info("AudioInferenceHandler init")
|
||||
|
||||
def on_handle(self, stream, type_):
|
||||
if type_ == 1:
|
||||
@ -30,8 +36,10 @@ class AudioInferenceHandler(AudioHandler):
|
||||
self._audio_queue.put(stream)
|
||||
|
||||
def __on_run(self):
|
||||
model = load_model(r'.\checkpoints\wav2lip.pth')
|
||||
print("Model loaded")
|
||||
wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'wav2lip.pth')
|
||||
logger.info(f'AudioInferenceHandler init, path:{wav2lip_path}')
|
||||
model = load_model(wav2lip_path)
|
||||
logger.info("Model loaded")
|
||||
|
||||
face_list_cycle = self._context.face_list_cycle
|
||||
|
||||
@ -39,10 +47,10 @@ class AudioInferenceHandler(AudioHandler):
|
||||
index = 0
|
||||
count = 0
|
||||
count_time = 0
|
||||
print('start inference')
|
||||
logger.info('start inference')
|
||||
|
||||
device = get_device()
|
||||
print(f'use device:{device}')
|
||||
logger.info(f'use device:{device}')
|
||||
|
||||
while True:
|
||||
if self._exit_event.is_set():
|
||||
@ -66,7 +74,7 @@ class AudioInferenceHandler(AudioHandler):
|
||||
0)
|
||||
index = index + 1
|
||||
else:
|
||||
print('infer=======')
|
||||
logger.info('infer=======')
|
||||
t = time.perf_counter()
|
||||
img_batch = []
|
||||
for i in range(batch_size):
|
||||
@ -95,20 +103,21 @@ class AudioInferenceHandler(AudioHandler):
|
||||
count += batch_size
|
||||
|
||||
if count >= 100:
|
||||
print(f"------actual avg infer fps:{count / count_time:.4f}")
|
||||
logger.info(f"------actual avg infer fps:{count / count_time:.4f}")
|
||||
count = 0
|
||||
count_time = 0
|
||||
|
||||
image_index = 0
|
||||
for i, res_frame in enumerate(pred):
|
||||
self.on_next_handle(
|
||||
(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
||||
0)
|
||||
index = index + 1
|
||||
image_index = image_index + 1
|
||||
print('batch count', image_index)
|
||||
print('total batch time:', time.perf_counter() - start_time)
|
||||
logger.info(f'total batch time: {time.perf_counter() - start_time}')
|
||||
else:
|
||||
time.sleep(1)
|
||||
break
|
||||
print('musereal inference processor stop')
|
||||
logger.info('AudioInferenceHandler inference processor stop')
|
||||
|
||||
def stop(self):
|
||||
self._exit_event.clear()
|
||||
self._run_thread.join()
|
||||
|
@ -7,7 +7,7 @@ from threading import Thread, Event
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .audio_handler import AudioHandler
|
||||
from human_handler import AudioHandler
|
||||
from utils import melspectrogram
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -25,6 +25,7 @@ class AudioMalHandler(AudioHandler):
|
||||
|
||||
self.frames = []
|
||||
self.chunk = context.sample_rate // context.fps
|
||||
logger.info("AudioMalHandler init")
|
||||
|
||||
def on_handle(self, stream, index):
|
||||
self._queue.put(stream)
|
||||
|
@ -1,15 +1,17 @@
|
||||
#encoding = utf8
|
||||
import logging
|
||||
import os
|
||||
|
||||
from asr import SherpaNcnnAsr
|
||||
from human.audio_inference_handler import AudioInferenceHandler
|
||||
from human.audio_mal_handler import AudioMalHandler
|
||||
from human.human_render import HumanRender
|
||||
from .audio_inference_handler import AudioInferenceHandler
|
||||
from .audio_mal_handler import AudioMalHandler
|
||||
from .human_render import HumanRender
|
||||
from nlp import PunctuationSplit, DouBao
|
||||
from tts import TTSEdge, TTSAudioSplitHandle
|
||||
from utils import load_avatar, get_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class HumanContext:
|
||||
@ -23,7 +25,9 @@ class HumanContext:
|
||||
|
||||
self._device = get_device()
|
||||
print(f'device:{self._device}')
|
||||
full_images, face_frames, coord_frames = load_avatar(r'./face/', self._image_size, self._device)
|
||||
base_path = os.path.join(current_file_path, '..', 'face')
|
||||
logger.info(f'_create_recognizer init, path:{base_path}')
|
||||
full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
|
||||
self._frame_list_cycle = full_images
|
||||
self._face_list_cycle = face_frames
|
||||
self._coord_list_cycle = coord_frames
|
||||
@ -31,14 +35,24 @@ class HumanContext:
|
||||
logging.info(f'face images length: {face_images_length}')
|
||||
print(f'face images length: {face_images_length}')
|
||||
|
||||
self.asr = None
|
||||
self.nlp = None
|
||||
self.tts = None
|
||||
self.tts_handle = None
|
||||
self.mal_handler = None
|
||||
self.infer_handler = None
|
||||
self._asr = None
|
||||
self._nlp = None
|
||||
self._tts = None
|
||||
self._tts_handle = None
|
||||
self._mal_handler = None
|
||||
self._infer_handler = None
|
||||
self._render_handler = None
|
||||
|
||||
def __del__(self):
|
||||
print(f'HumanContext: __del__')
|
||||
self._asr.stop()
|
||||
self._nlp.stop()
|
||||
self._tts.stop()
|
||||
self._tts_handle.stop()
|
||||
self._mal_handler.stop()
|
||||
self._infer_handler.stop()
|
||||
self._render_handler.stop()
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
return self._fps
|
||||
@ -81,17 +95,17 @@ class HumanContext:
|
||||
|
||||
@property
|
||||
def render_handler(self):
|
||||
return self.render_handler
|
||||
return self._render_handler
|
||||
|
||||
def build(self):
|
||||
self._render_handler = HumanRender(self, None)
|
||||
self.infer_handler = AudioInferenceHandler(self, self._render_handler)
|
||||
self.mal_handler = AudioMalHandler(self, self.infer_handler)
|
||||
self.tts_handle = TTSAudioSplitHandle(self, self.mal_handler)
|
||||
self.tts = TTSEdge(self.tts_handle)
|
||||
self._infer_handler = AudioInferenceHandler(self, self._render_handler)
|
||||
self._mal_handler = AudioMalHandler(self, self._infer_handler)
|
||||
self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler)
|
||||
self._tts = TTSEdge(self._tts_handle)
|
||||
split = PunctuationSplit()
|
||||
nlp = DouBao(split, self.tts)
|
||||
self.asr = SherpaNcnnAsr()
|
||||
self.asr.attach(nlp)
|
||||
self._nlp = DouBao(split, self._tts)
|
||||
self._asr = SherpaNcnnAsr()
|
||||
self._asr.attach(self._nlp)
|
||||
|
||||
|
||||
|
@ -9,8 +9,7 @@ from threading import Thread, Event
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from audio_render import AudioRender
|
||||
from .audio_handler import AudioHandler
|
||||
from human_handler import AudioHandler
|
||||
|
||||
|
||||
class HumanRender(AudioHandler):
|
||||
@ -27,12 +26,12 @@ class HumanRender(AudioHandler):
|
||||
self._thread.start()
|
||||
|
||||
def _on_run(self):
|
||||
logging.info('chunk2mal run')
|
||||
logging.info('human render run')
|
||||
while self._exit_event.is_set():
|
||||
self._run_step()
|
||||
time.sleep(0.002)
|
||||
|
||||
logging.info('chunk2mal exit')
|
||||
logging.info('human render exit')
|
||||
|
||||
def _run_step(self):
|
||||
try:
|
||||
@ -58,7 +57,7 @@ class HumanRender(AudioHandler):
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self._image_render is not None:
|
||||
self._image_render.render(image)
|
||||
self._image_render.on_render(image)
|
||||
|
||||
for audio_frame in audio_frames:
|
||||
frame, type_ = audio_frame
|
||||
@ -69,7 +68,6 @@ class HumanRender(AudioHandler):
|
||||
# new_frame.planes[0].update(frame.tobytes())
|
||||
# new_frame.sample_rate = 16000
|
||||
|
||||
|
||||
def set_audio_render(self, render):
|
||||
self._audio_render = render
|
||||
|
||||
@ -79,4 +77,6 @@ class HumanRender(AudioHandler):
|
||||
def on_handle(self, stream, index):
|
||||
self._queue.put(stream)
|
||||
|
||||
|
||||
def stop(self):
|
||||
self._exit_event.clear()
|
||||
self._thread.join()
|
||||
|
@ -12,6 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
class DouBao(NLPBase):
|
||||
def __init__(self, split, callback=None):
|
||||
super().__init__(split, callback)
|
||||
logger.info("DouBao init")
|
||||
# Access Key ID
|
||||
# AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y
|
||||
# AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc
|
||||
@ -30,7 +31,7 @@ class DouBao(NLPBase):
|
||||
async def _request(self, question):
|
||||
t = time.time()
|
||||
logger.info(f'_request:{question}')
|
||||
print(f'-------dou_bao ask:', question)
|
||||
logger.info(f'-------dou_bao ask:{question}')
|
||||
try:
|
||||
stream = await self.__client.chat.completions.create(
|
||||
model="ep-20241008152048-fsgzf",
|
||||
@ -51,38 +52,9 @@ class DouBao(NLPBase):
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.info(f'_request:{question}, time:{time.time() - t:.4f}s')
|
||||
print(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s')
|
||||
|
||||
async def _on_close(self):
|
||||
print('AsyncArk close')
|
||||
logger.info('AsyncArk close')
|
||||
if self.__client is not None and not self.__client.is_closed():
|
||||
await self.__client.close()
|
||||
|
||||
'''
|
||||
if __name__ == "__main__":
|
||||
# print(get_dou_bao_api())
|
||||
dou_bao = DouBao()
|
||||
dou_bao.ask('你好。')
|
||||
dou_bao.ask('你好,你是谁?')
|
||||
dou_bao.ask('你能做什么?')
|
||||
dou_bao.ask('介绍一下,我自己。')
|
||||
count = 1000
|
||||
sec = ''
|
||||
while count >= 0:
|
||||
count = count - 1
|
||||
if nlp_queue.empty():
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
sec = sec + nlp_queue.get(block=True, timeout=0.01)
|
||||
|
||||
pattern = r'[,。、;?!,.!?]'
|
||||
match = re.search(pattern, sec)
|
||||
if match:
|
||||
pos = match.start() + 1
|
||||
print(sec[:pos])
|
||||
sec = sec[pos:]
|
||||
print(sec)
|
||||
|
||||
|
||||
dou_bao.stop()
|
||||
'''
|
||||
|
44
test/test_human_context.py
Normal file
44
test/test_human_context.py
Normal file
@ -0,0 +1,44 @@
|
||||
#encoding = utf8
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
from human import HumanContext
|
||||
from utils import config_logging
|
||||
|
||||
|
||||
# try:
|
||||
# import sounddevice as sd
|
||||
# except ImportError as e:
|
||||
# print("Please install sounddevice first. You can use")
|
||||
# print()
|
||||
# print(" pip install sounddevice")
|
||||
# print()
|
||||
# print("to install it")
|
||||
# sys.exit(-1)
|
||||
#
|
||||
|
||||
|
||||
def main():
|
||||
print("Started! Please speak")
|
||||
human = HumanContext()
|
||||
human.build()
|
||||
time.sleep(60)
|
||||
print("Stop! ")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# devices = sd.query_devices()
|
||||
# print(devices)
|
||||
# default_input_device_idx = sd.default.device[0]
|
||||
# print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||
|
||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
log_path = os.path.join(current_file_path, '..', 'logs', 'info.log')
|
||||
config_logging(log_path, logging.INFO, logging.INFO)
|
||||
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nCaught Ctrl + C. Exiting")
|
@ -1,9 +1,13 @@
|
||||
#encoding = utf8
|
||||
import heapq
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from utils import save_wav
|
||||
from human import AudioHandler
|
||||
from human_handler import AudioHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTSAudioHandle(AudioHandler):
|
||||
@ -27,26 +31,38 @@ class TTSAudioHandle(AudioHandler):
|
||||
def on_handle(self, stream, index):
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
|
||||
class TTSAudioSplitHandle(TTSAudioHandle):
|
||||
def __init__(self, context, handler):
|
||||
super().__init__(context, handler)
|
||||
self.sample_rate = self._context.sample_rate
|
||||
self._chunk = self.sample_rate // self._context.fps
|
||||
self._priority_queue = []
|
||||
logger.info("TTSAudioSplitHandle init")
|
||||
|
||||
def on_handle(self, stream, index):
|
||||
# heapq.heappush(self._priority_queue, (index, stream))
|
||||
if stream is None:
|
||||
heapq.heappush(self._priority_queue, (index, None))
|
||||
|
||||
stream_len = stream.shape[0]
|
||||
idx = 0
|
||||
|
||||
while stream_len >= self._chunk:
|
||||
self._context.put_audio_frame(stream[idx:idx + self._chunk])
|
||||
self.on_next_handle(stream[idx:idx + self._chunk], 0)
|
||||
stream_len -= self._chunk
|
||||
idx += self._chunk
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
|
||||
class TTSAudioSaveHandle(TTSAudioHandle):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, context, handler):
|
||||
super().__init__(context, handler)
|
||||
self._save_path_dir = '../temp/audio/'
|
||||
self._clean()
|
||||
|
||||
@ -72,3 +88,6 @@ class TTSAudioSaveHandle(TTSAudioHandle):
|
||||
file_name = self._save_path_dir + str(index) + '.wav'
|
||||
save_wav(stream, file_name, self.sample_rate)
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
#encoding = utf8
|
||||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
@ -9,11 +9,14 @@ import resampy
|
||||
|
||||
from .tts_base import TTSBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTSEdge(TTSBase):
|
||||
def __init__(self, handle, voice='zh-CN-XiaoyiNeural'):
|
||||
super().__init__(handle)
|
||||
self._voice = voice
|
||||
logger.info(f"TTSEdge init, {voice}")
|
||||
|
||||
async def _on_request(self, txt: str):
|
||||
print('_on_request, txt')
|
||||
@ -42,6 +45,7 @@ class TTSEdge(TTSBase):
|
||||
print('-------tts finish push chunk')
|
||||
|
||||
except Exception as e:
|
||||
self._handle.on_handle(None, index)
|
||||
stream.seek(0)
|
||||
stream.truncate()
|
||||
print('-------tts finish error:', e)
|
||||
|
41
ui.py
41
ui.py
@ -16,8 +16,11 @@ from PIL import Image, ImageTk
|
||||
|
||||
from playsound import playsound
|
||||
|
||||
from audio_render import AudioRender
|
||||
# from Human import Human
|
||||
from human import HumanContext
|
||||
from utils import config_logging
|
||||
|
||||
# from tts.EdgeTTS import EdgeTTS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -48,7 +51,7 @@ class App(customtkinter.CTk):
|
||||
# self.logo_label.grid(row=0, column=0, padx=20, pady=(20, 10))
|
||||
|
||||
self.entry = customtkinter.CTkEntry(self, placeholder_text="输入内容")
|
||||
self.entry.insert(0, "大家好,我是九零科技有限公司,虚拟数字人。")
|
||||
self.entry.insert(0, "大家好,测试虚拟数字人。")
|
||||
self.entry.grid(row=2, column=0, columnspan=2, padx=(20, 0), pady=(20, 20), sticky="nsew")
|
||||
|
||||
self.main_button_1 = customtkinter.CTkButton(master=self, fg_color="transparent", border_width=2,
|
||||
@ -58,13 +61,14 @@ class App(customtkinter.CTk):
|
||||
|
||||
self._init_image_canvas()
|
||||
|
||||
self._is_play_audio = False
|
||||
self._audio_render = AudioRender()
|
||||
# self._human = Human()
|
||||
self._queue = Queue()
|
||||
self._human_context = HumanContext()
|
||||
self._human_context.build()
|
||||
render = self._human_context.render_handler
|
||||
render.set_image_render(self)
|
||||
render.set_audio_render(self._audio_render)
|
||||
self._render()
|
||||
# self.play_audio()
|
||||
|
||||
@ -84,13 +88,14 @@ class App(customtkinter.CTk):
|
||||
self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES)
|
||||
|
||||
def _render(self):
|
||||
after_time = 29
|
||||
try:
|
||||
image = self._queue.get()
|
||||
image = self._queue.get(block=True, timeout=0.003)
|
||||
if image is None:
|
||||
self.after(20, self._render)
|
||||
self.after(after_time, self._render)
|
||||
return
|
||||
except queue.Empty:
|
||||
self.after(20, self._render)
|
||||
self.after(after_time, self._render)
|
||||
return
|
||||
|
||||
iheight, iwidth = image.shape[0], image.shape[1]
|
||||
@ -111,7 +116,7 @@ class App(customtkinter.CTk):
|
||||
height = self.winfo_height() * 0.5
|
||||
self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk)
|
||||
self._canvas.update()
|
||||
self.after(20, self._render)
|
||||
self.after(after_time, self._render)
|
||||
|
||||
def request_tts(self):
|
||||
content = self.entry.get()
|
||||
@ -148,27 +153,9 @@ class App(customtkinter.CTk):
|
||||
sound = AudioSegment.from_mp3('./audio/mp3/' + file_name)
|
||||
sound.export('./audio/wav/' + file_name + '.wav', format="wav")
|
||||
|
||||
# open('./audio/', 'wb') with
|
||||
|
||||
|
||||
def config_logging(file_name: str, console_level: int=logging.INFO, file_level: int=logging.DEBUG):
|
||||
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
|
||||
file_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s [%(levelname)s] %(module)s.%(lineno)d %(name)s:\t%(message)s'
|
||||
))
|
||||
file_handler.setLevel(file_level)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(logging.Formatter(
|
||||
'[%(asctime)s %(levelname)s] %(message)s',
|
||||
datefmt="%Y/%m/%d %H:%M:%S"
|
||||
))
|
||||
console_handler.setLevel(console_level)
|
||||
|
||||
logging.basicConfig(
|
||||
level=min(console_level, file_level),
|
||||
handlers=[file_handler, console_handler],
|
||||
)
|
||||
def on_render(self, image):
|
||||
self._queue.put(image)
|
||||
print('on_render', self._queue.qsize())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,6 +1,6 @@
|
||||
#encoding = utf8
|
||||
|
||||
from .async_task_queue import AsyncTaskQueue
|
||||
from .utils import mirror_index, load_model, get_device, load_avatar
|
||||
from .utils import mirror_index, load_model, get_device, load_avatar, config_logging
|
||||
from .audio_utils import melspectrogram, save_wav
|
||||
|
||||
|
@ -38,7 +38,7 @@ def read_files_path(path):
|
||||
files = os.listdir(path)
|
||||
for file in files:
|
||||
if not os.path.isdir(file):
|
||||
file_paths.append(path + file)
|
||||
file_paths.append(os.path.join(path, file))
|
||||
return file_paths
|
||||
|
||||
|
||||
@ -160,6 +160,7 @@ def load_model(path):
|
||||
|
||||
|
||||
def load_avatar(path, img_size, device):
|
||||
print(f'load avatar:{path}')
|
||||
face_images_path = path
|
||||
face_images_path = read_files_path(face_images_path)
|
||||
full_list_cycle = read_images(face_images_path)
|
||||
@ -174,3 +175,23 @@ def load_avatar(path, img_size, device):
|
||||
coord_frames.append(coord)
|
||||
|
||||
return full_list_cycle, face_frames, coord_frames
|
||||
|
||||
|
||||
def config_logging(file_name: str, console_level: int=logging.INFO, file_level: int=logging.DEBUG):
|
||||
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
|
||||
file_handler.setFormatter(logging.Formatter(
|
||||
'%(asctime)s [%(levelname)s] %(module)s.%(lineno)d %(name)s:\t%(message)s'
|
||||
))
|
||||
file_handler.setLevel(file_level)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(logging.Formatter(
|
||||
'[%(asctime)s %(levelname)s] %(message)s',
|
||||
datefmt="%Y/%m/%d %H:%M:%S"
|
||||
))
|
||||
console_handler.setLevel(console_level)
|
||||
|
||||
logging.basicConfig(
|
||||
level=min(console_level, file_level),
|
||||
handlers=[file_handler, console_handler],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user