modify inter

This commit is contained in:
brige 2024-11-19 23:18:09 +08:00
parent d650a5d00e
commit 32e4444bb5
3 changed files with 9 additions and 6 deletions

View File

@ -59,7 +59,7 @@ class AudioInferenceHandler(AudioHandler):
super().on_message(message) super().on_message(message)
def __on_run(self): def __on_run(self):
wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'wav2lip.pth') wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'wav2lip_gan.pth')
logger.info(f'AudioInferenceHandler init, path:{wav2lip_path}') logger.info(f'AudioInferenceHandler init, path:{wav2lip_path}')
model = load_model(wav2lip_path) model = load_model(wav2lip_path)
logger.info("Model loaded") logger.info("Model loaded")

View File

@ -18,7 +18,7 @@ current_file_path = os.path.dirname(os.path.abspath(__file__))
class HumanContext: class HumanContext:
def __init__(self): def __init__(self):
self._fps = 50 # 20 ms per frame self._fps = 50 # 20 ms per frame
self._image_size = 96 self._image_size = 128
self._batch_size = 16 self._batch_size = 16
self._sample_rate = 16000 self._sample_rate = 16000
self._stride_left_size = 10 self._stride_left_size = 10
@ -37,9 +37,9 @@ class HumanContext:
print(f'device:{self._device}') print(f'device:{self._device}')
base_path = os.path.join(current_file_path, '..') base_path = os.path.join(current_file_path, '..')
logger.info(f'base path:{base_path}') logger.info(f'base path:{base_path}')
full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device) # full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device)
# full_images, face_frames, coord_frames = load_avatar_from_processed(base_path, full_images, face_frames, coord_frames = load_avatar_from_processed(base_path,
# 'wav2lip_avatar1') 'wav2lip_avatar2')
self._frame_list_cycle = full_images self._frame_list_cycle = full_images
self._face_list_cycle = face_frames self._face_list_cycle = face_frames
self._coord_list_cycle = coord_frames self._coord_list_cycle = coord_frames

View File

@ -15,6 +15,7 @@ class VideoRender(BaseRender):
def __init__(self, play_clock, context, human_render): def __init__(self, play_clock, context, human_render):
super().__init__(play_clock, context, 'Video') super().__init__(play_clock, context, 'Video')
self._human_render = human_render self._human_render = human_render
self.index = 0
def render(self, frame, ps): def render(self, frame, ps):
res_frame, idx, type_ = frame res_frame, idx, type_ = frame
@ -30,7 +31,9 @@ class VideoRender(BaseRender):
except: except:
print('resize error') print('resize error')
return return
combine_frame[y1:y2, x1:x2, :3] = res_frame cv2.imwrite(f'res_frame_{ self.index }.png', res_frame)
self.index = self.index + 1
combine_frame[y1:y2, x1:x2] = res_frame
image = combine_frame image = combine_frame
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)