#encoding = utf8 import io import logging import multiprocessing as mp import platform, subprocess import queue import threading import time import numpy as np import pyaudio import audio import face_detection import utils from infer import Infer from models import Wav2Lip from tts.Chunk2Mal import Chunk2Mal import torch import cv2 from tqdm import tqdm from queue import Queue from tts.EdgeTTS import EdgeTTS from tts.TTSBase import TTSBase device = 'cuda' if torch.cuda.is_available() else 'cpu' def _load(checkpoint_path): if device == 'cuda': checkpoint = torch.load(checkpoint_path) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) return checkpoint def load_model(path): model = Wav2Lip() print("Load checkpoint from: {}".format(path)) logging.info(f'Load checkpoint from {path}') checkpoint = _load(path) s = checkpoint["state_dict"] new_s = {} for k, v in s.items(): new_s[k.replace('module.', '')] = v model.load_state_dict(new_s) model = model.to(device) return model.eval() def read_images(img_list): frames = [] print('reading images...') for img_path in tqdm(img_list): print(f'read image path:{img_path}') frame = cv2.imread(img_path) frames.append(frame) return frames def __mirror_index(size, index): # size = len(self.coord_list_cycle) turn = index // size res = index % size if turn % 2 == 0: return res else: return size - res - 1 # python.exe .\inference.py --checkpoint_path .\checkpoints\wav2lip.pth --face # .\face\img00016.jpg --audio .\audio\audio1.wav def inference(render_event, batch_size, face_images_path, audio_feat_queue, audio_out_queue, res_frame_queue): logging.info(f'Using {device} for inference.') print(f'Using {device} for inference.') print(f'face_images_path: {face_images_path}') model = load_model(r'.\checkpoints\wav2lip.pth') face_list_cycle = read_images(face_images_path) face_images_length = len(face_list_cycle) logging.info(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}') length = len(face_list_cycle) index = 0 count = 0 count_time = 0 logging.info('start inference') print(f'start inference: {render_event.is_set()}') while render_event.is_set(): mel_batch = [] try: mel_batch = audio_feat_queue.get(block=True, timeout=1) except queue.Empty: continue audio_frames = [] is_all_silence = True for _ in range(batch_size * 2): frame, type = audio_out_queue.get() audio_frames.append((frame, type)) if type == 0: is_all_silence = False print(f'is_all_silence {is_all_silence}') if is_all_silence: for i in range(batch_size): res_frame_queue.put((None, __mirror_index(length, index), audio_frames[i*2:i*2+2])) index = index + 1 else: t = time.perf_counter() image_batch = [] for i in range(batch_size): idx = __mirror_index(length, index + i) face = face_list_cycle[idx] image_batch.append(face) image_batch, mel_batch = np.asarray(image_batch), np.asarray(mel_batch) image_masked = image_batch.copy() image_masked[:, face.shape[0]//2:] = 0 image_batch = np.concatenate((image_masked, image_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) image_batch = torch.FloatTensor(np.transpose(image_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) with torch.no_grad(): pred = model(mel_batch, image_batch) pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. count_time += (time.perf_counter() - t) count += batch_size if count >= 100: logging.info(f"------actual avg infer fps:{count/count_time:.4f}") count = 0 count_time = 0 for i, res_frame in enumerate(pred): res_frame_queue.put((res_frame, __mirror_index(length, index), audio_frames[i*2 : i*2+2])) index = index + 1 logging.info('finish inference') def get_smoothened_boxes(boxes, T): for i in range(len(boxes)): if i + T > len(boxes): window = boxes[len(boxes) - T:] else: window = boxes[i : i + T] boxes[i] = np.mean(window, axis=0) return boxes def face_detect(images): detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device=device) batch_size = 16 while 1: predictions = [] try: for i in range(0, len(images), batch_size): predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size]))) except RuntimeError: if batch_size == 1: raise RuntimeError( 'Image too big to run face detection on GPU. Please use the --resize_factor argument') batch_size //= 2 print('Recovering from OOM error; New batch size: {}'.format(batch_size)) continue break results = [] pady1, pady2, padx1, padx2 = [0, 10, 0, 0] for rect, image in zip(predictions, images): if rect is None: cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') y1 = max(0, rect[1] - pady1) y2 = min(image.shape[0], rect[3] + pady2) x1 = max(0, rect[0] - padx1) x2 = min(image.shape[1], rect[2] + padx2) results.append([x1, y1, x2, y2]) boxes = np.array(results) if not False: boxes = get_smoothened_boxes(boxes, T=5) results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] del detector return results img_size = 96 wav2lip_batch_size = 128 def datagen(frames, mels): img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] face_det_results = face_detect(frames) # BGR2RGB for CNN face detection # for i, m in enumerate(mels): for i in range(mels.qsize()): idx = 0 if True else i%len(frames) frame_to_save = frames[__mirror_index(1, i)].copy() face, coords = face_det_results[idx].copy() face = cv2.resize(face, (img_size, img_size)) m = mels.get() img_batch.append(face) mel_batch.append(m) frame_batch.append(frame_to_save) coords_batch.append(coords) if len(img_batch) >= wav2lip_batch_size: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size//2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) yield img_batch, mel_batch, frame_batch, coords_batch img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] if len(img_batch) > 0: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size//2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) yield img_batch, mel_batch, frame_batch, coords_batch def datagen_signal(frame, mel, face_det_results): img_batch, mel_batch, frame_batch, coord_batch = [], [], [], [] # for i, m in enumerate(mels): idx = 0 frame_to_save = frame.copy() face, coord = face_det_results[idx].copy() face = cv2.resize(face, (img_size, img_size)) for i, m in enumerate(mel): img_batch.append(face) mel_batch.append(m) frame_batch.append(frame_to_save) coord_batch.append(coord) if len(img_batch) >= wav2lip_batch_size: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size // 2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) return img_batch, mel_batch, frame_batch, coord_batch if len(img_batch) > 0: img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_masked = img_batch.copy() img_masked[:, img_size//2:] = 0 img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) return img_batch, mel_batch, frame_batch, coord_batch # 从字节流加载音频数据 def load_audio_from_bytes(byte_data): # 使用 BytesIO 创建一个字节流 with io.BytesIO(byte_data) as b: wav = audio.load_wav(b, 16000) # 根据实际库的参数进行调整 return wav # 假设你有音频文件的字节数据 class Human: def __init__(self): self._fps = 25 # 40 ms per frame self._batch_size = 16 self._sample_rate = 16000 self._stride_left_size = 10 self._stride_right_size = 10 self._feat_queue = mp.Queue(2) self._output_queue = mp.Queue() self._res_frame_queue = mp.Queue(self._batch_size * 2) self._chunk_2_mal = Chunk2Mal(self) self._tts = TTSBase(self) self._infer = Infer(self) self.mel_chunks_queue_ = Queue() self.audio_chunks_queue_ = Queue() self._test_image_queue = Queue() # self._thread = None thread = threading.Thread(target=self.test) thread.start() # self.test() # self.play_pcm() # face_images_path = r'./face/' # self._face_image_paths = utils.read_files_path(face_images_path) # print(self._face_image_paths) # self.render_event = mp.Event() # mp.Process(target=inference, args=(self.render_event, self._batch_size, self._face_image_paths, # self._feat_queue, self._output_queue, self._res_frame_queue, # )).start() # self.render_event.set() # def play_pcm(self): # p = pyaudio.PyAudio() # stream = p.open(format=p.get_format_from_width(2), channels=1, rate=16000, output=True) # file1 = r'./audio/en_weather.pcm' # # # 将 pcm 数据直接写入 PyAudio 的数据流 # with open(file1, "rb") as f: # stream.write(f.read()) # # stream.stop_stream() # stream.close() # p.terminate() def inter(self, model, chunks, face_list_cycle, face_det_results, out, j): inputs = np.concatenate(chunks) # [5 * chunk] mel = audio.melspectrogram(inputs) print("inter", len(mel[0])) if np.isnan(mel.reshape(-1)).sum() > 0: raise ValueError( 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') mel_step_size = 16 print('fps:', self._fps) mel_idx_multiplier = 80. / self._fps print('mel_idx_multiplier:', mel_idx_multiplier) i = 0 mel_chunks = [] while 1: start_idx = int(i * mel_idx_multiplier) if start_idx + mel_step_size > len(mel[0]): mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:]) # self.mel_chunks_queue_.put(mel[:, len(mel[0]) - mel_step_size:]) break mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size]) # self.mel_chunks_queue_.put(mel[:, start_idx: start_idx + mel_step_size]) i += 1 self.mel_chunks_queue_.put(mel_chunks) while not self.mel_chunks_queue_.empty(): print("self.mel_chunks_queue_ len:", self.mel_chunks_queue_.qsize()) m = self.mel_chunks_queue_.get() # mel_batch = np.reshape(m, [len(m), mel_batch.shape[1], mel_batch.shape[2], 1]) img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) with torch.no_grad(): pred = model(mel_batch, img_batch) pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. for p, f, c in zip(pred, frames, coords): y1, y2, x1, x2 = c p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) f[y1:y2, x1:x2] = p name = "%04d" % j cv2.imwrite(f'temp/images/{j}.jpg', p) j = j + 1 p = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) self._test_image_queue.put(p) out.write(f) return j def test(self): batch_size = 128 print('batch_size:', batch_size, ' mel_chunks len:', self.mel_chunks_queue_.qsize()) face_images_path = r'./face/' face_images_path = utils.read_files_path(face_images_path) face_list_cycle = read_images(face_images_path) face_images_length = len(face_list_cycle) logging.info(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}') model = load_model(r'.\checkpoints\wav2lip.pth') print("Model loaded") frame_h, frame_w = face_list_cycle[0].shape[:-1] out = cv2.VideoWriter('temp/resul_tttt.avi', cv2.VideoWriter_fourcc(*'DIVX'), 25, (frame_w, frame_h)) face_det_results = face_detect(face_list_cycle) audio_path = r'./temp/audio/chunk_0.wav' stream = audio.load_wav(audio_path, 16000) stream_len = stream.shape[0] print('wav length:', stream_len) _audio_chunk_queue = queue.Queue() index = 0 chunk_len = 640# // 200 print('chunk_len:', chunk_len) while stream_len >= chunk_len: audio_chunk = stream[index:index + chunk_len] _audio_chunk_queue.put(audio_chunk) stream_len -= chunk_len index += chunk_len if stream_len > 0: audio_chunk = stream[index:index + stream_len] _audio_chunk_queue.put(audio_chunk) index += stream_len stream_len -= stream_len print('_audio_chunk_queue:', _audio_chunk_queue.qsize()) j = 0 while not _audio_chunk_queue.empty(): chunks = [] length = min(64, _audio_chunk_queue.qsize()) for i in range(length): chunks.append(_audio_chunk_queue.get()) j = self.inter(model, chunks, face_list_cycle, face_det_results, out, j) out.release() command = 'ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, 'temp/resul_tttt.avi', 'temp/resul_tttt.mp4') subprocess.call(command, shell=platform.system() != 'Windows') # gen = datagen(face_list_cycle, self.mel_chunks_queue_) def get_fps(self): return self._fps def get_batch_size(self): return self._batch_size def get_audio_sample_rate(self): return self._sample_rate def get_stride_left_size(self): return self._stride_left_size def get_stride_right_size(self): return self._stride_right_size def on_destroy(self): # self.render_event.clear() # self._chunk_2_mal.stop() # if self._tts is not None: # self._tts.stop() logging.info('human destroy') def read(self, txt): if self._tts is None: logging.warning('tts is none') return self._tts.push_txt(txt) def push_audio_chunk(self, audio_chunk): self._chunk_2_mal.push_chunk(audio_chunk) def push_mel_chunks_queue(self, mel_chunk): self._infer.push(mel_chunk) # self.audio_chunks_queue_.put(audio_chunk) def push_feat_queue(self, mel_chunks): print("push_feat_queue") self._feat_queue.put(mel_chunks) def push_audio_frames(self, chunk, type_): self._output_queue.put((chunk, type_)) def push_render_image(self, image): self._test_image_queue.put(image) def render(self): try: # img, aud = self._res_frame_queue.get(block=True, timeout=.3) img = self._test_image_queue.get(block=True, timeout=.3) except queue.Empty: # print('render queue.Empty:') return None return img # def pull_audio_chunk(self): # try: # chunk = self._audio_chunk_queue.get(block=True, timeout=1.0) # type = 1 # except queue.Empty: # chunk = np.zeros(self._chunk, dtype=np.float32) # type = 0 # return chunk, type