render image to ui

This commit is contained in:
jiegeaiai 2024-09-27 01:34:52 +08:00
parent bef51d5c47
commit e606fb6ef5
7 changed files with 273 additions and 34 deletions

1
.gitignore vendored
View File

@ -1,5 +1,4 @@
*.pkl *.pkl
*.jpg
*.mp4 *.mp4
*.pth *.pth
*.pyc *.pyc

View File

@ -1,4 +1,5 @@
#encoding = utf8 #encoding = utf8
import io
import logging import logging
import multiprocessing as mp import multiprocessing as mp
@ -14,6 +15,7 @@ import pyaudio
import audio import audio
import face_detection import face_detection
import utils import utils
from infer import Infer
from models import Wav2Lip from models import Wav2Lip
from tts.Chunk2Mal import Chunk2Mal from tts.Chunk2Mal import Chunk2Mal
import torch import torch
@ -160,7 +162,6 @@ def get_smoothened_boxes(boxes, T):
def face_detect(images): def face_detect(images):
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device) flip_input=False, device=device)
batch_size = 16 batch_size = 16
while 1: while 1:
@ -281,6 +282,16 @@ def datagen_signal(frame, mel, face_det_results):
return img_batch, mel_batch, frame_batch, coords_batch return img_batch, mel_batch, frame_batch, coords_batch
# 从字节流加载音频数据
def load_audio_from_bytes(byte_data):
# 使用 BytesIO 创建一个字节流
with io.BytesIO(byte_data) as b:
wav = audio.load_wav(b, 16000) # 根据实际库的参数进行调整
return wav
# 假设你有音频文件的字节数据
class Human: class Human:
def __init__(self): def __init__(self):
self._fps = 25 # 20 ms per frame self._fps = 25 # 20 ms per frame
@ -294,12 +305,15 @@ class Human:
self._chunk_2_mal = Chunk2Mal(self) self._chunk_2_mal = Chunk2Mal(self)
self._tts = TTSBase(self) self._tts = TTSBase(self)
self._infer = Infer(self)
self.mel_chunks_queue_ = Queue() self.mel_chunks_queue_ = Queue()
self.audio_chunks_queue_ = Queue() self.audio_chunks_queue_ = Queue()
self._test_image_queue = Queue() self._test_image_queue = Queue()
self._thread = None self._thread = None
thread = threading.Thread(target=self.test)
thread.start()
# self.test() # self.test()
# self.play_pcm() # self.play_pcm()
@ -326,7 +340,13 @@ class Human:
# p.terminate() # p.terminate()
def test(self): def test(self):
wav = audio.load_wav(r'./audio/audio1.wav', 16000) wav = audio.load_wav(r'./audio/test.wav', 16000)
# with open(r'./audio/test.wav', 'rb') as f:
# byte_data = f.read()
#
# byte_data = byte_data[16:]
# inputs = np.concatenate(byte_data) # [N * chunk]
# wav = load_audio_from_bytes(inputs)
mel = audio.melspectrogram(wav) mel = audio.melspectrogram(wav)
if np.isnan(mel.reshape(-1)).sum() > 0: if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError( raise ValueError(
@ -432,7 +452,7 @@ class Human:
self._chunk_2_mal.push_chunk(audio_chunk) self._chunk_2_mal.push_chunk(audio_chunk)
def push_mel_chunks_queue(self, mel_chunk): def push_mel_chunks_queue(self, mel_chunk):
self.mel_chunks_queue_.put(mel_chunk) self._infer.push(mel_chunk)
# self.audio_chunks_queue_.put(audio_chunk) # self.audio_chunks_queue_.put(audio_chunk)
def push_feat_queue(self, mel_chunks): def push_feat_queue(self, mel_chunks):
@ -443,6 +463,9 @@ class Human:
print("push_audio_frames") print("push_audio_frames")
self._output_queue.put((chunk, type_)) self._output_queue.put((chunk, type_))
def push_render_image(self, image):
self._test_image_queue.put(image)
def render(self): def render(self):
try: try:
# img, aud = self._res_frame_queue.get(block=True, timeout=.3) # img, aud = self._res_frame_queue.get(block=True, timeout=.3)

View File

@ -22,7 +22,7 @@ async def play_tts(text, voice):
# 设置 PyAudio # 设置 PyAudio
audio = pyaudio.PyAudio() audio = pyaudio.PyAudio()
stream = audio.open(format=pyaudio.paInt16, channels=1, rate=24000, output=True) stream = audio.open(format=pyaudio.paInt16, channels=1, rate=16000, output=True)
# async for chunk in communicate.stream(): # 使用 stream 方法 # async for chunk in communicate.stream(): # 使用 stream 方法
# if chunk['type'] == 'audio': # 确保 chunk 是字节流 # if chunk['type'] == 'audio': # 确保 chunk 是字节流

200
infer.py Normal file
View File

@ -0,0 +1,200 @@
#encoding = utf8
import queue
from queue import Queue
from threading import Thread, Event
import logging
import cv2
import numpy as np
import torch
from tqdm import tqdm
import face_detection
import utils
from models import Wav2Lip
logger = logging.getLogger(__name__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def read_images(img_list):
frames = []
print('reading images...')
for img_path in tqdm(img_list):
print(f'read image path:{img_path}')
frame = cv2.imread(img_path)
frames.append(frame)
return frames
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
return checkpoint
def load_model(path):
model = Wav2Lip()
print("Load checkpoint from: {}".format(path))
logging.info(f'Load checkpoint from {path}')
checkpoint = _load(path)
s = checkpoint["state_dict"]
new_s = {}
for k, v in s.items():
new_s[k.replace('module.', '')] = v
model.load_state_dict(new_s)
model = model.to(device)
return model.eval()
def get_smoothened_boxes(boxes, T):
for i in range(len(boxes)):
if i + T > len(boxes):
window = boxes[len(boxes) - T:]
else:
window = boxes[i : i + T]
boxes[i] = np.mean(window, axis=0)
return boxes
def face_detect(images):
detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D,
flip_input=False, device=device)
batch_size = 16
while 1:
predictions = []
try:
for i in range(0, len(images), batch_size):
predictions.extend(detector.get_detections_for_batch(np.array(images[i:i + batch_size])))
except RuntimeError:
if batch_size == 1:
raise RuntimeError(
'Image too big to run face detection on GPU. Please use the --resize_factor argument')
batch_size //= 2
print('Recovering from OOM error; New batch size: {}'.format(batch_size))
continue
break
results = []
pady1, pady2, padx1, padx2 = [0, 10, 0, 0]
for rect, image in zip(predictions, images):
if rect is None:
cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
y1 = max(0, rect[1] - pady1)
y2 = min(image.shape[0], rect[3] + pady2)
x1 = max(0, rect[0] - padx1)
x2 = min(image.shape[1], rect[2] + padx2)
results.append([x1, y1, x2, y2])
boxes = np.array(results)
boxes = get_smoothened_boxes(boxes, T=5)
results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
del detector
return results
img_size = 96
wav2lip_batch_size = 128
def datagen_signal(frame, mel, face_det_results):
img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
# for i, m in enumerate(mels):
idx = 0
frame_to_save = frame.copy()
face, coords = face_det_results[idx].copy()
face = cv2.resize(face, (img_size, img_size))
m = mel
img_batch.append(face)
mel_batch.append(m)
frame_batch.append(frame_to_save)
coords_batch.append(coords)
if len(img_batch) >= wav2lip_batch_size:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size // 2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coords_batch
if len(img_batch) > 0:
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, img_size//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
return img_batch, mel_batch, frame_batch, coords_batch
class Infer:
def __init__(self, human):
self._human = human
self._queue = Queue()
self._exit_event = Event()
self._run_thread = Thread(target=self.__on_run)
self._exit_event.set()
self._run_thread.start()
def __on_run(self):
face_images_path = r'./face/'
face_images_path = utils.read_files_path(face_images_path)
face_list_cycle = read_images(face_images_path)
face_images_length = len(face_list_cycle)
logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}')
model = load_model(r'.\checkpoints\wav2lip.pth')
print("Model loaded")
# frame_h, frame_w = face_list_cycle[0].shape[:-1]
face_det_results = face_detect(face_list_cycle)
j = 0
while self._exit_event.is_set():
try:
m = self._queue.get(block=True, timeout=1)
except queue.Empty:
continue
img_batch, mel_batch, frames, coords = datagen_signal(face_list_cycle[0], m, face_det_results)
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
for p, f, c in zip(pred, frames, coords):
y1, y2, x1, x2 = c
p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1))
f[y1:y2, x1:x2] = p
# name = "%04d" % j
# cv2.imwrite(f'temp/images/{j}.jpg', p)
# j = j + 1
p = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
self._human.push_render_image(p)
# out.write(f)
def push(self, chunk):
self._queue.put(chunk)

View File

@ -2,6 +2,7 @@
import logging import logging
import queue import queue
import time
from queue import Queue from queue import Queue
from threading import Thread, Event from threading import Thread, Event
@ -28,25 +29,28 @@ class Chunk2Mal:
def _on_run(self): def _on_run(self):
logging.info('chunk2mal run') logging.info('chunk2mal run')
while self._exit_event.is_set(): while self._exit_event.is_set():
if self._audio_chunk_queue.empty():
time.sleep(0.5)
continue
try: try:
chunk, type_ = self.pull_chunk() chunk = self._audio_chunk_queue.get(block=True, timeout=1)
self._chunks.append(chunk) self._chunks.append(chunk)
self._human.push_audio_frames(chunk, type_) self._human.push_audio_frames(chunk, 0)
if len(self._chunks) < 10:
continue
except queue.Empty: except queue.Empty:
# print('Chunk2Mal queue.Empty') # print('Chunk2Mal queue.Empty')
continue continue
if type_ == 0:
continue
logging.info('np.concatenate') logging.info('np.concatenate')
mel = audio.melspectrogram(chunk) inputs = np.concatenate(self._chunks) # [N * chunk]
mel = audio.melspectrogram(inputs)
if np.isnan(mel.reshape(-1)).sum() > 0: if np.isnan(mel.reshape(-1)).sum() > 0:
raise ValueError( raise ValueError(
'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again') 'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
mel_step_size = 16 mel_step_size = 16
print('fps:', self._human.get_fps()) print('fps:', self._human.get_fps())
mel_idx_multiplier = 80. / self._human.get_fps() mel_idx_multiplier = 80. / self._human.get_fps()
print('mel_idx_multiplier:', mel_idx_multiplier) print('mel_idx_multiplier:', mel_idx_multiplier)
@ -55,10 +59,8 @@ class Chunk2Mal:
while 1: while 1:
start_idx = int(i * mel_idx_multiplier) start_idx = int(i * mel_idx_multiplier)
if start_idx + mel_step_size > len(mel[0]): if start_idx + mel_step_size > len(mel[0]):
# mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
self._human.push_mel_chunks_queue(mel[:, len(mel[0]) - mel_step_size:]) self._human.push_mel_chunks_queue(mel[:, len(mel[0]) - mel_step_size:])
break break
# mel_chunks.append(mel[:, start_idx: start_idx + mel_step_size])
self._human.push_mel_chunks_queue(mel[:, start_idx: start_idx + mel_step_size]) self._human.push_mel_chunks_queue(mel[:, start_idx: start_idx + mel_step_size])
i += 1 i += 1

View File

@ -7,6 +7,7 @@ import edge_tts
import numpy as np import numpy as np
import pyaudio import pyaudio
import soundfile import soundfile
import sounddevice
import resampy import resampy
import queue import queue
from io import BytesIO from io import BytesIO
@ -23,18 +24,16 @@ class TTSBase:
self._human = human self._human = human
self._thread = None self._thread = None
self._queue = Queue() self._queue = Queue()
self._exit_event = None
self._io_stream = BytesIO() self._io_stream = BytesIO()
self._sample_rate = 16000 self._chunk_len = self._human.get_audio_sample_rate() // self._human.get_fps()
self._chunk_len = self._sample_rate // self._human.get_fps()
self._exit_event = Event() self._exit_event = Event()
self._thread = Thread(target=self._on_run) self._thread = Thread(target=self._on_run)
self._exit_event.set() self._exit_event.set()
self._thread.start() self._thread.start()
self._pcm_player = pyaudio.PyAudio() # self._pcm_player = pyaudio.PyAudio()
self._pcm_stream = self._pcm_player.open(format=pyaudio.paInt16, # self._pcm_stream = self._pcm_player.open(format=pyaudio.paInt16,
channels=1, rate=16000, output=True) # channels=1, rate=24000, output=True)
logging.info('tts start') logging.info('tts start')
def _on_run(self): def _on_run(self):
@ -56,16 +55,24 @@ class TTSBase:
self._io_stream.seek(0) self._io_stream.seek(0)
stream = self.__create_bytes_stream(self._io_stream) stream = self.__create_bytes_stream(self._io_stream)
stream_len = stream.shape[0] stream_len = stream.shape[0]
# try:
# sounddevice.play(stream, samplerate=self._human.get_audio_sample_rate())
# sounddevice.wait() # 等待音频播放完毕
# except Exception as e:
# logger.error(f"播放音频出错: {e}") playrec
index = 0 index = 0
while stream_len >= self._chunk_len: while stream_len >= self._chunk_len:
audio_chunk = stream[index:index + self._chunk_len] audio_chunk = stream[index:index + self._chunk_len]
# sounddevice.play(audio_chunk, samplerate=self._human.get_audio_sample_rate())
# self._pcm_stream.write(audio_chunk) # self._pcm_stream.write(audio_chunk)
# self._pcm_stream.write(AudioSegment.from_mp3(audio_chunk)) # self._pcm_stream.write(audio_chunk.tobytes())
# self._human.push_audio_chunk(audio_chunk) # self._human.push_audio_chunk(audio_chunk)
# self._human.push_mel_chunks_queue(audio_chunk) # self._human.push_mel_chunks_queue(audio_chunk)
self._human.push_audio_chunk(audio_chunk) self._human.push_audio_chunk(audio_chunk)
stream_len -= self._chunk_len stream_len -= self._chunk_len
index += self._chunk_len index += self._chunk_len
self._io_stream.seek(0)
self._io_stream.truncate()
def __create_bytes_stream(self, io_stream): def __create_bytes_stream(self, io_stream):
stream, sample_rate = soundfile.read(io_stream) stream, sample_rate = soundfile.read(io_stream)
@ -76,29 +83,34 @@ class TTSBase:
logger.warning(f'tts audio has {stream.shape[1]} channels, only use the first') logger.warning(f'tts audio has {stream.shape[1]} channels, only use the first')
stream = stream[:, 1] stream = stream[:, 1]
if sample_rate != self._sample_rate and stream.shape[0] > 0: if sample_rate != self._human.get_audio_sample_rate() and stream.shape[0] > 0:
logger.warning(f'tts audio sample rate is {sample_rate}, resample to {self._sample_rate}') logger.warning(f'tts audio sample rate is {sample_rate}, resample to {self._human.get_audio_sample_rate() }')
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._sample_rate) stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self._human.get_audio_sample_rate() )
return stream return stream
async def __on_request(self, voice, txt): async def __on_request(self, voice, txt):
communicate = edge_tts.Communicate(txt, voice) communicate = edge_tts.Communicate(txt, voice)
first = True first = True
# total_data = b'' total_data = b''
# CHUNK_SIZE = self._chunk_len CHUNK_SIZE = self._chunk_len
async for chunk in communicate.stream(): async for chunk in communicate.stream():
if chunk["type"] == "audio" and chunk["data"]: if chunk["type"] == "audio" and chunk["data"]:
self._io_stream.write(chunk['data']) data = chunk['data']
# total_data += chunk["data"] self._io_stream.write(data)
# if len(total_data) >= CHUNK_SIZE: elif chunk["type"] == "WordBoundary":
pass
'''
total_data += chunk["data"]
if len(total_data) >= CHUNK_SIZE:
# print(f"Time elapsed: {time.time() - start_time:.2f} seconds") # Print time # print(f"Time elapsed: {time.time() - start_time:.2f} seconds") # Print time
# audio_data = AudioSegment.from_mp3(BytesIO(total_data[:CHUNK_SIZE])) #.raw_data audio_data = AudioSegment.from_mp3(BytesIO(total_data[:CHUNK_SIZE])) #.raw_data
# audio_data = audio_data.set_frame_rate(self._human.get_audio_sample_rate()) audio_data = audio_data.set_frame_rate(self._human.get_audio_sample_rate())
# self._human.push_audio_chunk(audio_data) # self._human.push_audio_chunk(audio_data)
# self._pcm_stream.write(audio_data.raw_data) self._pcm_stream.write(audio_data.raw_data)
# play_audio(total_data[:CHUNK_SIZE], stream) # Play first CHUNK_SIZE bytes # play_audio(total_data[:CHUNK_SIZE], stream) # Play first CHUNK_SIZE bytes
# total_data = total_data[CHUNK_SIZE:] # Remove played data total_data = total_data[CHUNK_SIZE:] # Remove played data
'''
# if first: # if first:
# first = False # first = False
@ -106,10 +118,12 @@ class TTSBase:
# if chuck['type'] == 'audio': # if chuck['type'] == 'audio':
# # self._io_stream.write(chuck['data']) # # self._io_stream.write(chuck['data'])
# self._io_stream.write(AudioSegment.from_mp3(BytesIO(total_data[:CHUNK_SIZE])).raw_data) # self._io_stream.write(AudioSegment.from_mp3(BytesIO(total_data[:CHUNK_SIZE])).raw_data)
# if len(total_data) > 0: # if len(total_data) > 0:
# self._pcm_stream.write(AudioSegment.from_mp3(BytesIO(total_data)).raw_data) # self._pcm_stream.write(AudioSegment.from_mp3(BytesIO(total_data)).raw_data)
# audio_data = AudioSegment.from_mp3(BytesIO(total_data)) # .raw_data # audio_data = AudioSegment.from_mp3(BytesIO(total_data)) # .raw_data
# audio_data = audio_data.set_frame_rate(self._human.get_audio_sample_rate()) # audio_data = audio_data.set_frame_rate(self._human.get_audio_sample_rate())
# self._pcm_stream.write(audio_data.raw_data)
# self._human.push_audio_chunk(audio_data) # self._human.push_audio_chunk(audio_data)
# self._io_stream.write(AudioSegment.from_mp3(BytesIO(total_data)).raw_data) # self._io_stream.write(AudioSegment.from_mp3(BytesIO(total_data)).raw_data)

5
ui.py
View File

@ -63,10 +63,11 @@ class App(customtkinter.CTk):
self._human.on_destroy() self._human.on_destroy()
def play_audio(self): def play_audio(self):
# return
if self._is_play_audio: if self._is_play_audio:
return return
self._is_play_audio = True self._is_play_audio = True
file = os.path.curdir + '/audio/audio1.wav' file = os.path.curdir + '/audio/test.wav'
print(file) print(file)
winsound.PlaySound(file, winsound.SND_ASYNC or winsound.SND_FILENAME) winsound.PlaySound(file, winsound.SND_ASYNC or winsound.SND_FILENAME)
# playsound(file) # playsound(file)
@ -104,7 +105,7 @@ class App(customtkinter.CTk):
height = self.winfo_height() * 0.5 height = self.winfo_height() * 0.5
self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk) self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk)
self._canvas.update() self._canvas.update()
self.after(34, self._render) self.after(33, self._render)
def request_tts(self): def request_tts(self):
content = self.entry.get() content = self.entry.get()