modify human
This commit is contained in:
parent
cb9ea0ac17
commit
4e1e923c0b
26
Human.py
26
Human.py
@ -14,6 +14,9 @@ import torch
|
||||
import cv2
|
||||
from tqdm import tqdm
|
||||
|
||||
from tts.EdgeTTS import EdgeTTS
|
||||
from tts.TTSBase import TTSBase
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
|
||||
@ -139,19 +142,18 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
|
||||
|
||||
class Human:
|
||||
def __init__(self):
|
||||
self._text = None
|
||||
self._tts = None
|
||||
self._fps = 50 # 20 ms per frame
|
||||
self._batch_size = 16
|
||||
self._sample_rate = 16000
|
||||
self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000)
|
||||
self._chunk_2_mal = Chunk2Mal(self)
|
||||
self._stride_left_size = 10
|
||||
self._stride_right_size = 10
|
||||
self._feat_queue = mp.Queue(2)
|
||||
self._output_queue = mp.Queue()
|
||||
self._res_frame_queue = mp.Queue(self._batch_size * 2)
|
||||
|
||||
self._chunk_2_mal = Chunk2Mal(self)
|
||||
self._tts = TTSBase(self)
|
||||
|
||||
face_images_path = r'./face/'
|
||||
self._face_image_paths = utils.read_files_path(face_images_path)
|
||||
print(self._face_image_paths)
|
||||
@ -167,8 +169,8 @@ class Human:
|
||||
def get_batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def get_chunk(self):
|
||||
return self._chunk
|
||||
def get_audio_sample_rate(self):
|
||||
return self._sample_rate
|
||||
|
||||
def get_stride_left_size(self):
|
||||
return self._stride_left_size
|
||||
@ -183,14 +185,6 @@ class Human:
|
||||
self._tts.stop()
|
||||
logging.info('human destroy')
|
||||
|
||||
def set_tts(self, tts):
|
||||
if self._tts == tts:
|
||||
return
|
||||
|
||||
self._tts = tts
|
||||
self._tts.start()
|
||||
self._chunk_2_mal.start()
|
||||
|
||||
def read(self, txt):
|
||||
if self._tts is None:
|
||||
logging.warning('tts is none')
|
||||
@ -198,8 +192,8 @@ class Human:
|
||||
|
||||
self._tts.push_txt(txt)
|
||||
|
||||
def push_audio_chunk(self, chunk):
|
||||
self._chunk_2_mal.push_chunk(chunk)
|
||||
def push_audio_chunk(self, audio_chunk):
|
||||
self._chunk_2_mal.push_chunk(audio_chunk)
|
||||
|
||||
def push_feat_queue(self, mel_chunks):
|
||||
print("push_feat_queue")
|
||||
|
@ -14,12 +14,20 @@ class Chunk2Mal:
|
||||
self._audio_chunk_queue = Queue()
|
||||
self._human = human
|
||||
self._thread = None
|
||||
self._exit_event = None
|
||||
|
||||
self._chunks = []
|
||||
# 320 samples per chunk (20ms * 16000 / 1000)audio_chunk
|
||||
self._chunk_len = self._human.get_audio_sample_rate // self._human.get_fps()
|
||||
|
||||
self._exit_event = Event()
|
||||
self._thread = Thread(target=self._on_run)
|
||||
self._exit_event.set()
|
||||
self._thread.start()
|
||||
logging.info('chunk2mal start')
|
||||
|
||||
def _on_run(self):
|
||||
logging.info('chunk2mal run')
|
||||
while not self._exit_event.is_set():
|
||||
while self._exit_event.is_set():
|
||||
try:
|
||||
chunk, type_ = self.pull_chunk()
|
||||
self._chunks.append(chunk)
|
||||
@ -57,19 +65,11 @@ class Chunk2Mal:
|
||||
|
||||
logging.info('chunk2mal exit')
|
||||
|
||||
def start(self):
|
||||
if self._exit_event is not None:
|
||||
return
|
||||
self._exit_event = Event()
|
||||
self._thread = Thread(target=self._on_run)
|
||||
self._thread.start()
|
||||
logging.info('chunk2mal start')
|
||||
|
||||
def stop(self):
|
||||
if self._exit_event is None:
|
||||
return
|
||||
|
||||
self._exit_event.set()
|
||||
self._exit_event.clear()
|
||||
if self._thread.is_alive():
|
||||
self._thread.join()
|
||||
logging.info('chunk2mal stop')
|
||||
|
@ -1,10 +1,19 @@
|
||||
#encoding = utf8
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import edge_tts
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import resampy
|
||||
import queue
|
||||
from io import BytesIO
|
||||
from queue import Queue
|
||||
from threading import Thread, Event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TTSBase:
|
||||
def __init__(self, human):
|
||||
@ -13,13 +22,18 @@ class TTSBase:
|
||||
self._queue = Queue()
|
||||
self._exit_event = None
|
||||
self._io_stream = BytesIO()
|
||||
self._fps = human.get_fps()
|
||||
self._sample_rate = 16000
|
||||
self._chunk = self._sample_rate // self._fps
|
||||
self._chunk = self._sample_rate // self._human.get_fps()
|
||||
|
||||
self._exit_event = Event()
|
||||
self._thread = Thread(target=self._on_run)
|
||||
self._exit_event.set()
|
||||
self._thread.start()
|
||||
logging.info('tts start')
|
||||
|
||||
def _on_run(self):
|
||||
logging.info('tts run')
|
||||
while not self._exit_event.is_set():
|
||||
while self._exit_event.is_set():
|
||||
try:
|
||||
txt = self._queue.get(block=True, timeout=1)
|
||||
except queue.Empty:
|
||||
@ -28,21 +42,50 @@ class TTSBase:
|
||||
logging.info('tts exit')
|
||||
|
||||
def _request(self, txt):
|
||||
pass
|
||||
voice = 'zh-CN-XiaoyiNeural'
|
||||
t = time.time()
|
||||
asyncio.new_event_loop().run_until_complete(self.__on_request(voice, txt))
|
||||
logger.info(f'edge tts time:{time.time() - t : 0.4f}s')
|
||||
|
||||
def start(self):
|
||||
if self._exit_event is not None:
|
||||
return
|
||||
self._exit_event = Event()
|
||||
self._thread = Thread(target=self._on_run)
|
||||
self._thread.start()
|
||||
logging.info('tts start')
|
||||
self._io_stream.seek(0)
|
||||
stream = self.__create_bytes_stream(self._io_stream)
|
||||
stream_len = stream.shape[0]
|
||||
index = 0
|
||||
while stream_len >= self._chunk:
|
||||
self._human.push_audio_chunk(stream[index:index + self._chunk])
|
||||
stream_len -= self._chunk
|
||||
index += self._chunk
|
||||
|
||||
def __create_bytes_stream(self, io_stream):
|
||||
stream, sample_rate = soundfile.read(io_stream)
|
||||
logger.info(f'tts audio stream {sample_rate} : {stream.shape}')
|
||||
stream = stream.astype(np.float32)
|
||||
|
||||
if stream.ndim > 1:
|
||||
logger.warning(f'tts audio has {stream.shape[1]} channels, only use the first')
|
||||
stream = stream[:, 1]
|
||||
|
||||
if sample_rate != self._sample_rate and stream.shape[0] > 0:
|
||||
logger.warning(f'tts audio sample rate is {sample_rate}, resample to {self._sample_rate}')
|
||||
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate)
|
||||
|
||||
return stream
|
||||
|
||||
async def __on_request(self, voice, txt):
|
||||
communicate = edge_tts.Communicate(txt, voice)
|
||||
first = True
|
||||
async for chuck in communicate.stream():
|
||||
if first:
|
||||
first = False
|
||||
|
||||
if chuck['type'] == 'audio':
|
||||
self._io_stream.write(chuck['data'])
|
||||
|
||||
def stop(self):
|
||||
if self._exit_event is None:
|
||||
return
|
||||
|
||||
self._exit_event.set()
|
||||
self._exit_event.clear()
|
||||
self._thread.join()
|
||||
logging.info('tts stop')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user