modify audio handle

This commit is contained in:
brige 2024-10-16 19:04:12 +08:00
parent da37374232
commit d8225d3929
5 changed files with 63 additions and 20 deletions

View File

@ -1,6 +1,9 @@
#encoding = utf8 #encoding = utf8
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
logger = logging.getLogger(__name__)
class AudioHandler(ABC): class AudioHandler(ABC):
def __init__(self, context, handler): def __init__(self, context, handler):
@ -9,5 +12,10 @@ class AudioHandler(ABC):
@abstractmethod @abstractmethod
def on_handle(self, stream, index): def on_handle(self, stream, index):
pass
def on_next_handle(self, stream, type_):
if self._handler is not None: if self._handler is not None:
self._handler.on_handle(stream, index) self._handler.on_handle(stream, type_)
else:
logging.info(f'_handler is None')

View File

@ -1,6 +1,7 @@
#encoding = utf8 #encoding = utf8
import queue import queue
import time import time
from queue import Queue
from threading import Event, Thread from threading import Event, Thread
import numpy as np import numpy as np
@ -14,20 +15,25 @@ class AudioInferenceHandler(AudioHandler):
def __init__(self, context, handler): def __init__(self, context, handler):
super().__init__(context, handler) super().__init__(context, handler)
self._mal_queue = Queue()
self._audio_queue = Queue()
self._exit_event = Event() self._exit_event = Event()
self._run_thread = Thread(target=self.__on_run) self._run_thread = Thread(target=self.__on_run)
self._exit_event.set() self._exit_event.set()
self._run_thread.start() self._run_thread.start()
def on_handle(self, stream, index): def on_handle(self, stream, type_):
if self._handler is not None: if type_ == 1:
self._handler.on_handle(stream, index) self._mal_queue.put(stream)
elif type_ == 0:
self._audio_queue.put(stream)
def __on_run(self): def __on_run(self):
model = load_model(r'.\checkpoints\wav2lip.pth') model = load_model(r'.\checkpoints\wav2lip.pth')
print("Model loaded") print("Model loaded")
face_list_cycle = self._human.get_face_list_cycle() face_list_cycle = self._context.face_list_cycle()
length = len(face_list_cycle) length = len(face_list_cycle)
index = 0 index = 0
@ -43,20 +49,21 @@ class AudioInferenceHandler(AudioHandler):
start_time = time.perf_counter() start_time = time.perf_counter()
batch_size = self._context.batch_size() batch_size = self._context.batch_size()
try: try:
mel_batch = self._feat_queue.get(block=True, timeout=0.1) mel_batch = self._mal_queue.get(block=True, timeout=0.1)
except queue.Empty: except queue.Empty:
continue continue
is_all_silence = True is_all_silence = True
audio_frames = [] audio_frames = []
for _ in range(batch_size * 2): for _ in range(batch_size * 2):
frame, type_ = self._audio_out_queue.get() frame, type_ = self._audio_queue.get()
audio_frames.append((frame, type_)) audio_frames.append((frame, type_))
if type_ == 0: if type_ == 0:
is_all_silence = False is_all_silence = False
if is_all_silence: if is_all_silence:
for i in range(batch_size): for i in range(batch_size):
self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]) self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
0)
index = index + 1 index = index + 1
else: else:
print('infer=======') print('infer=======')
@ -71,7 +78,7 @@ class AudioInferenceHandler(AudioHandler):
mel_batch = np.asarray(mel_batch) mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy() img_masked = img_batch.copy()
img_masked[:, face.shape[0] // 2:] = 0 img_masked[:, face.shape[0] // 2:] = 0
#
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, mel_batch = np.reshape(mel_batch,
[len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
@ -94,8 +101,9 @@ class AudioInferenceHandler(AudioHandler):
image_index = 0 image_index = 0
for i, res_frame in enumerate(pred): for i, res_frame in enumerate(pred):
self._human.push_res_frame(res_frame, mirror_index(length, index), self.on_next_handle(
audio_frames[i * 2:i * 2 + 2]) (res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
0)
index = index + 1 index = index + 1
image_index = image_index + 1 image_index = image_index + 1
print('batch count', image_index) print('batch count', image_index)

View File

@ -27,8 +27,7 @@ class AudioMalHandler(AudioHandler):
self.chunk = context.sample_rate() // context.fps() self.chunk = context.sample_rate() // context.fps()
def on_handle(self, stream, index): def on_handle(self, stream, index):
if self._handler is not None: self._queue.put(stream)
self._handler.on_handle(stream, index)
def _on_run(self): def _on_run(self):
logging.info('chunk2mal run') logging.info('chunk2mal run')
@ -42,9 +41,7 @@ class AudioMalHandler(AudioHandler):
for _ in range(self._context.batch_size() * 2): for _ in range(self._context.batch_size() * 2):
frame, _type = self.get_audio_frame() frame, _type = self.get_audio_frame()
self.frames.append(frame) self.frames.append(frame)
# put to output self.on_next_handle((frame, _type), 0)
# self.output_queue.put((frame, _type))
self._human.push_out_put(frame, _type)
# context not enough, do not run network. # context not enough, do not run network.
if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size(): if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size():
return return
@ -67,8 +64,7 @@ class AudioMalHandler(AudioHandler):
else: else:
mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
i += 1 i += 1
# self.feat_queue.put(mel_chunks) self.on_next_handle(mel_chunks, 1)
self._human.push_mel_chunks(mel_chunks)
# discard the old part to save memory # discard the old part to save memory
self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):] self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):]

View File

@ -4,6 +4,7 @@ import logging
from asr import SherpaNcnnAsr from asr import SherpaNcnnAsr
from nlp import PunctuationSplit, DouBao from nlp import PunctuationSplit, DouBao
from tts import TTSEdge, TTSAudioSplitHandle from tts import TTSEdge, TTSAudioSplitHandle
from utils import load_avatar, get_device
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -11,12 +12,14 @@ logger = logging.getLogger(__name__)
class HumanContext: class HumanContext:
def __init__(self): def __init__(self):
self._fps = 50 # 20 ms per frame self._fps = 50 # 20 ms per frame
self._image_size = 96
self._batch_size = 16 self._batch_size = 16
self._sample_rate = 16000 self._sample_rate = 16000
self._stride_left_size = 10 self._stride_left_size = 10
self._stride_right_size = 10 self._stride_right_size = 10
full_images, face_frames, coord_frames = load_avatar(r'./face/') self._device = get_device()
full_images, face_frames, coord_frames = load_avatar(r'./face/', self._device, self._image_size)
self._frame_list_cycle = full_images self._frame_list_cycle = full_images
self._face_list_cycle = face_frames self._face_list_cycle = face_frames
self._coord_list_cycle = coord_frames self._coord_list_cycle = coord_frames
@ -28,6 +31,14 @@ class HumanContext:
def fps(self): def fps(self):
return self._fps return self._fps
@property
def image_size(self):
return self._image_size
@property
def device(self):
return self._device
@property @property
def batch_size(self): def batch_size(self):
return self._batch_size return self._batch_size
@ -44,6 +55,10 @@ class HumanContext:
def stride_right_size(self): def stride_right_size(self):
return self._stride_right_size return self._stride_right_size
@property
def face_list_cycle(self):
return self._face_list_cycle
def build(self): def build(self):
tts_handle = TTSAudioSplitHandle(self, None) tts_handle = TTSAudioSplitHandle(self, None)
tts = TTSEdge(tts_handle) tts = TTSEdge(tts_handle)

16
human/human_render.py Normal file
View File

@ -0,0 +1,16 @@
#encoding = utf8
from queue import Queue
from human import AudioHandler
class HumanRender(AudioHandler):
def __init__(self, context, handler):
super().__init__(context, handler)
self._queue = Queue(context.batch_size * 2)
def on_handle(self, stream, index):
self._queue.put(stream)