add audio inferance handler and about codes

This commit is contained in:
jiegeaiai 2024-10-16 08:01:11 +08:00
parent dadfaf4eaf
commit da37374232
7 changed files with 334 additions and 23 deletions

View File

@ -0,0 +1,106 @@
#encoding = utf8
import queue
import time
from threading import Event, Thread
import numpy as np
import torch
from human import AudioHandler
from utils import load_model, mirror_index, get_device
class AudioInferenceHandler(AudioHandler):
def __init__(self, context, handler):
super().__init__(context, handler)
self._exit_event = Event()
self._run_thread = Thread(target=self.__on_run)
self._exit_event.set()
self._run_thread.start()
def on_handle(self, stream, index):
if self._handler is not None:
self._handler.on_handle(stream, index)
def __on_run(self):
model = load_model(r'.\checkpoints\wav2lip.pth')
print("Model loaded")
face_list_cycle = self._human.get_face_list_cycle()
length = len(face_list_cycle)
index = 0
count = 0
count_time = 0
print('start inference')
device = get_device()
print(f'use device:{device}')
while True:
if self._exit_event.is_set():
start_time = time.perf_counter()
batch_size = self._context.batch_size()
try:
mel_batch = self._feat_queue.get(block=True, timeout=0.1)
except queue.Empty:
continue
is_all_silence = True
audio_frames = []
for _ in range(batch_size * 2):
frame, type_ = self._audio_out_queue.get()
audio_frames.append((frame, type_))
if type_ == 0:
is_all_silence = False
if is_all_silence:
for i in range(batch_size):
self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])
index = index + 1
else:
print('infer=======')
t = time.perf_counter()
img_batch = []
for i in range(batch_size):
idx = mirror_index(length, index + i)
face = face_list_cycle[idx]
img_batch.append(face)
img_batch = np.asarray(img_batch)
mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, face.shape[0] // 2:] = 0
#
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch,
[len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
count_time += (time.perf_counter() - t)
count += batch_size
if count >= 100:
print(f"------actual avg infer fps:{count / count_time:.4f}")
count = 0
count_time = 0
image_index = 0
for i, res_frame in enumerate(pred):
self._human.push_res_frame(res_frame, mirror_index(length, index),
audio_frames[i * 2:i * 2 + 2])
index = index + 1
image_index = image_index + 1
print('batch count', image_index)
print('total batch time:', time.perf_counter() - start_time)
else:
time.sleep(1)
break
print('musereal inference processor stop')

View File

@ -8,6 +8,7 @@ from threading import Thread, Event
import numpy as np import numpy as np
from human import AudioHandler from human import AudioHandler
from utils import melspectrogram
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,20 +46,20 @@ class AudioMalHandler(AudioHandler):
# self.output_queue.put((frame, _type)) # self.output_queue.put((frame, _type))
self._human.push_out_put(frame, _type) self._human.push_out_put(frame, _type)
# context not enough, do not run network. # context not enough, do not run network.
if len(self.frames) <= self.stride_left_size + self.stride_right_size: if len(self.frames) <= self._context.stride_left_size() + self._context.stride_right_size():
return return
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]
mel = audio.melspectrogram(inputs) mel = melspectrogram(inputs)
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames)) # print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
# cut off stride # cut off stride
left = max(0, self.stride_left_size * 80 / 50) left = max(0, self._context.stride_left_size() * 80 / 50)
right = min(len(mel[0]), len(mel[0]) - self.stride_right_size * 80 / 50) right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size() * 80 / 50)
mel_idx_multiplier = 80. * 2 / self.fps mel_idx_multiplier = 80. * 2 / self._context.fps()
mel_step_size = 16 mel_step_size = 16
i = 0 i = 0
mel_chunks = [] mel_chunks = []
while i < (len(self.frames) - self.stride_left_size - self.stride_right_size) / 2: while i < (len(self.frames) - self._context.stride_left_size() - self._context.stride_right_size()) / 2:
start_idx = int(left + i * mel_idx_multiplier) start_idx = int(left + i * mel_idx_multiplier)
# print(start_idx) # print(start_idx)
if start_idx + mel_step_size > len(mel[0]): if start_idx + mel_step_size > len(mel[0]):
@ -70,7 +71,7 @@ class AudioMalHandler(AudioHandler):
self._human.push_mel_chunks(mel_chunks) self._human.push_mel_chunks(mel_chunks)
# discard the old part to save memory # discard the old part to save memory
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] self.frames = self.frames[-(self._context.stride_left_size() + self._context.stride_right_size()):]
def get_audio_frame(self): def get_audio_frame(self):
try: try:

View File

@ -1,8 +1,12 @@
#encoding = utf8 #encoding = utf8
import logging
from asr import SherpaNcnnAsr from asr import SherpaNcnnAsr
from nlp import PunctuationSplit, DouBao from nlp import PunctuationSplit, DouBao
from tts import TTSEdge, TTSAudioSplitHandle from tts import TTSEdge, TTSAudioSplitHandle
logger = logging.getLogger(__name__)
class HumanContext: class HumanContext:
def __init__(self): def __init__(self):
@ -12,6 +16,14 @@ class HumanContext:
self._stride_left_size = 10 self._stride_left_size = 10
self._stride_right_size = 10 self._stride_right_size = 10
full_images, face_frames, coord_frames = load_avatar(r'./face/')
self._frame_list_cycle = full_images
self._face_list_cycle = face_frames
self._coord_list_cycle = coord_frames
face_images_length = len(self._face_list_cycle)
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
@property @property
def fps(self): def fps(self):
return self._fps return self._fps
@ -33,7 +45,7 @@ class HumanContext:
return self._stride_right_size return self._stride_right_size
def build(self): def build(self):
tts_handle = TTSAudioSplitHandle(self) tts_handle = TTSAudioSplitHandle(self, None)
tts = TTSEdge(tts_handle) tts = TTSEdge(tts_handle)
split = PunctuationSplit() split = PunctuationSplit()
nlp = DouBao(split, tts) nlp = DouBao(split, tts)

View File

@ -2,12 +2,13 @@
import os import os
import shutil import shutil
from audio import save_wav from utils import save_wav
from human import AudioHandler from human import AudioHandler
class TTSAudioHandle(AudioHandler): class TTSAudioHandle(AudioHandler):
def __init__(self): def __init__(self, context, handler):
super().__init__(context, handler)
self._sample_rate = 16000 self._sample_rate = 16000
self._index = 1 self._index = 1
@ -23,11 +24,13 @@ class TTSAudioHandle(AudioHandler):
self._index = self._index + 1 self._index = self._index + 1
return self._index return self._index
def on_handle(self, stream, index):
pass
class TTSAudioSplitHandle(TTSAudioHandle): class TTSAudioSplitHandle(TTSAudioHandle):
def __init__(self, context): def __init__(self, context, handler):
super().__init__() super().__init__(context, handler)
self._context = context
self.sample_rate = self._context.get_audio_sample_rate() self.sample_rate = self._context.get_audio_sample_rate()
self._chunk = self.sample_rate // self._context.get_fps() self._chunk = self.sample_rate // self._context.get_fps()

View File

@ -1,4 +1,6 @@
#encoding = utf8 #encoding = utf8
from .async_task_queue import AsyncTaskQueue from .async_task_queue import AsyncTaskQueue
from .utils import mirror_index from .utils import mirror_index, load_model, get_device, load_avatar
from .audio_utils import melspectrogram, save_wav

View File

@ -1,34 +1,41 @@
#encoding = utf8
import librosa import librosa
import librosa.filters import librosa.filters
import numpy as np import numpy as np
# import tensorflow as tf
from scipy import signal from scipy import signal
from scipy.io import wavfile from scipy.io import wavfile
from hparams import hparams as hp from hparams import hparams as hp
import soundfile as sf import soundfile as sf
from IPython.display import Audio from IPython.display import Audio
def load_wav(path, sr): def load_wav(path, sr):
return librosa.core.load(path, sr=sr)[0] return librosa.core.load(path, sr=sr)[0]
def save_wav(wav, path, sr): def save_wav(wav, path, sr):
wav *= 32767 / max(0.01, np.max(np.abs(wav))) wav *= 32767 / max(0.01, np.max(np.abs(wav)))
# proposed by @dsmiller # proposed by @dsmiller
wavfile.write(path, sr, wav.astype(np.int16)) wavfile.write(path, sr, wav.astype(np.int16))
def save_wavenet_wav(wav, path, sr): def save_wavenet_wav(wav, path, sr):
librosa.output.write_wav(path, wav, sr=sr) librosa.output.write_wav(path, wav, sr=sr)
def preemphasis(wav, k, preemphasize=True): def preemphasis(wav, k, preemphasize=True):
if preemphasize: if preemphasize:
return signal.lfilter([1, -k], [1], wav) return signal.lfilter([1, -k], [1], wav)
return wav return wav
def inv_preemphasis(wav, k, inv_preemphasize=True): def inv_preemphasis(wav, k, inv_preemphasize=True):
if inv_preemphasize: if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav) return signal.lfilter([1], [1, -k], wav)
return wav return wav
def get_hop_size(): def get_hop_size():
hop_size = hp.hop_size hop_size = hp.hop_size
if hop_size is None: if hop_size is None:
@ -36,6 +43,7 @@ def get_hop_size():
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
return hop_size return hop_size
def linearspectrogram(wav): def linearspectrogram(wav):
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
S = _amp_to_db(np.abs(D)) - hp.ref_level_db S = _amp_to_db(np.abs(D)) - hp.ref_level_db
@ -44,6 +52,7 @@ def linearspectrogram(wav):
return _normalize(S) return _normalize(S)
return S return S
def melspectrogram(wav): def melspectrogram(wav):
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
@ -52,16 +61,19 @@ def melspectrogram(wav):
return _normalize(S) return _normalize(S)
return S return S
def _lws_processor(): def _lws_processor():
import lws import lws
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
def _stft(y): def _stft(y):
if hp.use_lws: if hp.use_lws:
return _lws_processor(hp).stft(y).T return _lws_processor(hp).stft(y).T
else: else:
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size) return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
########################################################## ##########################################################
# Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) # Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
def num_frames(length, fsize, fshift): def num_frames(length, fsize, fshift):
@ -83,32 +95,40 @@ def pad_lr(x, fsize, fshift):
T = len(x) + 2 * pad T = len(x) + 2 * pad
r = (M - 1) * fshift + fsize - T r = (M - 1) * fshift + fsize - T
return pad, pad + r return pad, pad + r
########################################################## ##########################################################
# Librosa correct padding # Librosa correct padding
def librosa_pad_lr(x, fsize, fshift): def librosa_pad_lr(x, fsize, fshift):
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
# Conversions # Conversions
_mel_basis = None _mel_basis = None
def _linear_to_mel(spectogram): def _linear_to_mel(spectogram):
global _mel_basis global _mel_basis
if _mel_basis is None: if _mel_basis is None:
_mel_basis = _build_mel_basis() _mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectogram) return np.dot(_mel_basis, spectogram)
def _build_mel_basis(): def _build_mel_basis():
assert hp.fmax <= hp.sample_rate // 2 assert hp.fmax <= hp.sample_rate // 2
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels, return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
fmin=hp.fmin, fmax=hp.fmax) fmin=hp.fmin, fmax=hp.fmax)
def _amp_to_db(x): def _amp_to_db(x):
min_level = np.exp(hp.min_level_db / 20 * np.log(10)) min_level = np.exp(hp.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x)) return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x): def _db_to_amp(x):
return np.power(10.0, (x) * 0.05) return np.power(10.0, (x) * 0.05)
def _normalize(S): def _normalize(S):
if hp.allow_clipping_in_normalization: if hp.allow_clipping_in_normalization:
if hp.symmetric_mels: if hp.symmetric_mels:
@ -123,6 +143,7 @@ def _normalize(S):
else: else:
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
def _denormalize(D): def _denormalize(D):
if hp.allow_clipping_in_normalization: if hp.allow_clipping_in_normalization:
if hp.symmetric_mels: if hp.symmetric_mels:

View File

@ -1,4 +1,17 @@
#encoding = utf8 #encoding = utf8
import logging
import os
import cv2
import numpy as np
import torch
from tqdm import tqdm
import face_detection
from models import Wav2Lip
logger = logging.getLogger(__name__)
def mirror_index(size, index): def mirror_index(size, index):
# size = len(self.coord_list_cycle) # size = len(self.coord_list_cycle)
@ -8,3 +21,156 @@ def mirror_index(size, index):
return res return res
else: else:
return size - res - 1 return size - res - 1
def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
print(f'read image path:{img_path}')
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def read_files_path(path):
file_paths = []
files = os.listdir(path)
for file in files:
if not os.path.isdir(file):
file_paths.append(path + file)
return file_paths
def get_smoothened_boxes(boxes, t):
for i in range(len(boxes)):
if i + t > len(boxes):
window = boxes[len(boxes) - t:]
else:
window = boxes[i: i + t]
boxes[i] = np.mean(window, axis=0)
return boxes
def datagen_signal(frame, mel, face_det_results, img_size, wav2lip_batch_size=128):
img_batch, mel_batch, frame_batch, coord_batch = [], [], [], []
idx = 0
frame_to_save = frame.copy()
face, coord = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
for i, m in enumerate(mel):
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coord_batch.append(coord)
if len(img_batch) >= wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coord_batch
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coord_batch
def face_detect(images, device):
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
batch_size = 16
while 1:
predictions = []
try:
for i in range(0, len(images), batch_size):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError(
'Image too big to run face detection on GPU. Please use the --resize_factor argument')
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break
results = []
pad_y1, pad_y2, pad_x1, pad_x2 = [0, 10, 0, 0]
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
y1 = max(0, rect[1] - pad_y1)
y2 = min(image.shape[0], rect[3] + pad_y2)
x1 = max(0, rect[0] - pad_x1)
x2 = min(image.shape[1], rect[2] + pad_x2)
results.append([x1, y1, x2, y2])
boxes = np.array(results)
if not False:
boxes = get_smoothened_boxes(boxes, t=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
del detector
return results
def get_device():
return 'cuda' if torch.cuda.is_available() else 'cpu'
def _load(checkpoint_path):
device = get_device
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
logging.info(f'Load checkpoint from {path}')
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
device = get_device()
model = model.to(device)
return model.eval()
def load_avatar(path, img_size, device):
face_images_path = path
face_images_path = read_files_path(face_images_path)
full_list_cycle = read_images(face_images_path)
face_det_results = face_detect(full_list_cycle, device)
face_frames = []
coord_frames = []
for face, coord in face_det_results:
resized_crop_frame = cv2.resize(face, (img_size, img_size))
face_frames.append(resized_crop_frame)
coord_frames.append(coord)
return full_list_cycle, face_frames, coord_frames