From da37374232651955ce84070eccca097ff63c5116 Mon Sep 17 00:00:00 2001 From: jiegeaiai Date: Wed, 16 Oct 2024 08:01:11 +0800 Subject: [PATCH] add audio inferance handler and about codes --- human/audio_inference_handler.py | 106 +++++++++++++++++++ human/audio_mal_handler.py | 15 +-- human/human_context.py | 14 ++- tts/tts_audio_handle.py | 13 ++- utils/__init__.py | 4 +- audio.py => utils/audio_utils.py | 37 +++++-- utils/utils.py | 168 ++++++++++++++++++++++++++++++- 7 files changed, 334 insertions(+), 23 deletions(-) create mode 100644 human/audio_inference_handler.py rename audio.py => utils/audio_utils.py (96%) diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py new file mode 100644 index 0000000..d64b141 --- /dev/null +++ b/human/audio_inference_handler.py @@ -0,0 +1,106 @@ +#encoding = utf8 +import queue +import time +from threading import Event, Thread + +import numpy as np +import torch + +from human import AudioHandler +from utils import load_model, mirror_index, get_device + + +class AudioInferenceHandler(AudioHandler): + def __init__(self, context, handler): + super().__init__(context, handler) + + self._exit_event = Event() + self._run_thread = Thread(target=self.__on_run) + self._exit_event.set() + self._run_thread.start() + + def on_handle(self, stream, index): + if self._handler is not None: + self._handler.on_handle(stream, index) + + def __on_run(self): + model = load_model(r'.\checkpoints\wav2lip.pth') + print("Model loaded") + + face_list_cycle = self._human.get_face_list_cycle() + + length = len(face_list_cycle) + index = 0 + count = 0 + count_time = 0 + print('start inference') + + device = get_device() + print(f'use device:{device}') + + while True: + if self._exit_event.is_set(): + start_time = time.perf_counter() + batch_size = self._context.batch_size() + try: + mel_batch = self._feat_queue.get(block=True, timeout=0.1) + except queue.Empty: + continue + + is_all_silence = True + audio_frames = [] + for _ in range(batch_size * 2): + frame, type_ = self._audio_out_queue.get() + audio_frames.append((frame, type_)) + if type_ == 0: + is_all_silence = False + if is_all_silence: + for i in range(batch_size): + self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]) + index = index + 1 + else: + print('infer=======') + t = time.perf_counter() + img_batch = [] + for i in range(batch_size): + idx = mirror_index(length, index + i) + face = face_list_cycle[idx] + img_batch.append(face) + + img_batch = np.asarray(img_batch) + mel_batch = np.asarray(mel_batch) + img_masked = img_batch.copy() + img_masked[:, face.shape[0] // 2:] = 0 + # + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, + [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + + with torch.no_grad(): + pred = model(mel_batch, img_batch) + + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + + count_time += (time.perf_counter() - t) + count += batch_size + + if count >= 100: + print(f"------actual avg infer fps:{count / count_time:.4f}") + count = 0 + count_time = 0 + + image_index = 0 + for i, res_frame in enumerate(pred): + self._human.push_res_frame(res_frame, mirror_index(length, index), + audio_frames[i * 2:i * 2 + 2]) + index = index + 1 + image_index = image_index + 1 + print('batch count', image_index) + print('total batch time:', time.perf_counter() - start_time) + else: + time.sleep(1) + break + print('musereal inference processor stop') diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py index bb96c2f..3b12eee 100644 --- a/human/audio_mal_handler.py +++ b/human/audio_mal_handler.py @@ -8,6 +8,7 @@ from threading import Thread, Event import numpy as np from human import AudioHandler +from utils import melspectrogram logger = logging.getLogger(__name__) @@ -45,20 +46,20 @@ class AudioMalHandler(AudioHandler): # self.output_queue.put((frame, _type)) self._human.push_out_put(frame, _type) # context not enough, do not run network. - if len(self.frames) <= self.stride_left_size + self.stride_right_size: + if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size(): return inputs = np.concatenate(self.frames) # [N * chunk] - mel = audio.melspectrogram(inputs) + mel = melspectrogram(inputs) # print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames)) # cut off stride - left = max(0, self.stride_left_size * 80 / 50) - right = min(len(mel[0]), len(mel[0]) - self.stride_right_size * 80 / 50) - mel_idx_multiplier = 80. * 2 / self.fps + left = max(0, self._context.stride_left_size() * 80 / 50) + right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size() * 80 / 50) + mel_idx_multiplier = 80. * 2 / self._context.fps() mel_step_size = 16 i = 0 mel_chunks = [] - while i < (len(self.frames) - self.stride_left_size - self.stride_right_size) / 2: + while i < (len(self.frames) - self._context.stride_left_size() - self._context.stride_right_size()) / 2: start_idx = int(left + i * mel_idx_multiplier) # print(start_idx) if start_idx + mel_step_size > len(mel[0]): @@ -70,7 +71,7 @@ class AudioMalHandler(AudioHandler): self._human.push_mel_chunks(mel_chunks) # discard the old part to save memory - self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] + self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):] def get_audio_frame(self): try: diff --git a/human/human_context.py b/human/human_context.py index 0db4d10..a728e24 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -1,8 +1,12 @@ #encoding = utf8 +import logging + from asr import SherpaNcnnAsr from nlp import PunctuationSplit, DouBao from tts import TTSEdge, TTSAudioSplitHandle +logger = logging.getLogger(__name__) + class HumanContext: def __init__(self): @@ -12,6 +16,14 @@ class HumanContext: self._stride_left_size = 10 self._stride_right_size = 10 + full_images, face_frames, coord_frames = load_avatar(r'./face/') + self._frame_list_cycle = full_images + self._face_list_cycle = face_frames + self._coord_list_cycle = coord_frames + face_images_length = len(self._face_list_cycle) + logging.info(f'face images length: {face_images_length}') + print(f'face images length: {face_images_length}') + @property def fps(self): return self._fps @@ -33,7 +45,7 @@ class HumanContext: return self._stride_right_size def build(self): - tts_handle = TTSAudioSplitHandle(self) + tts_handle = TTSAudioSplitHandle(self, None) tts = TTSEdge(tts_handle) split = PunctuationSplit() nlp = DouBao(split, tts) diff --git a/tts/tts_audio_handle.py b/tts/tts_audio_handle.py index 545705a..8e7babf 100644 --- a/tts/tts_audio_handle.py +++ b/tts/tts_audio_handle.py @@ -2,12 +2,13 @@ import os import shutil -from audio import save_wav +from utils import save_wav from human import AudioHandler class TTSAudioHandle(AudioHandler): - def __init__(self): + def __init__(self, context, handler): + super().__init__(context, handler) self._sample_rate = 16000 self._index = 1 @@ -23,11 +24,13 @@ class TTSAudioHandle(AudioHandler): self._index = self._index + 1 return self._index + def on_handle(self, stream, index): + pass + class TTSAudioSplitHandle(TTSAudioHandle): - def __init__(self, context): - super().__init__() - self._context = context + def __init__(self, context, handler): + super().__init__(context, handler) self.sample_rate = self._context.get_audio_sample_rate() self._chunk = self.sample_rate // self._context.get_fps() diff --git a/utils/__init__.py b/utils/__init__.py index dc67ac2..0d5c37a 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,4 +1,6 @@ #encoding = utf8 from .async_task_queue import AsyncTaskQueue -from .utils import mirror_index +from .utils import mirror_index, load_model, get_device, load_avatar +from .audio_utils import melspectrogram, save_wav + diff --git a/audio.py b/utils/audio_utils.py similarity index 96% rename from audio.py rename to utils/audio_utils.py index 379e3a8..2927bff 100644 --- a/audio.py +++ b/utils/audio_utils.py @@ -1,34 +1,41 @@ +#encoding = utf8 import librosa import librosa.filters import numpy as np -# import tensorflow as tf + from scipy import signal from scipy.io import wavfile from hparams import hparams as hp import soundfile as sf from IPython.display import Audio + def load_wav(path, sr): return librosa.core.load(path, sr=sr)[0] + def save_wav(wav, path, sr): wav *= 32767 / max(0.01, np.max(np.abs(wav))) - #proposed by @dsmiller + # proposed by @dsmiller wavfile.write(path, sr, wav.astype(np.int16)) + def save_wavenet_wav(wav, path, sr): librosa.output.write_wav(path, wav, sr=sr) + def preemphasis(wav, k, preemphasize=True): if preemphasize: return signal.lfilter([1, -k], [1], wav) return wav + def inv_preemphasis(wav, k, inv_preemphasize=True): if inv_preemphasize: return signal.lfilter([1], [1, -k], wav) return wav + def get_hop_size(): hop_size = hp.hop_size if hop_size is None: @@ -36,34 +43,39 @@ def get_hop_size(): hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) return hop_size + def linearspectrogram(wav): D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) S = _amp_to_db(np.abs(D)) - hp.ref_level_db - + if hp.signal_normalization: return _normalize(S) return S + def melspectrogram(wav): D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db - + if hp.signal_normalization: return _normalize(S) return S + def _lws_processor(): import lws return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") + def _stft(y): if hp.use_lws: return _lws_processor(hp).stft(y).T else: return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) + ########################################################## -#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) +# Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) def num_frames(length, fsize, fshift): """Compute number of time frames of spectrogram """ @@ -83,32 +95,40 @@ def pad_lr(x, fsize, fshift): T = len(x) + 2 * pad r = (M - 1) * fshift + fsize - T return pad, pad + r + + ########################################################## -#Librosa correct padding +# Librosa correct padding def librosa_pad_lr(x, fsize, fshift): return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + # Conversions _mel_basis = None + def _linear_to_mel(spectogram): global _mel_basis if _mel_basis is None: _mel_basis = _build_mel_basis() return np.dot(_mel_basis, spectogram) + def _build_mel_basis(): assert hp.fmax <= hp.sample_rate // 2 return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin, fmax=hp.fmax) + def _amp_to_db(x): min_level = np.exp(hp.min_level_db / 20 * np.log(10)) return 20 * np.log10(np.maximum(min_level, x)) + def _db_to_amp(x): return np.power(10.0, (x) * 0.05) + def _normalize(S): if hp.allow_clipping_in_normalization: if hp.symmetric_mels: @@ -116,13 +136,14 @@ def _normalize(S): -hp.max_abs_value, hp.max_abs_value) else: return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value) - + assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 if hp.symmetric_mels: return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value else: return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) + def _denormalize(D): if hp.allow_clipping_in_normalization: if hp.symmetric_mels: @@ -131,7 +152,7 @@ def _denormalize(D): + hp.min_level_db) else: return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db) - + if hp.symmetric_mels: return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db) else: diff --git a/utils/utils.py b/utils/utils.py index 5297afb..371ac87 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,4 +1,17 @@ #encoding = utf8 +import logging +import os + +import cv2 +import numpy as np +import torch +from tqdm import tqdm + +import face_detection +from models import Wav2Lip + +logger = logging.getLogger(__name__) + def mirror_index(size, index): # size = len(self.coord_list_cycle) @@ -7,4 +20,157 @@ def mirror_index(size, index): if turn % 2 == 0: return res else: - return size - res - 1 \ No newline at end of file + return size - res - 1 + + +def read_images(img_list): + frames = [] + print('reading images...') + for img_path in tqdm(img_list): + print(f'read image path:{img_path}') + frame = cv2.imread(img_path) + frames.append(frame) + return frames + + +def read_files_path(path): + file_paths = [] + files = os.listdir(path) + for file in files: + if not os.path.isdir(file): + file_paths.append(path + file) + return file_paths + + +def get_smoothened_boxes(boxes, t): + for i in range(len(boxes)): + if i + t > len(boxes): + window = boxes[len(boxes) - t:] + else: + window = boxes[i: i + t] + boxes[i] = np.mean(window, axis=0) + return boxes + + +def datagen_signal(frame, mel, face_det_results, img_size, wav2lip_batch_size=128): + img_batch, mel_batch, frame_batch, coord_batch = [], [], [], [] + + idx = 0 + frame_to_save = frame.copy() + face, coord = face_det_results[idx].copy() + + face = cv2.resize(face, (img_size, img_size)) + + for i, m in enumerate(mel): + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coord_batch.append(coord) + + if len(img_batch) >= wav2lip_batch_size: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, img_size // 2:] = 0 + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + return img_batch, mel_batch, frame_batch, coord_batch + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + img_masked = img_batch.copy() + img_masked[:, img_size//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + return img_batch, mel_batch, frame_batch, coord_batch + + +def face_detect(images, device): + detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, + flip_input=False, device=device) + batch_size = 16 + + while 1: + predictions = [] + try: + for i in range(0, len(images), batch_size): + predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError( + 'Image too big to run face detection on GPU. Please use the --resize_factor argument') + batch_size //= 2 + print('Recovering from OOM error; New batch size: {}'.format(batch_size)) + continue + break + + results = [] + pad_y1, pad_y2, pad_x1, pad_x2 = [0, 10, 0, 0] + for rect, image in zip(predictions, images): + if rect is None: + cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. + raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') + + y1 = max(0, rect[1] - pad_y1) + y2 = min(image.shape[0], rect[3] + pad_y2) + x1 = max(0, rect[0] - pad_x1) + x2 = min(image.shape[1], rect[2] + pad_x2) + + results.append([x1, y1, x2, y2]) + + boxes = np.array(results) + if not False: + boxes = get_smoothened_boxes(boxes, t=5) + results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + del detector + return results + + +def get_device(): + return 'cuda' if torch.cuda.is_available() else 'cpu' + + +def _load(checkpoint_path): + device = get_device + if device == 'cuda': + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + + +def load_model(path): + model = Wav2Lip() + print("Load checkpoint from: {}".format(path)) + logging.info(f'Load checkpoint from {path}') + checkpoint = _load(path) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s) + device = get_device() + model = model.to(device) + return model.eval() + + +def load_avatar(path, img_size, device): + face_images_path = path + face_images_path = read_files_path(face_images_path) + full_list_cycle = read_images(face_images_path) + + face_det_results = face_detect(full_list_cycle, device) + + face_frames = [] + coord_frames = [] + for face, coord in face_det_results: + resized_crop_frame = cv2.resize(face, (img_size, img_size)) + face_frames.append(resized_crop_frame) + coord_frames.append(coord) + + return full_list_cycle, face_frames, coord_frames