modify human

This commit is contained in:
brige 2024-09-21 20:58:26 +08:00
parent cb9ea0ac17
commit 4e1e923c0b
4 changed files with 76 additions and 41 deletions

View File

@ -14,6 +14,9 @@ import torch
import cv2 import cv2
from tqdm import tqdm from tqdm import tqdm
from tts.EdgeTTS import EdgeTTS
from tts.TTSBase import TTSBase
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -139,19 +142,18 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
class Human: class Human:
def __init__(self): def __init__(self):
self._text = None
self._tts = None
self._fps = 50 # 20 ms per frame self._fps = 50 # 20 ms per frame
self._batch_size = 16 self._batch_size = 16
self._sample_rate = 16000 self._sample_rate = 16000
self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000)
self._chunk_2_mal = Chunk2Mal(self)
self._stride_left_size = 10 self._stride_left_size = 10
self._stride_right_size = 10 self._stride_right_size = 10
self._feat_queue = mp.Queue(2) self._feat_queue = mp.Queue(2)
self._output_queue = mp.Queue() self._output_queue = mp.Queue()
self._res_frame_queue = mp.Queue(self._batch_size * 2) self._res_frame_queue = mp.Queue(self._batch_size * 2)
self._chunk_2_mal = Chunk2Mal(self)
self._tts = TTSBase(self)
face_images_path = r'./face/' face_images_path = r'./face/'
self._face_image_paths = utils.read_files_path(face_images_path) self._face_image_paths = utils.read_files_path(face_images_path)
print(self._face_image_paths) print(self._face_image_paths)
@ -167,8 +169,8 @@ class Human:
def get_batch_size(self): def get_batch_size(self):
return self._batch_size return self._batch_size
def get_chunk(self): def get_audio_sample_rate(self):
return self._chunk return self._sample_rate
def get_stride_left_size(self): def get_stride_left_size(self):
return self._stride_left_size return self._stride_left_size
@ -183,14 +185,6 @@ class Human:
self._tts.stop() self._tts.stop()
logging.info('human destroy') logging.info('human destroy')
def set_tts(self, tts):
if self._tts == tts:
return
self._tts = tts
self._tts.start()
self._chunk_2_mal.start()
def read(self, txt): def read(self, txt):
if self._tts is None: if self._tts is None:
logging.warning('tts is none') logging.warning('tts is none')
@ -198,8 +192,8 @@ class Human:
self._tts.push_txt(txt) self._tts.push_txt(txt)
def push_audio_chunk(self, chunk): def push_audio_chunk(self, audio_chunk):
self._chunk_2_mal.push_chunk(chunk) self._chunk_2_mal.push_chunk(audio_chunk)
def push_feat_queue(self, mel_chunks): def push_feat_queue(self, mel_chunks):
print("push_feat_queue") print("push_feat_queue")

View File

@ -14,12 +14,20 @@ class Chunk2Mal:
self._audio_chunk_queue = Queue() self._audio_chunk_queue = Queue()
self._human = human self._human = human
self._thread = None self._thread = None
self._exit_event = None
self._chunks = [] self._chunks = []
# 320 samples per chunk (20ms * 16000 / 1000)audio_chunk
self._chunk_len = self._human.get_audio_sample_rate // self._human.get_fps()
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._exit_event.set()
self._thread.start()
logging.info('chunk2mal start')
def _on_run(self): def _on_run(self):
logging.info('chunk2mal run') logging.info('chunk2mal run')
while not self._exit_event.is_set(): while self._exit_event.is_set():
try: try:
chunk, type_ = self.pull_chunk() chunk, type_ = self.pull_chunk()
self._chunks.append(chunk) self._chunks.append(chunk)
@ -57,19 +65,11 @@ class Chunk2Mal:
logging.info('chunk2mal exit') logging.info('chunk2mal exit')
def start(self):
if self._exit_event is not None:
return
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._thread.start()
logging.info('chunk2mal start')
def stop(self): def stop(self):
if self._exit_event is None: if self._exit_event is None:
return return
self._exit_event.set() self._exit_event.clear()
if self._thread.is_alive(): if self._thread.is_alive():
self._thread.join() self._thread.join()
logging.info('chunk2mal stop') logging.info('chunk2mal stop')

View File

@ -1,10 +1,19 @@
#encoding = utf8 #encoding = utf8
import logging import logging
import asyncio
import time
import edge_tts
import numpy as np
import soundfile
import resampy
import queue import queue
from io import BytesIO from io import BytesIO
from queue import Queue from queue import Queue
from threading import Thread, Event from threading import Thread, Event
logger = logging.getLogger(__name__)
class TTSBase: class TTSBase:
def __init__(self, human): def __init__(self, human):
@ -13,13 +22,18 @@ class TTSBase:
self._queue = Queue() self._queue = Queue()
self._exit_event = None self._exit_event = None
self._io_stream = BytesIO() self._io_stream = BytesIO()
self._fps = human.get_fps()
self._sample_rate = 16000 self._sample_rate = 16000
self._chunk = self._sample_rate // self._fps self._chunk = self._sample_rate // self._human.get_fps()
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._exit_event.set()
self._thread.start()
logging.info('tts start')
def _on_run(self): def _on_run(self):
logging.info('tts run') logging.info('tts run')
while not self._exit_event.is_set(): while self._exit_event.is_set():
try: try:
txt = self._queue.get(block=True, timeout=1) txt = self._queue.get(block=True, timeout=1)
except queue.Empty: except queue.Empty:
@ -28,21 +42,50 @@ class TTSBase:
logging.info('tts exit') logging.info('tts exit')
def _request(self, txt): def _request(self, txt):
pass voice = 'zh-CN-XiaoyiNeural'
t = time.time()
asyncio.new_event_loop().run_until_complete(self.__on_request(voice, txt))
logger.info(f'edge tts time:{time.time() - t : 0.4f}s')
def start(self): self._io_stream.seek(0)
if self._exit_event is not None: stream = self.__create_bytes_stream(self._io_stream)
return stream_len = stream.shape[0]
self._exit_event = Event() index = 0
self._thread = Thread(target=self._on_run) while stream_len >= self._chunk:
self._thread.start() self._human.push_audio_chunk(stream[index:index + self._chunk])
logging.info('tts start') stream_len -= self._chunk
index += self._chunk
def __create_bytes_stream(self, io_stream):
stream, sample_rate = soundfile.read(io_stream)
logger.info(f'tts audio stream {sample_rate} : {stream.shape}')
stream = stream.astype(np.float32)
if stream.ndim > 1:
logger.warning(f'tts audio has {stream.shape[1]} channels, only use the first')
stream = stream[:, 1]
if sample_rate != self._sample_rate and stream.shape[0] > 0:
logger.warning(f'tts audio sample rate is {sample_rate}, resample to {self._sample_rate}')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate)
return stream
async def __on_request(self, voice, txt):
communicate = edge_tts.Communicate(txt, voice)
first = True
async for chuck in communicate.stream():
if first:
first = False
if chuck['type'] == 'audio':
self._io_stream.write(chuck['data'])
def stop(self): def stop(self):
if self._exit_event is None: if self._exit_event is None:
return return
self._exit_event.set() self._exit_event.clear()
self._thread.join() self._thread.join()
logging.info('tts stop') logging.info('tts stop')

2
ui.py
View File

@ -49,8 +49,6 @@ class App(customtkinter.CTk):
self._init_image_canvas() self._init_image_canvas()
self._human = Human() self._human = Human()
tts = EdgeTTS(self._human)
self._human.set_tts(tts)
self._render() self._render()
def on_destroy(self): def on_destroy(self):