diff --git a/Human.py b/Human.py index f37e98b..ca78f9e 100644 --- a/Human.py +++ b/Human.py @@ -1,14 +1,63 @@ #encoding = utf8 -from queue import Queue +import logging +import multiprocessing as mp + +from tts.Chunk2Mal import Chunk2Mal + +logger = logging.getLogger(__name__) class Human: def __init__(self): self._tts = None - self._audio_chunk_queue = Queue() + self._fps = 50 # 20 ms per frame + self._batch_size = 16 + self._sample_rate = 16000 + self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000) + self._chunk_2_mal = Chunk2Mal(self) + self._stride_left_size = 10 + self._stride_right_size = 10 + self._feat_queue = mp.Queue(2) + + def get_fps(self): + return self._fps + + def get_batch_size(self): + return self._batch_size + + def get_chunk(self): + return self._chunk + + def get_stride_left_size(self): + return self._stride_left_size + + def get_stride_right_size(self): + return self._stride_right_size + + def on_destroy(self): + self._chunk_2_mal.stop() + + if self._tts is not None: + self._tts.stop() + logger.info('human destroy') def set_tts(self, tts): + if self._tts == tts: + return + self._tts = tts + self._tts.start() + self._chunk_2_mal.start() + + def read(self, txt): + if self._tts is None: + logger.warning('tts is none') + return + + self._tts.push_txt(txt) def push_audio_chunk(self, chunk): - pass + self._chunk_2_mal.push_chunk(chunk) + + def push_feat_queue(self, mel_chunks): + self._feat_queue.put(mel_chunks) diff --git a/requirements.txt b/requirements.txt index 08809c8..cc1e1f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ librosa~=0.10.2.post1 numpy~=1.26.3 opencv-contrib-python opencv-python~=4.10.0.84 -torch~=2.4.0+cu118 +torch torchvision tqdm~=4.66.5 numba diff --git a/tts/Chunk2Mal.py b/tts/Chunk2Mal.py new file mode 100644 index 0000000..004d682 --- /dev/null +++ b/tts/Chunk2Mal.py @@ -0,0 +1,84 @@ +#encoding = utf8 +import logging +import queue +from queue import Queue +from threading import Thread, Event + +import numpy as np +import audio + + +class Chunk2Mal: + def __init__(self, human): + self._audio_chunk_queue = Queue() + self._human = human + self._thread = None + self._exit_event = None + self._chunks = [] + + def _on_run(self): + logging.info('chunk2mal run') + while not self._exit_event.is_set(): + try: + chunk, type = self.pull_chunk() + self._chunks.append(chunk) + except queue.Empty: + continue + + if len(self._chunks) <= self._human.get_stride_left_size() + self._human.get_stride_right_size(): + continue + + inputs = np.concatenate(self._chunks) # [N * chunk] + mel = audio.melspectrogram(inputs) + left = max(0, self._human.get_stride_left_size() * 80 / 50) + right = min(len(mel[0]), len(mel[0]) - self._human.get_stride_right_size() * 80 / 50) + mel_idx_multiplier = 80. * 2 / self._human.get_fps() + mel_step_size = 16 + i = 0 + mel_chunks = [] + while i < (len(self._chunks) - self._human.get_stride_left_size() + - self._human.get_stride_right_size()) / 2: + start_idx = int(left + i * mel_idx_multiplier) + # print(start_idx) + if start_idx + mel_step_size > len(mel[0]): + mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) + else: + mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) + i += 1 + self._human.push_feat_queue(mel_chunks) + + # discard the old part to save memory + self._chunks = self._chunks[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):] + + logging.info('chunk2mal exit') + + def start(self): + if self._exit_event is not None: + return + self._exit_event = Event() + self._thread = Thread(target=self._on_run) + self._thread.start() + logging.info('chunk2mal start') + + def stop(self): + if self._exit_event is None: + return + + self._exit_event.set() + self._thread.join() + logging.info('chunk2mal stop') + + def push_chunk(self, chunk): + self._audio_chunk_queue.put(chunk) + + def pull_chunk(self): + try: + chunk = self._audio_chunk_queue.get(block=True, timeout=1.0) + type = 1 + except queue.Empty: + chunk = np.zeros(self._human.get_chunk(), dtype=np.float32) + type = 0 + return chunk, type + + + diff --git a/tts/TTSBase.py b/tts/TTSBase.py index bc4eb28..b15ad94 100644 --- a/tts/TTSBase.py +++ b/tts/TTSBase.py @@ -1,4 +1,5 @@ #encoding = utf8 +import logging import queue from io import BytesIO from queue import Queue @@ -12,17 +13,19 @@ class TTSBase: self._queue = Queue() self._exit_event = None self._io_stream = BytesIO() - self._fps = 50 + self._fps = human.get_fps() self._sample_rate = 16000 self._chunk = self._sample_rate // self._fps def _on_run(self): + logging.info('tts run') while not self._exit_event.is_set(): try: txt = self._queue.get(block=True, timeout=1) except queue.Empty: continue self._request(txt) + logging.info('tts exit') def _request(self, txt): pass @@ -33,6 +36,7 @@ class TTSBase: self._exit_event = Event() self._thread = Thread(target=self._on_run) self._thread.start() + logging.info('tts start') def stop(self): if self._exit_event is None: @@ -40,6 +44,7 @@ class TTSBase: self._exit_event.set() self._thread.join() + logging.info('tts stop') def clear(self): self._queue.queue.clear() diff --git a/ui.py b/ui.py index 3a27263..b41b4c0 100644 --- a/ui.py +++ b/ui.py @@ -1,6 +1,7 @@ #encoding = utf8 import json import logging +from logging import handlers import tkinter import tkinter.messagebox import customtkinter @@ -52,6 +53,10 @@ class App(customtkinter.CTk): tts = EdgeTTS(self._human) self._human.set_tts(tts) + def on_destroy(self): + logger.info('------------App destroy------------') + self._human.on_destroy() + def _init_image_canvas(self): self._canvas = customtkinter.CTkCanvas(self.image_frame) self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES) @@ -60,20 +65,21 @@ class App(customtkinter.CTk): content = self.entry.get() print('content:', content) self.entry.delete(0, customtkinter.END) - payload = { - 'text': content, - 'voice': 'zh-CN-XiaoyiNeural' - } - resp = requests.get(self._tts_url + '/tts', params=urlencode(payload)) - if resp.status_code != 200: - print('tts error', resp.status_code) - return - - print(resp.content) - - resJson = json.loads(resp.text) - url = resJson.get('url') - self.download_tts(url) + self._human.read(content) + # payload = { + # 'text': content, + # 'voice': 'zh-CN-XiaoyiNeural' + # } + # resp = requests.get(self._tts_url + '/tts', params=urlencode(payload)) + # if resp.status_code != 200: + # print('tts error', resp.status_code) + # return + # + # print(resp.content) + # + # resJson = json.loads(resp.text) + # url = resJson.get('url') + # self.download_tts(url) def download_tts(self, url): file_name = url[3:] @@ -91,8 +97,43 @@ class App(customtkinter.CTk): # open('./audio/', 'wb') with -if __name__ == "__main__": - logging.basicConfig(filename='./logs/info.log', level=logging.INFO) +def config_logging(file_name: str, console_level: int=logging.INFO, file_level: int=logging.DEBUG): + file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8") + file_handler.setFormatter(logging.Formatter( + '%(asctime)s [%(levelname)s] %(module)s.%(lineno)d %(name)s:\t%(message)s' + )) + file_handler.setLevel(file_level) + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter( + '[%(asctime)s %(levelname)s] %(message)s', + datefmt="%Y/%m/%d %H:%M:%S" + )) + console_handler.setLevel(console_level) + + logging.basicConfig( + level=min(console_level, file_level), + handlers=[file_handler, console_handler], + ) + + +if __name__ == "__main__": + # logging.basicConfig(filename='./logs/info.log', level=logging.INFO) + config_logging('./logs/info.log', logging.INFO, logging.INFO) + # logger = logging.getLogger('manager') + # # 输出到控制台, 级别为DEBUG + # console = logging.StreamHandler() + # console.setLevel(logging.DEBUG) + # logger.addHandler(console) + # + # # 输出到文件, 级别为INFO, 文件按大小切分 + # filelog = logging.handlers.RotatingFileHandler(filename='./logs/info.log', level=logging.INFO, + # maxBytes=1024 * 1024, backupCount=5) + # filelog.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) + # logger.setLevel(logging.INFO) + # logger.addHandler(filelog) + logger.info('------------start------------') app = App() app.mainloop() + app.on_destroy() + # logger.info('------------exit------------')