diff --git a/Human.py b/Human.py index d79dc99..c4aa573 100644 --- a/Human.py +++ b/Human.py @@ -14,6 +14,9 @@ import torch import cv2 from tqdm import tqdm +from tts.EdgeTTS import EdgeTTS +from tts.TTSBase import TTSBase + device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -139,19 +142,18 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi class Human: def __init__(self): - self._text = None - self._tts = None self._fps = 50 # 20 ms per frame self._batch_size = 16 self._sample_rate = 16000 - self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000) - self._chunk_2_mal = Chunk2Mal(self) self._stride_left_size = 10 self._stride_right_size = 10 self._feat_queue = mp.Queue(2) self._output_queue = mp.Queue() self._res_frame_queue = mp.Queue(self._batch_size * 2) + self._chunk_2_mal = Chunk2Mal(self) + self._tts = TTSBase(self) + face_images_path = r'./face/' self._face_image_paths = utils.read_files_path(face_images_path) print(self._face_image_paths) @@ -167,8 +169,8 @@ class Human: def get_batch_size(self): return self._batch_size - def get_chunk(self): - return self._chunk + def get_audio_sample_rate(self): + return self._sample_rate def get_stride_left_size(self): return self._stride_left_size @@ -183,14 +185,6 @@ class Human: self._tts.stop() logging.info('human destroy') - def set_tts(self, tts): - if self._tts == tts: - return - - self._tts = tts - self._tts.start() - self._chunk_2_mal.start() - def read(self, txt): if self._tts is None: logging.warning('tts is none') @@ -198,8 +192,8 @@ class Human: self._tts.push_txt(txt) - def push_audio_chunk(self, chunk): - self._chunk_2_mal.push_chunk(chunk) + def push_audio_chunk(self, audio_chunk): + self._chunk_2_mal.push_chunk(audio_chunk) def push_feat_queue(self, mel_chunks): print("push_feat_queue") diff --git a/tts/Chunk2Mal.py b/tts/Chunk2Mal.py index f0d32af..460ab94 100644 --- a/tts/Chunk2Mal.py +++ b/tts/Chunk2Mal.py @@ -14,12 +14,20 @@ class Chunk2Mal: self._audio_chunk_queue = Queue() self._human = human self._thread = None - self._exit_event = None + self._chunks = [] + # 320 samples per chunk (20ms * 16000 / 1000)audio_chunk + self._chunk_len = self._human.get_audio_sample_rate // self._human.get_fps() + + self._exit_event = Event() + self._thread = Thread(target=self._on_run) + self._exit_event.set() + self._thread.start() + logging.info('chunk2mal start') def _on_run(self): logging.info('chunk2mal run') - while not self._exit_event.is_set(): + while self._exit_event.is_set(): try: chunk, type_ = self.pull_chunk() self._chunks.append(chunk) @@ -57,19 +65,11 @@ class Chunk2Mal: logging.info('chunk2mal exit') - def start(self): - if self._exit_event is not None: - return - self._exit_event = Event() - self._thread = Thread(target=self._on_run) - self._thread.start() - logging.info('chunk2mal start') - def stop(self): if self._exit_event is None: return - self._exit_event.set() + self._exit_event.clear() if self._thread.is_alive(): self._thread.join() logging.info('chunk2mal stop') diff --git a/tts/TTSBase.py b/tts/TTSBase.py index b15ad94..d38a539 100644 --- a/tts/TTSBase.py +++ b/tts/TTSBase.py @@ -1,10 +1,19 @@ #encoding = utf8 import logging +import asyncio +import time + +import edge_tts +import numpy as np +import soundfile +import resampy import queue from io import BytesIO from queue import Queue from threading import Thread, Event +logger = logging.getLogger(__name__) + class TTSBase: def __init__(self, human): @@ -13,13 +22,18 @@ class TTSBase: self._queue = Queue() self._exit_event = None self._io_stream = BytesIO() - self._fps = human.get_fps() self._sample_rate = 16000 - self._chunk = self._sample_rate // self._fps + self._chunk = self._sample_rate // self._human.get_fps() + + self._exit_event = Event() + self._thread = Thread(target=self._on_run) + self._exit_event.set() + self._thread.start() + logging.info('tts start') def _on_run(self): logging.info('tts run') - while not self._exit_event.is_set(): + while self._exit_event.is_set(): try: txt = self._queue.get(block=True, timeout=1) except queue.Empty: @@ -28,21 +42,50 @@ class TTSBase: logging.info('tts exit') def _request(self, txt): - pass + voice = 'zh-CN-XiaoyiNeural' + t = time.time() + asyncio.new_event_loop().run_until_complete(self.__on_request(voice, txt)) + logger.info(f'edge tts time:{time.time() - t : 0.4f}s') - def start(self): - if self._exit_event is not None: - return - self._exit_event = Event() - self._thread = Thread(target=self._on_run) - self._thread.start() - logging.info('tts start') + self._io_stream.seek(0) + stream = self.__create_bytes_stream(self._io_stream) + stream_len = stream.shape[0] + index = 0 + while stream_len >= self._chunk: + self._human.push_audio_chunk(stream[index:index + self._chunk]) + stream_len -= self._chunk + index += self._chunk + + def __create_bytes_stream(self, io_stream): + stream, sample_rate = soundfile.read(io_stream) + logger.info(f'tts audio stream {sample_rate} : {stream.shape}') + stream = stream.astype(np.float32) + + if stream.ndim > 1: + logger.warning(f'tts audio has {stream.shape[1]} channels, only use the first') + stream = stream[:, 1] + + if sample_rate != self._sample_rate and stream.shape[0] > 0: + logger.warning(f'tts audio sample rate is {sample_rate}, resample to {self._sample_rate}') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate) + + return stream + + async def __on_request(self, voice, txt): + communicate = edge_tts.Communicate(txt, voice) + first = True + async for chuck in communicate.stream(): + if first: + first = False + + if chuck['type'] == 'audio': + self._io_stream.write(chuck['data']) def stop(self): if self._exit_event is None: return - self._exit_event.set() + self._exit_event.clear() self._thread.join() logging.info('tts stop') diff --git a/ui.py b/ui.py index e6684ba..dc4fc6b 100644 --- a/ui.py +++ b/ui.py @@ -49,8 +49,6 @@ class App(customtkinter.CTk): self._init_image_canvas() self._human = Human() - tts = EdgeTTS(self._human) - self._human.set_tts(tts) self._render() def on_destroy(self):