human/tts/Chunk2Mal.py
2024-10-04 01:52:49 +08:00

120 lines
3.9 KiB
Python

#encoding = utf8
import logging
import queue
import time
from queue import Queue
import multiprocessing as mp
from threading import Thread, Event
import numpy as np
import audio
from audio_render import AudioRender
class Chunk2Mal:
def __init__(self, human):
# self._audio_chunk_queue = Queue()
self._human = human
self._thread = None
self.frames = []
self.queue = Queue()
# self.output_queue = mp.Queue()
# self.feat_queue = mp.Queue(2)
# 320 samples per chunk (20ms * 16000 / 1000)audio_chunk
self.chunk = self._human.get_audio_sample_rate() // self._human.get_fps()
self._exit_event = Event()
self._thread = Thread(target=self._on_run)
self._exit_event.set()
self._thread.start()
# self._audio_render = AudioRender()
self._stream_len = 0
logging.info('chunk2mal start')
def _on_run(self):
logging.info('chunk2mal run')
while self._exit_event.is_set():
self._run_step()
time.sleep(0.01)
logging.info('chunk2mal exit')
def _run_step(self):
for _ in range(self._human.get_batch_size() * 2):
frame, _type = self.get_audio_frame()
self.frames.append(frame)
# put to output
self._human.push_out_put(frame, _type)
# self.output_queue.put((frame, _type))
# context not enough, do not run network.
if len(self.frames) <= self._human.get_stride_left_size() + self._human.get_stride_right_size():
return
inputs = np.concatenate(self.frames) # [N * chunk]
mel = audio.melspectrogram(inputs)
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
# cut off stride
left = max(0, self._human.get_stride_left_size() * 80 / 50)
right = min(len(mel[0]), len(mel[0]) - self._human.get_stride_right_size() * 80 / 50)
mel_idx_multiplier = 80. * 2 / self._human.get_fps()
mel_step_size = 16
i = 0
mel_chunks = []
while i < (len(self.frames) - self._human.get_stride_left_size() - self._human.get_stride_right_size()) / 2:
start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx)
if start_idx + mel_step_size > len(mel[0]):
mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
else:
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
i += 1
self._human.push_mel_chunks(mel_chunks)
# self.feat_queue.put(mel_chunks)
# discard the old part to save memory
self.frames = self.frames[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):]
def stop(self):
if self._exit_event is None:
return
self._exit_event.clear()
if self._thread.is_alive():
self._thread.join()
logging.info('chunk2mal stop')
def pause_talk(self):
self.queue.queue.clear()
def put_audio_frame(self, audio_chunk): #16khz 20ms pcm
self.queue.put(audio_chunk)
def get_audio_frame(self):
try:
frame = self.queue.get(block=True, timeout=0.01)
type = 0
# print(f'[INFO] get frame {frame.shape}')
except queue.Empty:
frame = np.zeros(self.chunk, dtype=np.float32)
type = 1
return frame, type
def get_audio_out(self): # get origin audio pcm to nerf
return self.output_queue.get()
def warm_up(self):
for _ in range(self._human.get_stride_left_size() + self._human.get_stride_right_size()):
audio_frame, _type = self.get_audio_frame()
self.frames.append(audio_frame)
self.output_queue.put((audio_frame, type))
for _ in range(self._human.get_stride_right_size()):
self.output_queue.get()
def get_next_feat(self, block, timeout):
return self.feat_queue.get(block, timeout)