modify audio handle
This commit is contained in:
parent
da37374232
commit
d8225d3929
@ -1,6 +1,9 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AudioHandler(ABC):
|
class AudioHandler(ABC):
|
||||||
def __init__(self, context, handler):
|
def __init__(self, context, handler):
|
||||||
@ -9,5 +12,10 @@ class AudioHandler(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_handle(self, stream, index):
|
def on_handle(self, stream, index):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_next_handle(self, stream, type_):
|
||||||
if self._handler is not None:
|
if self._handler is not None:
|
||||||
self._handler.on_handle(stream, index)
|
self._handler.on_handle(stream, type_)
|
||||||
|
else:
|
||||||
|
logging.info(f'_handler is None')
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#encoding = utf8
|
#encoding = utf8
|
||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
|
from queue import Queue
|
||||||
from threading import Event, Thread
|
from threading import Event, Thread
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -14,20 +15,25 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
def __init__(self, context, handler):
|
def __init__(self, context, handler):
|
||||||
super().__init__(context, handler)
|
super().__init__(context, handler)
|
||||||
|
|
||||||
|
self._mal_queue = Queue()
|
||||||
|
self._audio_queue = Queue()
|
||||||
|
|
||||||
self._exit_event = Event()
|
self._exit_event = Event()
|
||||||
self._run_thread = Thread(target=self.__on_run)
|
self._run_thread = Thread(target=self.__on_run)
|
||||||
self._exit_event.set()
|
self._exit_event.set()
|
||||||
self._run_thread.start()
|
self._run_thread.start()
|
||||||
|
|
||||||
def on_handle(self, stream, index):
|
def on_handle(self, stream, type_):
|
||||||
if self._handler is not None:
|
if type_ == 1:
|
||||||
self._handler.on_handle(stream, index)
|
self._mal_queue.put(stream)
|
||||||
|
elif type_ == 0:
|
||||||
|
self._audio_queue.put(stream)
|
||||||
|
|
||||||
def __on_run(self):
|
def __on_run(self):
|
||||||
model = load_model(r'.\checkpoints\wav2lip.pth')
|
model = load_model(r'.\checkpoints\wav2lip.pth')
|
||||||
print("Model loaded")
|
print("Model loaded")
|
||||||
|
|
||||||
face_list_cycle = self._human.get_face_list_cycle()
|
face_list_cycle = self._context.face_list_cycle()
|
||||||
|
|
||||||
length = len(face_list_cycle)
|
length = len(face_list_cycle)
|
||||||
index = 0
|
index = 0
|
||||||
@ -43,20 +49,21 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch_size = self._context.batch_size()
|
batch_size = self._context.batch_size()
|
||||||
try:
|
try:
|
||||||
mel_batch = self._feat_queue.get(block=True, timeout=0.1)
|
mel_batch = self._mal_queue.get(block=True, timeout=0.1)
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
is_all_silence = True
|
is_all_silence = True
|
||||||
audio_frames = []
|
audio_frames = []
|
||||||
for _ in range(batch_size * 2):
|
for _ in range(batch_size * 2):
|
||||||
frame, type_ = self._audio_out_queue.get()
|
frame, type_ = self._audio_queue.get()
|
||||||
audio_frames.append((frame, type_))
|
audio_frames.append((frame, type_))
|
||||||
if type_ == 0:
|
if type_ == 0:
|
||||||
is_all_silence = False
|
is_all_silence = False
|
||||||
if is_all_silence:
|
if is_all_silence:
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])
|
self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
||||||
|
0)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
else:
|
else:
|
||||||
print('infer=======')
|
print('infer=======')
|
||||||
@ -71,7 +78,7 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
mel_batch = np.asarray(mel_batch)
|
mel_batch = np.asarray(mel_batch)
|
||||||
img_masked = img_batch.copy()
|
img_masked = img_batch.copy()
|
||||||
img_masked[:, face.shape[0] // 2:] = 0
|
img_masked[:, face.shape[0] // 2:] = 0
|
||||||
#
|
|
||||||
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
|
||||||
mel_batch = np.reshape(mel_batch,
|
mel_batch = np.reshape(mel_batch,
|
||||||
[len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
[len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
|
||||||
@ -94,8 +101,9 @@ class AudioInferenceHandler(AudioHandler):
|
|||||||
|
|
||||||
image_index = 0
|
image_index = 0
|
||||||
for i, res_frame in enumerate(pred):
|
for i, res_frame in enumerate(pred):
|
||||||
self._human.push_res_frame(res_frame, mirror_index(length, index),
|
self.on_next_handle(
|
||||||
audio_frames[i * 2:i * 2 + 2])
|
(res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
||||||
|
0)
|
||||||
index = index + 1
|
index = index + 1
|
||||||
image_index = image_index + 1
|
image_index = image_index + 1
|
||||||
print('batch count', image_index)
|
print('batch count', image_index)
|
||||||
|
@ -27,8 +27,7 @@ class AudioMalHandler(AudioHandler):
|
|||||||
self.chunk = context.sample_rate() // context.fps()
|
self.chunk = context.sample_rate() // context.fps()
|
||||||
|
|
||||||
def on_handle(self, stream, index):
|
def on_handle(self, stream, index):
|
||||||
if self._handler is not None:
|
self._queue.put(stream)
|
||||||
self._handler.on_handle(stream, index)
|
|
||||||
|
|
||||||
def _on_run(self):
|
def _on_run(self):
|
||||||
logging.info('chunk2mal run')
|
logging.info('chunk2mal run')
|
||||||
@ -42,9 +41,7 @@ class AudioMalHandler(AudioHandler):
|
|||||||
for _ in range(self._context.batch_size() * 2):
|
for _ in range(self._context.batch_size() * 2):
|
||||||
frame, _type = self.get_audio_frame()
|
frame, _type = self.get_audio_frame()
|
||||||
self.frames.append(frame)
|
self.frames.append(frame)
|
||||||
# put to output
|
self.on_next_handle((frame, _type), 0)
|
||||||
# self.output_queue.put((frame, _type))
|
|
||||||
self._human.push_out_put(frame, _type)
|
|
||||||
# context not enough, do not run network.
|
# context not enough, do not run network.
|
||||||
if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size():
|
if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size():
|
||||||
return
|
return
|
||||||
@ -67,8 +64,7 @@ class AudioMalHandler(AudioHandler):
|
|||||||
else:
|
else:
|
||||||
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
|
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
|
||||||
i += 1
|
i += 1
|
||||||
# self.feat_queue.put(mel_chunks)
|
self.on_next_handle(mel_chunks, 1)
|
||||||
self._human.push_mel_chunks(mel_chunks)
|
|
||||||
|
|
||||||
# discard the old part to save memory
|
# discard the old part to save memory
|
||||||
self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):]
|
self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):]
|
||||||
|
@ -4,6 +4,7 @@ import logging
|
|||||||
from asr import SherpaNcnnAsr
|
from asr import SherpaNcnnAsr
|
||||||
from nlp import PunctuationSplit, DouBao
|
from nlp import PunctuationSplit, DouBao
|
||||||
from tts import TTSEdge, TTSAudioSplitHandle
|
from tts import TTSEdge, TTSAudioSplitHandle
|
||||||
|
from utils import load_avatar, get_device
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -11,12 +12,14 @@ logger = logging.getLogger(__name__)
|
|||||||
class HumanContext:
|
class HumanContext:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._fps = 50 # 20 ms per frame
|
self._fps = 50 # 20 ms per frame
|
||||||
|
self._image_size = 96
|
||||||
self._batch_size = 16
|
self._batch_size = 16
|
||||||
self._sample_rate = 16000
|
self._sample_rate = 16000
|
||||||
self._stride_left_size = 10
|
self._stride_left_size = 10
|
||||||
self._stride_right_size = 10
|
self._stride_right_size = 10
|
||||||
|
|
||||||
full_images, face_frames, coord_frames = load_avatar(r'./face/')
|
self._device = get_device()
|
||||||
|
full_images, face_frames, coord_frames = load_avatar(r'./face/', self._device, self._image_size)
|
||||||
self._frame_list_cycle = full_images
|
self._frame_list_cycle = full_images
|
||||||
self._face_list_cycle = face_frames
|
self._face_list_cycle = face_frames
|
||||||
self._coord_list_cycle = coord_frames
|
self._coord_list_cycle = coord_frames
|
||||||
@ -28,6 +31,14 @@ class HumanContext:
|
|||||||
def fps(self):
|
def fps(self):
|
||||||
return self._fps
|
return self._fps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def image_size(self):
|
||||||
|
return self._image_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self._device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return self._batch_size
|
return self._batch_size
|
||||||
@ -44,6 +55,10 @@ class HumanContext:
|
|||||||
def stride_right_size(self):
|
def stride_right_size(self):
|
||||||
return self._stride_right_size
|
return self._stride_right_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def face_list_cycle(self):
|
||||||
|
return self._face_list_cycle
|
||||||
|
|
||||||
def build(self):
|
def build(self):
|
||||||
tts_handle = TTSAudioSplitHandle(self, None)
|
tts_handle = TTSAudioSplitHandle(self, None)
|
||||||
tts = TTSEdge(tts_handle)
|
tts = TTSEdge(tts_handle)
|
||||||
|
16
human/human_render.py
Normal file
16
human/human_render.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
|
from human import AudioHandler
|
||||||
|
|
||||||
|
|
||||||
|
class HumanRender(AudioHandler):
|
||||||
|
def __init__(self, context, handler):
|
||||||
|
super().__init__(context, handler)
|
||||||
|
|
||||||
|
self._queue = Queue(context.batch_size * 2)
|
||||||
|
|
||||||
|
def on_handle(self, stream, index):
|
||||||
|
self._queue.put(stream)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user