添加chunk处理

This commit is contained in:
jiegeaiai 2024-09-05 00:51:14 +08:00
parent 9569009f32
commit cf9fc3545d
5 changed files with 200 additions and 21 deletions

View File

@ -1,14 +1,63 @@
#encoding = utf8
from queue import Queue
import logging
import multiprocessing as mp
from tts.Chunk2Mal import Chunk2Mal
logger = logging.getLogger(__name__)
class Human:
def __init__(self):
self._tts = None
self._audio_chunk_queue = Queue()
self._fps = 50 # 20 ms per frame
self._batch_size = 16
self._sample_rate = 16000
self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000)
self._chunk_2_mal = Chunk2Mal(self)
self._stride_left_size = 10
self._stride_right_size = 10
self._feat_queue = mp.Queue(2)
def get_fps(self):
return self._fps
def get_batch_size(self):
return self._batch_size
def get_chunk(self):
return self._chunk
def get_stride_left_size(self):
return self._stride_left_size
def get_stride_right_size(self):
return self._stride_right_size
def on_destroy(self):
self._chunk_2_mal.stop()
if self._tts is not None:
self._tts.stop()
logger.info('human destroy')
def set_tts(self, tts):
if self._tts == tts:
return
self._tts = tts
self._tts.start()
self._chunk_2_mal.start()
def read(self, txt):
if self._tts is None:
logger.warning('tts is none')
return
self._tts.push_txt(txt)
def push_audio_chunk(self, chunk):
pass
self._chunk_2_mal.push_chunk(chunk)
def push_feat_queue(self, mel_chunks):
self._feat_queue.put(mel_chunks)

View File

@ -2,7 +2,7 @@ librosa~=0.10.2.post1
numpy~=1.26.3
opencv-contrib-python
opencv-python~=4.10.0.84
torch~=2.4.0+cu118
torch
torchvision
tqdm~=4.66.5
numba

84
tts/Chunk2Mal.py Normal file
View File

@ -0,0 +1,84 @@
#encoding = utf8
import logging
import queue
from queue import Queue
from threading import Thread, Event
import numpy as np
import audio
class Chunk2Mal:
def __init__(self, human):
self._audio_chunk_queue = Queue()
self._human = human
self._thread = None
self._exit_event = None
self._chunks = []
def _on_run(self):
logging.info('chunk2mal run')
while not self._exit_event.is_set():
try:
chunk, type = self.pull_chunk()
self._chunks.append(chunk)
except queue.Empty:
continue
if len(self._chunks) <= self._human.get_stride_left_size() + self._human.get_stride_right_size():
continue
inputs = np.concatenate(self._chunks) # [N * chunk]
mel = audio.melspectrogram(inputs)
left = max(0, self._human.get_stride_left_size() * 80 / 50)
right = min(len(mel[0]), len(mel[0]) - self._human.get_stride_right_size() * 80 / 50)
mel_idx_multiplier = 80. * 2 / self._human.get_fps()
mel_step_size = 16
i = 0
mel_chunks = []
while i < (len(self._chunks) - self._human.get_stride_left_size()
- self._human.get_stride_right_size()) / 2:
start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx)
if start_idx + mel_step_size > len(mel[0]):
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
else:
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
i += 1
self._human.push_feat_queue(mel_chunks)
# discard the old part to save memory
self._chunks = self._chunks[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):]
logging.info('chunk2mal exit')
def start(self):
if self._exit_event is not None:
return
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._thread.start()
logging.info('chunk2mal start')
def stop(self):
if self._exit_event is None:
return
self._exit_event.set()
self._thread.join()
logging.info('chunk2mal stop')
def push_chunk(self, chunk):
self._audio_chunk_queue.put(chunk)
def pull_chunk(self):
try:
chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
type = 1
except queue.Empty:
chunk = np.zeros(self._human.get_chunk(), dtype=np.float32)
type = 0
return chunk, type

View File

@ -1,4 +1,5 @@
#encoding = utf8
import logging
import queue
from io import BytesIO
from queue import Queue
@ -12,17 +13,19 @@ class TTSBase:
self._queue = Queue()
self._exit_event = None
self._io_stream = BytesIO()
self._fps = 50
self._fps = human.get_fps()
self._sample_rate = 16000
self._chunk = self._sample_rate // self._fps
def _on_run(self):
logging.info('tts run')
while not self._exit_event.is_set():
try:
txt = self._queue.get(block=True, timeout=1)
except queue.Empty:
continue
self._request(txt)
logging.info('tts exit')
def _request(self, txt):
pass
@ -33,6 +36,7 @@ class TTSBase:
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._thread.start()
logging.info('tts start')
def stop(self):
if self._exit_event is None:
@ -40,6 +44,7 @@ class TTSBase:
self._exit_event.set()
self._thread.join()
logging.info('tts stop')
def clear(self):
self._queue.queue.clear()

73
ui.py
View File

@ -1,6 +1,7 @@
#encoding = utf8
import json
import logging
from logging import handlers
import tkinter
import tkinter.messagebox
import customtkinter
@ -52,6 +53,10 @@ class App(customtkinter.CTk):
tts = EdgeTTS(self._human)
self._human.set_tts(tts)
def on_destroy(self):
logger.info('------------App destroy------------')
self._human.on_destroy()
def _init_image_canvas(self):
self._canvas = customtkinter.CTkCanvas(self.image_frame)
self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES)
@ -60,20 +65,21 @@ class App(customtkinter.CTk):
content = self.entry.get()
print('content:', content)
self.entry.delete(0, customtkinter.END)
payload = {
'text': content,
'voice': 'zh-CN-XiaoyiNeural'
}
resp = requests.get(self._tts_url + '/tts', params=urlencode(payload))
if resp.status_code != 200:
print('tts error', resp.status_code)
return
print(resp.content)
resJson = json.loads(resp.text)
url = resJson.get('url')
self.download_tts(url)
self._human.read(content)
# payload = {
# 'text': content,
# 'voice': 'zh-CN-XiaoyiNeural'
# }
# resp = requests.get(self._tts_url + '/tts', params=urlencode(payload))
# if resp.status_code != 200:
# print('tts error', resp.status_code)
# return
#
# print(resp.content)
#
# resJson = json.loads(resp.text)
# url = resJson.get('url')
# self.download_tts(url)
def download_tts(self, url):
file_name = url[3:]
@ -91,8 +97,43 @@ class App(customtkinter.CTk):
# open('./audio/', 'wb') with
if __name__ == "__main__":
logging.basicConfig(filename='./logs/info.log', level=logging.INFO)
def config_logging(file_name: str, console_level: int=logging.INFO, file_level: int=logging.DEBUG):
file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8")
file_handler.setFormatter(logging.Formatter(
'%(asctime)s [%(levelname)s] %(module)s.%(lineno)d %(name)s:\t%(message)s'
))
file_handler.setLevel(file_level)
console_handler = logging.StreamHandler()
console_handler.setFormatter(logging.Formatter(
'[%(asctime)s %(levelname)s] %(message)s',
datefmt="%Y/%m/%d %H:%M:%S"
))
console_handler.setLevel(console_level)
logging.basicConfig(
level=min(console_level, file_level),
handlers=[file_handler, console_handler],
)
if __name__ == "__main__":
# logging.basicConfig(filename='./logs/info.log', level=logging.INFO)
config_logging('./logs/info.log', logging.INFO, logging.INFO)
# logger = logging.getLogger('manager')
# # 输出到控制台, 级别为DEBUG
# console = logging.StreamHandler()
# console.setLevel(logging.DEBUG)
# logger.addHandler(console)
#
# # 输出到文件, 级别为INFO, 文件按大小切分
# filelog = logging.handlers.RotatingFileHandler(filename='./logs/info.log', level=logging.INFO,
# maxBytes=1024 * 1024, backupCount=5)
# filelog.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
# logger.setLevel(logging.INFO)
# logger.addHandler(filelog)
logger.info('------------start------------')
app = App()
app.mainloop()
app.on_destroy()
# logger.info('------------exit------------')