modify video render
This commit is contained in:
parent
c01ec04cd3
commit
d0c8ddb18b
@ -3,4 +3,5 @@
|
|||||||
from .human_context import HumanContext
|
from .human_context import HumanContext
|
||||||
from .audio_mal_handler import AudioMalHandler
|
from .audio_mal_handler import AudioMalHandler
|
||||||
from .audio_inference_handler import AudioInferenceHandler
|
from .audio_inference_handler import AudioInferenceHandler
|
||||||
|
from .audio_inference_onnx_handler import AudioInferenceOnnxHandler
|
||||||
from .human_render import HumanRender
|
from .human_render import HumanRender
|
||||||
|
207
human/audio_inference_onnx_handler.py
Normal file
207
human/audio_inference_onnx_handler.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
#encoding = utf8
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
from threading import Event, Thread
|
||||||
|
|
||||||
|
from gfpgan import GFPGANer
|
||||||
|
from eventbus import EventBus
|
||||||
|
from human_handler import AudioHandler
|
||||||
|
from utils import load_model, mirror_index, get_device, SyncQueue
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
def load_gfpgan_model(model_path):
|
||||||
|
logger.info(f'load_gfpgan_model, path:{model_path}')
|
||||||
|
model = GFPGANer(
|
||||||
|
model_path=model_path,
|
||||||
|
upscale=1,
|
||||||
|
arch='clean',
|
||||||
|
channel_multiplier=2,
|
||||||
|
bg_upsampler=None,
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path):
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
sess_opt = ort.SessionOptions()
|
||||||
|
sess_opt.intra_op_num_threads = 8
|
||||||
|
sess = ort.InferenceSession(model_path, sess_options=sess_opt, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||||
|
|
||||||
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
class AudioInferenceOnnxHandler(AudioHandler):
|
||||||
|
def __init__(self, context, handler):
|
||||||
|
super().__init__(context, handler)
|
||||||
|
|
||||||
|
EventBus().register('stop', self._on_stop)
|
||||||
|
EventBus().register('clear_cache', self.on_clear_cache)
|
||||||
|
self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel")
|
||||||
|
self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio")
|
||||||
|
|
||||||
|
self._is_running = True
|
||||||
|
self._exit_event = Event()
|
||||||
|
self._run_thread = Thread(target=self.__on_run, name="AudioInferenceHandlerThread")
|
||||||
|
self._exit_event.set()
|
||||||
|
self._run_thread.start()
|
||||||
|
|
||||||
|
logger.info("AudioInferenceHandler init")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
EventBus().unregister('stop', self._on_stop)
|
||||||
|
EventBus().unregister('clear_cache', self.on_clear_cache)
|
||||||
|
|
||||||
|
def _on_stop(self, *args, **kwargs):
|
||||||
|
self.stop()
|
||||||
|
|
||||||
|
def on_clear_cache(self, *args, **kwargs):
|
||||||
|
self._mal_queue.clear()
|
||||||
|
self._audio_queue.clear()
|
||||||
|
|
||||||
|
def on_handle(self, stream, type_):
|
||||||
|
if not self._is_running:
|
||||||
|
return
|
||||||
|
|
||||||
|
if type_ == 1:
|
||||||
|
self._mal_queue.put(stream)
|
||||||
|
elif type_ == 0:
|
||||||
|
self._audio_queue.put(stream)
|
||||||
|
# print('AudioInferenceHandler on_handle', type_, self._audio_queue.size())
|
||||||
|
|
||||||
|
def on_message(self, message):
|
||||||
|
super().on_message(message)
|
||||||
|
|
||||||
|
def __on_run(self):
|
||||||
|
model_path = os.path.join(current_file_path, '..', 'checkpoints', 'weights', 'wav2lip')
|
||||||
|
model_a_path = os.path.join(model_path, 'model_a_general.onnx')
|
||||||
|
logger.info(f'AudioInferenceHandler model_a_path, path:{model_a_path}')
|
||||||
|
model_g_path = os.path.join(model_path, 'model_g_general.onnx')
|
||||||
|
logger.info(f'AudioInferenceHandler model_g_path, path:{model_g_path}')
|
||||||
|
|
||||||
|
model_a = load_model(model_a_path)
|
||||||
|
model_g = load_model(model_g_path)
|
||||||
|
logger.info("Model loaded")
|
||||||
|
|
||||||
|
gfpgan_model_path = os.path.join(current_file_path, '..', 'checkpoints', 'gfpgan', 'weights', 'GFPGANv1.4.pth')
|
||||||
|
gfpgan_model = load_gfpgan_model(gfpgan_model_path)
|
||||||
|
|
||||||
|
face_list_cycle = self._context.face_list_cycle
|
||||||
|
|
||||||
|
|
||||||
|
length = len(face_list_cycle)
|
||||||
|
index = 0
|
||||||
|
count = 0
|
||||||
|
count_time = 0
|
||||||
|
logger.info('start inference')
|
||||||
|
|
||||||
|
device = get_device()
|
||||||
|
logger.info(f'use device:{device}')
|
||||||
|
|
||||||
|
while self._is_running:
|
||||||
|
if self._exit_event.is_set():
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
batch_size = self._context.batch_size
|
||||||
|
try:
|
||||||
|
mel_batch = self._mal_queue.get(timeout=0.02)
|
||||||
|
# print('AudioInferenceHandler mel_batch:', len(mel_batch), 'size:', self._mal_queue.size())
|
||||||
|
except queue.Empty:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# print('origin mel_batch:', len(mel_batch))
|
||||||
|
is_all_silence = True
|
||||||
|
audio_frames = []
|
||||||
|
current_text = ''
|
||||||
|
for _ in range(batch_size * 2):
|
||||||
|
frame, type_ = self._audio_queue.get()
|
||||||
|
# print('AudioInferenceHandler type_', type_)
|
||||||
|
current_text = frame[1]
|
||||||
|
audio_frames.append((frame, type_))
|
||||||
|
if type_ == 0:
|
||||||
|
is_all_silence = False
|
||||||
|
|
||||||
|
if not self._is_running:
|
||||||
|
print('AudioInferenceHandler not running')
|
||||||
|
break
|
||||||
|
|
||||||
|
if is_all_silence:
|
||||||
|
for i in range(batch_size):
|
||||||
|
if not self._is_running:
|
||||||
|
break
|
||||||
|
self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
||||||
|
0)
|
||||||
|
index = index + 1
|
||||||
|
else:
|
||||||
|
logger.info(f'infer======= {current_text}')
|
||||||
|
t = time.perf_counter()
|
||||||
|
img_batch = []
|
||||||
|
# for i in range(batch_size):
|
||||||
|
for i in range(len(mel_batch)):
|
||||||
|
idx = mirror_index(length, index + i)
|
||||||
|
face = face_list_cycle[idx]
|
||||||
|
img_batch.append(face)
|
||||||
|
|
||||||
|
# print('orign img_batch:', len(img_batch), 'origin mel_batch:', len(mel_batch))
|
||||||
|
onnx_input = {'audio_seqs__0': mel_batch, }
|
||||||
|
onnx_names = [output.name for output in model_a.get_outputs()]
|
||||||
|
embeddings = model_a.run(onnx_names, onnx_input)[0]
|
||||||
|
|
||||||
|
onnx_input = {"audio_embedings__0": embeddings, "img_seqs__1": img_batch}
|
||||||
|
onnx_names = [output.name for output in model_g.get_outputs()]
|
||||||
|
onnx_out = model_g.run(onnx_names, onnx_input)[0]
|
||||||
|
pred = onnx_out
|
||||||
|
|
||||||
|
# onnxruntime_inputs = {"audio_seqs__0": mel_batch, }
|
||||||
|
# onnxruntime_names = [output.name for output in model_a.get_outputs()]
|
||||||
|
# embeddings = model_a.run(onnxruntime_names, onnxruntime_inputs)[0]
|
||||||
|
#
|
||||||
|
# onnxruntime_inputs = {"audio_embedings__0": embeddings, "img_seqs__1": img_batch}
|
||||||
|
# onnxruntime_names = [output.name for output in model_g.get_outputs()]
|
||||||
|
#
|
||||||
|
# start_model = time.time()
|
||||||
|
# onnxruntime_output = model_g.run(onnxruntime_names, onnxruntime_inputs)[0]
|
||||||
|
# end_model = time.time()
|
||||||
|
# pred = onnxruntime_output
|
||||||
|
|
||||||
|
count_time += (time.perf_counter() - t)
|
||||||
|
count += batch_size
|
||||||
|
|
||||||
|
if count >= 100:
|
||||||
|
logger.info(f"------actual avg infer fps:{count / count_time:.4f}")
|
||||||
|
count = 0
|
||||||
|
count_time = 0
|
||||||
|
|
||||||
|
for res_frame in zip(pred):
|
||||||
|
if not self._is_running:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.on_next_handle(
|
||||||
|
(res_frame[0], mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]),
|
||||||
|
0)
|
||||||
|
index = index + 1
|
||||||
|
logger.info(f'total batch time: {time.perf_counter() - start_time}')
|
||||||
|
else:
|
||||||
|
time.sleep(1)
|
||||||
|
break
|
||||||
|
logger.info('AudioInferenceHandler inference processor stop')
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
logger.info('AudioInferenceHandler stop')
|
||||||
|
self._is_running = False
|
||||||
|
self._exit_event.clear()
|
||||||
|
if self._run_thread.is_alive():
|
||||||
|
logger.info('AudioInferenceHandler stop join')
|
||||||
|
self._run_thread.join()
|
||||||
|
logger.info('AudioInferenceHandler stop exit')
|
||||||
|
|
||||||
|
def pause_talk(self):
|
||||||
|
print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size())
|
||||||
|
self._audio_queue.clear()
|
||||||
|
print('AudioInferenceHandler111')
|
||||||
|
self._mal_queue.clear()
|
||||||
|
print('AudioInferenceHandler222')
|
@ -4,6 +4,7 @@ import os
|
|||||||
|
|
||||||
from asr import SherpaNcnnAsr
|
from asr import SherpaNcnnAsr
|
||||||
from eventbus import EventBus
|
from eventbus import EventBus
|
||||||
|
from .audio_inference_onnx_handler import AudioInferenceOnnxHandler
|
||||||
from .audio_inference_handler import AudioInferenceHandler
|
from .audio_inference_handler import AudioInferenceHandler
|
||||||
from .audio_mal_handler import AudioMalHandler
|
from .audio_mal_handler import AudioMalHandler
|
||||||
from .human_render import HumanRender
|
from .human_render import HumanRender
|
||||||
@ -99,6 +100,14 @@ class HumanContext:
|
|||||||
def coord_list_cycle(self):
|
def coord_list_cycle(self):
|
||||||
return self._coord_list_cycle
|
return self._coord_list_cycle
|
||||||
|
|
||||||
|
@property
|
||||||
|
def align_frames(self):
|
||||||
|
return self._align_frames
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inv_m_frames(self):
|
||||||
|
return self._inv_m_frames
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def render_handler(self):
|
def render_handler(self):
|
||||||
return self._render_handler
|
return self._render_handler
|
||||||
|
@ -10,6 +10,17 @@ import numpy as np
|
|||||||
from .base_render import BaseRender
|
from .base_render import BaseRender
|
||||||
|
|
||||||
|
|
||||||
|
def img_warp_back_inv_m(img, img_to, inv_m):
|
||||||
|
h_up, w_up, c = img_to.shape
|
||||||
|
|
||||||
|
mask = np.ones_like(img).astype(np.float32)
|
||||||
|
inv_mask = cv2.warpAffine(mask, inv_m, (w_up, h_up))
|
||||||
|
inv_img = cv2.warpAffine(img, inv_m, (w_up, h_up))
|
||||||
|
|
||||||
|
img_to[inv_mask == 1] = inv_img[inv_mask == 1]
|
||||||
|
return img_to
|
||||||
|
|
||||||
|
|
||||||
class VideoRender(BaseRender):
|
class VideoRender(BaseRender):
|
||||||
def __init__(self, play_clock, context, human_render):
|
def __init__(self, play_clock, context, human_render):
|
||||||
super().__init__(play_clock, context, 'Video')
|
super().__init__(play_clock, context, 'Video')
|
||||||
@ -24,18 +35,22 @@ class VideoRender(BaseRender):
|
|||||||
else:
|
else:
|
||||||
bbox = self._context.coord_list_cycle[idx]
|
bbox = self._context.coord_list_cycle[idx]
|
||||||
combine_frame = copy.deepcopy(self._context.frame_list_cycle[idx])
|
combine_frame = copy.deepcopy(self._context.frame_list_cycle[idx])
|
||||||
|
af = copy.deepcopy(self._context.align_frames[idx])
|
||||||
|
inv_m = self._context.inv_m_frames[idx]
|
||||||
y1, y2, x1, x2 = bbox
|
y1, y2, x1, x2 = bbox
|
||||||
try:
|
try:
|
||||||
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1))
|
||||||
except:
|
af[y1:y2, x1:x2] = res_frame
|
||||||
print('resize error')
|
combine_frame = img_warp_back_inv_m(af, combine_frame, inv_m)
|
||||||
|
except Exception as e:
|
||||||
|
print('resize error', e)
|
||||||
return
|
return
|
||||||
# cv2.imwrite(f'./images/res_frame_{ self.index }.png', res_frame)
|
# cv2.imwrite(f'./images/res_frame_{ self.index }.png', res_frame)
|
||||||
combine_frame[y1:y2, x1:x2] = res_frame
|
# combine_frame[y1:y2, x1:x2] = res_frame
|
||||||
# cv2.imwrite(f'/combine_frame_{self.index}.png', combine_frame)
|
# cv2.imwrite(f'/combine_frame_{self.index}.png', combine_frame)
|
||||||
# self.index = self.index + 1
|
# self.index = self.index + 1
|
||||||
|
|
||||||
image = combine_frame
|
image = combine_frame
|
||||||
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
if self._human_render is not None:
|
if self._human_render is not None:
|
||||||
self._human_render.put_image(image)
|
self._human_render.put_image(image)
|
||||||
|
Loading…
Reference in New Issue
Block a user