diff --git a/Human.py b/Human.py
index ca78f9e..4f35eda 100644
--- a/Human.py
+++ b/Human.py
@@ -1,11 +1,131 @@
 #encoding = utf8
 import logging
 import multiprocessing as mp
+import queue
+import time
 
+import numpy as np
+
+from models import Wav2Lip
 from tts.Chunk2Mal import Chunk2Mal
+import torch
+import cv2
+from tqdm import tqdm
 
 logger = logging.getLogger(__name__)
 
+device = 'cuda' if torch.cuda.is_available() else 'cpu'
+print('Using {} for inference.'.format(device))
+
+
+def _load(checkpoint_path):
+    if device == 'cuda':
+        checkpoint = torch.load(checkpoint_path)
+    else:
+        checkpoint = torch.load(checkpoint_path,
+                                map_location=lambda storage, loc: storage)
+    return checkpoint
+
+
+def load_model(path):
+    model = Wav2Lip()
+    print("Load checkpoint from: {}".format(path))
+    checkpoint = _load(path)
+    s = checkpoint["state_dict"]
+    new_s = {}
+    for k, v in s.items():
+        new_s[k.replace('module.', '')] = v
+    model.load_state_dict(new_s)
+    model = model.to(device)
+    return model.eval()
+
+
+def read_images(img_list):
+    frames = []
+    print('reading images...')
+    for img_path in tqdm(img_list):
+        frame = cv2.imread(img_path)
+        frames.append(frame)
+    return frames
+
+
+def __mirror_index(size, index):
+    # size = len(self.coord_list_cycle)
+    turn = index // size
+    res = index % size
+    if turn % 2 == 0:
+        return res
+    else:
+        return size - res - 1
+
+
+#  python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face
+#  .\face\img00016.jpg --audio .\audio\audio1.wav
+def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue):
+    model = load_model(r'.\checkpoints\wav2lip.pth')
+    face_list_cycle = read_images(face_images_path)
+    face_images_length = len(face_list_cycle)
+    logger.info(f'face images length: {face_images_length}')
+
+    length = len(face_list_cycle)
+    index = 0
+    count = 0
+    count_time = 0
+    logger.info('start inference')
+    while render_event.is_set():
+        try:
+            mel_batch = audio_feat_queue.get(block=True, timeout=1)
+        except queue.Empty:
+            continue
+
+        audio_frames = []
+        is_all_silence = True
+        for _ in range(batch_size * 2):
+            frame, type = audio_feat_queue.get()
+            audio_frames.append((frame, type))
+
+            if type == 0:
+                is_all_silence = False
+
+        if is_all_silence:
+            for i in range(batch_size):
+                res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2]))
+                index = index + 1
+        else:
+            t = time.perf_counter()
+            image_batch = []
+            for i in range(batch_size):
+                idx = __mirror_index(length, index + i)
+                face = face_list_cycle[idx]
+                image_batch.append(face)
+            image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch)
+
+            image_masked = image_batch.copy()
+            image_masked[:, face.shape[0]//2:] = 0
+
+            image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255.
+            mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
+
+            image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device)
+            mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
+
+            with torch.no_grad():
+                pred = model(mel_batch, image_batch)
+            pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
+
+            count_time += (time.perf_counter() - t)
+            count += batch_size
+            if count >= 100:
+                logger.info(f"------actual avg infer fps:{count/count_time:.4f}")
+                count = 0
+                count_time = 0
+
+            for i, res_frame in enumerate(pred):
+                res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2]))
+                index = index + 1
+
+    logger.info('finish inference')
+
 
 class Human:
     def __init__(self):
@@ -18,6 +138,14 @@ class Human:
         self._stride_left_size = 10
         self._stride_right_size = 10
         self._feat_queue = mp.Queue(2)
+        self._output_queue = mp.Queue()
+        self._res_frame_queue = mp.Queue(self._batch_size * 2)
+
+        self.face_images_path = r'.\face'
+        self.render_event = mp.Event()
+        mp.Process(target=inference, args=(self.render_event, self._batch_size, self.face_images_path,
+                                           self._feat_queue, self._output_queue, self._res_frame_queue,
+                                           )).start()
 
     def get_fps(self):
         return self._fps
@@ -35,6 +163,8 @@ class Human:
         return self._stride_right_size
 
     def on_destroy(self):
+        self.render_event.set()
+
         self._chunk_2_mal.stop()
 
         if self._tts is not None:
@@ -60,4 +190,6 @@ class Human:
         self._chunk_2_mal.push_chunk(chunk)
 
     def push_feat_queue(self, mel_chunks):
+        print("21")
         self._feat_queue.put(mel_chunks)
+        print("22")
diff --git a/tts/Chunk2Mal.py b/tts/Chunk2Mal.py
index 004d682..3d4e8a6 100644
--- a/tts/Chunk2Mal.py
+++ b/tts/Chunk2Mal.py
@@ -22,6 +22,7 @@ class Chunk2Mal:
             try:
                 chunk, type = self.pull_chunk()
                 self._chunks.append(chunk)
+                print("1")
             except queue.Empty:
                 continue
 
@@ -38,6 +39,7 @@ class Chunk2Mal:
             mel_chunks = []
             while i < (len(self._chunks) - self._human.get_stride_left_size()
                        - self._human.get_stride_right_size()) / 2:
+                print("14")
                 start_idx = int(left + i * mel_idx_multiplier)
                 # print(start_idx)
                 if start_idx + mel_step_size > len(mel[0]):
@@ -45,10 +47,13 @@ class Chunk2Mal:
                 else:
                     mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
                 i += 1
+                print("13")
             self._human.push_feat_queue(mel_chunks)
+            print("15")
 
             # discard the old part to save memory
             self._chunks = self._chunks[-(self._human.get_stride_left_size() + self._human.get_stride_right_size()):]
+            print("12")
 
         logging.info('chunk2mal exit')
 
@@ -65,7 +70,8 @@ class Chunk2Mal:
             return
 
         self._exit_event.set()
-        self._thread.join()
+        if self._thread.is_alive():
+            self._thread.join()
         logging.info('chunk2mal stop')
 
     def push_chunk(self, chunk):
@@ -73,7 +79,7 @@ class Chunk2Mal:
 
     def pull_chunk(self):
         try:
-            chunk = self._audio_chunk_queue.get(block=True, timeout=1.0)
+            chunk = self._audio_chunk_queue.get(block=True, timeout=1)
             type = 1
         except queue.Empty:
             chunk = np.zeros(self._human.get_chunk(), dtype=np.float32)
diff --git a/ui.py b/ui.py
index b41b4c0..503d018 100644
--- a/ui.py
+++ b/ui.py
@@ -120,18 +120,6 @@ def config_logging(file_name: str, console_level: int=logging.INFO, file_level:
 if __name__ == "__main__":
     # logging.basicConfig(filename='./logs/info.log', level=logging.INFO)
     config_logging('./logs/info.log', logging.INFO, logging.INFO)
-    # logger = logging.getLogger('manager')
-    # # 输出到控制台, 级别为DEBUG
-    # console = logging.StreamHandler()
-    # console.setLevel(logging.DEBUG)
-    # logger.addHandler(console)
-    #
-    # # 输出到文件, 级别为INFO, 文件按大小切分
-    # filelog = logging.handlers.RotatingFileHandler(filename='./logs/info.log', level=logging.INFO,
-    #                                                maxBytes=1024 * 1024, backupCount=5)
-    # filelog.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
-    # logger.setLevel(logging.INFO)
-    # logger.addHandler(filelog)
     logger.info('------------start------------')
     app = App()
     app.mainloop()