diff --git a/human/audio_handler.py b/human/audio_handler.py index 6d5648b..94a7b41 100644 --- a/human/audio_handler.py +++ b/human/audio_handler.py @@ -3,6 +3,11 @@ from abc import ABC, abstractmethod class AudioHandler(ABC): + def __init__(self, context, handler): + self._context = context + self._handler = handler + @abstractmethod def on_handle(self, stream, index): - pass + if self._handler is not None: + self._handler.on_handle(stream, index) diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py new file mode 100644 index 0000000..bb96c2f --- /dev/null +++ b/human/audio_mal_handler.py @@ -0,0 +1,93 @@ +#encoding = utf8 +import logging +import queue +import time +from queue import Queue +from threading import Thread, Event + +import numpy as np + +from human import AudioHandler + +logger = logging.getLogger(__name__) + + +class AudioMalHandler(AudioHandler): + def __init__(self, context, handler): + super().__init__(context, handler) + + self._queue = Queue() + self._exit_event = Event() + self._thread = Thread(target=self._on_run) + self._exit_event.set() + self._thread.start() + + self.frames = [] + self.chunk = context.sample_rate() // context.fps() + + def on_handle(self, stream, index): + if self._handler is not None: + self._handler.on_handle(stream, index) + + def _on_run(self): + logging.info('chunk2mal run') + while self._exit_event.is_set(): + self._run_step() + time.sleep(0.01) + + logging.info('chunk2mal exit') + + def _run_step(self): + for _ in range(self._context.batch_size() * 2): + frame, _type = self.get_audio_frame() + self.frames.append(frame) + # put to output + # self.output_queue.put((frame, _type)) + self._human.push_out_put(frame, _type) + # context not enough, do not run network. + if len(self.frames) <= self.stride_left_size + self.stride_right_size: + return + + inputs = np.concatenate(self.frames) # [N * chunk] + mel = audio.melspectrogram(inputs) + # print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames)) + # cut off stride + left = max(0, self.stride_left_size * 80 / 50) + right = min(len(mel[0]), len(mel[0]) - self.stride_right_size * 80 / 50) + mel_idx_multiplier = 80. * 2 / self.fps + mel_step_size = 16 + i = 0 + mel_chunks = [] + while i < (len(self.frames) - self.stride_left_size - self.stride_right_size) / 2: + start_idx = int(left + i * mel_idx_multiplier) + # print(start_idx) + if start_idx + mel_step_size > len(mel[0]): + mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) + else: + mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) + i += 1 + # self.feat_queue.put(mel_chunks) + self._human.push_mel_chunks(mel_chunks) + + # discard the old part to save memory + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] + + def get_audio_frame(self): + try: + frame = self._queue.get(block=True, timeout=0.01) + type_ = 0 + except queue.Empty: + frame = np.zeros(self.chunk, dtype=np.float32) + type_ = 1 + + return frame, type_ + + def stop(self): + logging.info('stop') + if self._exit_event is None: + return + + self._exit_event.clear() + if self._thread.is_alive(): + self._thread.join() + logging.info('chunk2mal stop') diff --git a/tts/Chunk2Mal.py b/tts/Chunk2Mal.py index f284c38..7b0536a 100644 --- a/tts/Chunk2Mal.py +++ b/tts/Chunk2Mal.py @@ -51,8 +51,6 @@ class Chunk2Mal: for _ in range(self.batch_size * 2): frame, _type = self.get_audio_frame() self.frames.append(frame) - # put to output - # self.output_queue.put((frame, _type)) self._human.push_out_put(frame, _type) # context not enough, do not run network. if len(self.frames) <= self.stride_left_size + self.stride_right_size: