modify human
This commit is contained in:
parent
cb9ea0ac17
commit
4e1e923c0b
26
Human.py
26
Human.py
@ -14,6 +14,9 @@ import torch
|
|||||||
import cv2
|
import cv2
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from tts.EdgeTTS import EdgeTTS
|
||||||
|
from tts.TTSBase import TTSBase
|
||||||
|
|
||||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
|
||||||
@ -139,19 +142,18 @@ def inference(render_event, batch_size, face_images_path, audio_feat_queue, audi
|
|||||||
|
|
||||||
class Human:
|
class Human:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._text = None
|
|
||||||
self._tts = None
|
|
||||||
self._fps = 50 # 20 ms per frame
|
self._fps = 50 # 20 ms per frame
|
||||||
self._batch_size = 16
|
self._batch_size = 16
|
||||||
self._sample_rate = 16000
|
self._sample_rate = 16000
|
||||||
self._chunk = self._sample_rate // self._fps # 320 samples per chunk (20ms * 16000 / 1000)
|
|
||||||
self._chunk_2_mal = Chunk2Mal(self)
|
|
||||||
self._stride_left_size = 10
|
self._stride_left_size = 10
|
||||||
self._stride_right_size = 10
|
self._stride_right_size = 10
|
||||||
self._feat_queue = mp.Queue(2)
|
self._feat_queue = mp.Queue(2)
|
||||||
self._output_queue = mp.Queue()
|
self._output_queue = mp.Queue()
|
||||||
self._res_frame_queue = mp.Queue(self._batch_size * 2)
|
self._res_frame_queue = mp.Queue(self._batch_size * 2)
|
||||||
|
|
||||||
|
self._chunk_2_mal = Chunk2Mal(self)
|
||||||
|
self._tts = TTSBase(self)
|
||||||
|
|
||||||
face_images_path = r'./face/'
|
face_images_path = r'./face/'
|
||||||
self._face_image_paths = utils.read_files_path(face_images_path)
|
self._face_image_paths = utils.read_files_path(face_images_path)
|
||||||
print(self._face_image_paths)
|
print(self._face_image_paths)
|
||||||
@ -167,8 +169,8 @@ class Human:
|
|||||||
def get_batch_size(self):
|
def get_batch_size(self):
|
||||||
return self._batch_size
|
return self._batch_size
|
||||||
|
|
||||||
def get_chunk(self):
|
def get_audio_sample_rate(self):
|
||||||
return self._chunk
|
return self._sample_rate
|
||||||
|
|
||||||
def get_stride_left_size(self):
|
def get_stride_left_size(self):
|
||||||
return self._stride_left_size
|
return self._stride_left_size
|
||||||
@ -183,14 +185,6 @@ class Human:
|
|||||||
self._tts.stop()
|
self._tts.stop()
|
||||||
logging.info('human destroy')
|
logging.info('human destroy')
|
||||||
|
|
||||||
def set_tts(self, tts):
|
|
||||||
if self._tts == tts:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._tts = tts
|
|
||||||
self._tts.start()
|
|
||||||
self._chunk_2_mal.start()
|
|
||||||
|
|
||||||
def read(self, txt):
|
def read(self, txt):
|
||||||
if self._tts is None:
|
if self._tts is None:
|
||||||
logging.warning('tts is none')
|
logging.warning('tts is none')
|
||||||
@ -198,8 +192,8 @@ class Human:
|
|||||||
|
|
||||||
self._tts.push_txt(txt)
|
self._tts.push_txt(txt)
|
||||||
|
|
||||||
def push_audio_chunk(self, chunk):
|
def push_audio_chunk(self, audio_chunk):
|
||||||
self._chunk_2_mal.push_chunk(chunk)
|
self._chunk_2_mal.push_chunk(audio_chunk)
|
||||||
|
|
||||||
def push_feat_queue(self, mel_chunks):
|
def push_feat_queue(self, mel_chunks):
|
||||||
print("push_feat_queue")
|
print("push_feat_queue")
|
||||||
|
@ -14,12 +14,20 @@ class Chunk2Mal:
|
|||||||
self._audio_chunk_queue = Queue()
|
self._audio_chunk_queue = Queue()
|
||||||
self._human = human
|
self._human = human
|
||||||
self._thread = None
|
self._thread = None
|
||||||
self._exit_event = None
|
|
||||||
self._chunks = []
|
self._chunks = []
|
||||||
|
# 320 samples per chunk (20ms * 16000 / 1000)audio_chunk
|
||||||
|
self._chunk_len = self._human.get_audio_sample_rate // self._human.get_fps()
|
||||||
|
|
||||||
|
self._exit_event = Event()
|
||||||
|
self._thread = Thread(target=self._on_run)
|
||||||
|
self._exit_event.set()
|
||||||
|
self._thread.start()
|
||||||
|
logging.info('chunk2mal start')
|
||||||
|
|
||||||
def _on_run(self):
|
def _on_run(self):
|
||||||
logging.info('chunk2mal run')
|
logging.info('chunk2mal run')
|
||||||
while not self._exit_event.is_set():
|
while self._exit_event.is_set():
|
||||||
try:
|
try:
|
||||||
chunk, type_ = self.pull_chunk()
|
chunk, type_ = self.pull_chunk()
|
||||||
self._chunks.append(chunk)
|
self._chunks.append(chunk)
|
||||||
@ -57,19 +65,11 @@ class Chunk2Mal:
|
|||||||
|
|
||||||
logging.info('chunk2mal exit')
|
logging.info('chunk2mal exit')
|
||||||
|
|
||||||
def start(self):
|
|
||||||
if self._exit_event is not None:
|
|
||||||
return
|
|
||||||
self._exit_event = Event()
|
|
||||||
self._thread = Thread(target=self._on_run)
|
|
||||||
self._thread.start()
|
|
||||||
logging.info('chunk2mal start')
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if self._exit_event is None:
|
if self._exit_event is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._exit_event.set()
|
self._exit_event.clear()
|
||||||
if self._thread.is_alive():
|
if self._thread.is_alive():
|
||||||
self._thread.join()
|
self._thread.join()
|
||||||
logging.info('chunk2mal stop')
|
logging.info('chunk2mal stop')
|
||||||
|
@ -1,10 +1,19 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
import edge_tts
|
||||||
|
import numpy as np
|
||||||
|
import soundfile
|
||||||
|
import resampy
|
||||||
import queue
|
import queue
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TTSBase:
|
class TTSBase:
|
||||||
def __init__(self, human):
|
def __init__(self, human):
|
||||||
@ -13,13 +22,18 @@ class TTSBase:
|
|||||||
self._queue = Queue()
|
self._queue = Queue()
|
||||||
self._exit_event = None
|
self._exit_event = None
|
||||||
self._io_stream = BytesIO()
|
self._io_stream = BytesIO()
|
||||||
self._fps = human.get_fps()
|
|
||||||
self._sample_rate = 16000
|
self._sample_rate = 16000
|
||||||
self._chunk = self._sample_rate // self._fps
|
self._chunk = self._sample_rate // self._human.get_fps()
|
||||||
|
|
||||||
|
self._exit_event = Event()
|
||||||
|
self._thread = Thread(target=self._on_run)
|
||||||
|
self._exit_event.set()
|
||||||
|
self._thread.start()
|
||||||
|
logging.info('tts start')
|
||||||
|
|
||||||
def _on_run(self):
|
def _on_run(self):
|
||||||
logging.info('tts run')
|
logging.info('tts run')
|
||||||
while not self._exit_event.is_set():
|
while self._exit_event.is_set():
|
||||||
try:
|
try:
|
||||||
txt = self._queue.get(block=True, timeout=1)
|
txt = self._queue.get(block=True, timeout=1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
@ -28,21 +42,50 @@ class TTSBase:
|
|||||||
logging.info('tts exit')
|
logging.info('tts exit')
|
||||||
|
|
||||||
def _request(self, txt):
|
def _request(self, txt):
|
||||||
pass
|
voice = 'zh-CN-XiaoyiNeural'
|
||||||
|
t = time.time()
|
||||||
|
asyncio.new_event_loop().run_until_complete(self.__on_request(voice, txt))
|
||||||
|
logger.info(f'edge tts time:{time.time() - t : 0.4f}s')
|
||||||
|
|
||||||
def start(self):
|
self._io_stream.seek(0)
|
||||||
if self._exit_event is not None:
|
stream = self.__create_bytes_stream(self._io_stream)
|
||||||
return
|
stream_len = stream.shape[0]
|
||||||
self._exit_event = Event()
|
index = 0
|
||||||
self._thread = Thread(target=self._on_run)
|
while stream_len >= self._chunk:
|
||||||
self._thread.start()
|
self._human.push_audio_chunk(stream[index:index + self._chunk])
|
||||||
logging.info('tts start')
|
stream_len -= self._chunk
|
||||||
|
index += self._chunk
|
||||||
|
|
||||||
|
def __create_bytes_stream(self, io_stream):
|
||||||
|
stream, sample_rate = soundfile.read(io_stream)
|
||||||
|
logger.info(f'tts audio stream {sample_rate} : {stream.shape}')
|
||||||
|
stream = stream.astype(np.float32)
|
||||||
|
|
||||||
|
if stream.ndim > 1:
|
||||||
|
logger.warning(f'tts audio has {stream.shape[1]} channels, only use the first')
|
||||||
|
stream = stream[:, 1]
|
||||||
|
|
||||||
|
if sample_rate != self._sample_rate and stream.shape[0] > 0:
|
||||||
|
logger.warning(f'tts audio sample rate is {sample_rate}, resample to {self._sample_rate}')
|
||||||
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate)
|
||||||
|
|
||||||
|
return stream
|
||||||
|
|
||||||
|
async def __on_request(self, voice, txt):
|
||||||
|
communicate = edge_tts.Communicate(txt, voice)
|
||||||
|
first = True
|
||||||
|
async for chuck in communicate.stream():
|
||||||
|
if first:
|
||||||
|
first = False
|
||||||
|
|
||||||
|
if chuck['type'] == 'audio':
|
||||||
|
self._io_stream.write(chuck['data'])
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
if self._exit_event is None:
|
if self._exit_event is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._exit_event.set()
|
self._exit_event.clear()
|
||||||
self._thread.join()
|
self._thread.join()
|
||||||
logging.info('tts stop')
|
logging.info('tts stop')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user