try close infer

This commit is contained in:
brige 2024-10-04 18:10:03 +08:00
parent 6445b6ee05
commit aef7d3d499
4 changed files with 39 additions and 26 deletions

View File

@ -224,12 +224,13 @@ class Human:
# self.coords_path = f"{self.avatar_path}/coords.pkl" # self.coords_path = f"{self.avatar_path}/coords.pkl"
# self.__loadavatar() # self.__loadavatar()
self.stop = False
self.res_render_queue = Queue(self._batch_size * 2) self.res_render_queue = Queue(self._batch_size * 2)
self.chunk_2_mal = Chunk2Mal(self) self.chunk_2_mal = Chunk2Mal(self)
self._tts = TTSBase(self) self._tts = TTSBase(self)
self._infer = Infer(self) self._infer = Infer(self)
self.chunk_2_mal.warm_up() # self.chunk_2_mal.warm_up()
self.audio_render = AudioRender() self.audio_render = AudioRender()
@ -249,6 +250,8 @@ class Human:
# )).start() # )).start()
# self.render_event.set() # self.render_event.set()
def __del__(self):
print('Human del')
# def play_pcm(self): # def play_pcm(self):
# p = pyaudio.PyAudio() # p = pyaudio.PyAudio()
# stream = p.open(format=p.get_format_from_width(2), channels=1, rate=16000, output=True) # stream = p.open(format=p.get_format_from_width(2), channels=1, rate=16000, output=True)
@ -413,8 +416,12 @@ class Human:
return self._stride_right_size return self._stride_right_size
def on_destroy(self): def on_destroy(self):
self.stop = True
# self.render_event.clear() # self.render_event.clear()
# self._chunk_2_mal.stop()
self.chunk_2_mal.stop()
self._tts.stop()
self._infer.stop()
# if self._tts is not None: # if self._tts is not None:
# self._tts.stop() # self._tts.stop()
logging.info('human destroy') logging.info('human destroy')
@ -459,13 +466,14 @@ class Human:
self._test_image_queue.put(image) self._test_image_queue.put(image)
def push_res_frame(self, res_frame, idx, audio_frames): def push_res_frame(self, res_frame, idx, audio_frames):
if self.stop:
print("push_res_frame stop")
return
self.res_render_queue.put((res_frame, idx, audio_frames)) self.res_render_queue.put((res_frame, idx, audio_frames))
def render(self): def render(self):
try: try:
# img, aud = self._res_frame_queue.get(block=True, timeout=.3) res_frame, idx, audio_frames = self.res_render_queue.get(block=True, timeout=.03)
# img = self._test_image_queue.get(block=True, timeout=.3)
res_frame, idx, audio_frames = self.res_render_queue.get(block=True, timeout=.3)
except queue.Empty: except queue.Empty:
# print('render queue.Empty:') # print('render queue.Empty:')
return None return None

View File

@ -348,18 +348,13 @@ class Infer:
count = 0 count = 0
count_time = 0 count_time = 0
print('start inference') print('start inference')
#
# face_images_path = r'./face/'
# face_images_path = utils.read_files_path(face_images_path)
# face_list_cycle1 = read_images(face_images_path)
# face_det_results = face_detect(face_list_cycle1)
while True: while True:
if self._exit_event.is_set(): if self._exit_event.is_set():
start_time = time.perf_counter() start_time = time.perf_counter()
batch_size = self._human.get_batch_size() batch_size = self._human.get_batch_size()
try: try:
mel_batch = self._feat_queue.get(block=True, timeout=1) mel_batch = self._feat_queue.get(block=True, timeout=0.1)
except queue.Empty: except queue.Empty:
continue continue
@ -370,10 +365,8 @@ class Infer:
audio_frames.append((frame, type_)) audio_frames.append((frame, type_))
if type_ == 0: if type_ == 0:
is_all_silence = False is_all_silence = False
if is_all_silence: if is_all_silence:
for i in range(batch_size): for i in range(batch_size):
# res_frame_queue.put((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]) self._human.push_res_frame(None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])
index = index + 1 index = index + 1
else: else:
@ -385,9 +378,6 @@ class Infer:
face = face_list_cycle[idx] face = face_list_cycle[idx]
img_batch.append(face) img_batch.append(face)
# img_batch_1, mel_batch_1, frames, coords = datagen_signal(face_list_cycle1,
# mel_batch, face_det_results)
img_batch = np.asarray(img_batch) img_batch = np.asarray(img_batch)
mel_batch = np.asarray(mel_batch) mel_batch = np.asarray(mel_batch)
img_masked = img_batch.copy() img_masked = img_batch.copy()
@ -402,7 +392,7 @@ class Infer:
with torch.no_grad(): with torch.no_grad():
pred = model(mel_batch, img_batch) pred = model(mel_batch, img_batch)
# pred = model(mel_batch, img_batch) * 255.0
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
count_time += (time.perf_counter() - t) count_time += (time.perf_counter() - t)
@ -412,18 +402,30 @@ class Infer:
print(f"------actual avg infer fps:{count / count_time:.4f}") print(f"------actual avg infer fps:{count / count_time:.4f}")
count = 0 count = 0
count_time = 0 count_time = 0
image_index = 0
for i, res_frame in enumerate(pred): for i, res_frame in enumerate(pred):
# self.__pushmedia(res_frame,loop,audio_track,video_track)
# res_frame_queue.put(
# (res_frame, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]))
self._human.push_res_frame(res_frame, mirror_index(length, index), self._human.push_res_frame(res_frame, mirror_index(length, index),
audio_frames[i * 2:i * 2 + 2]) audio_frames[i * 2:i * 2 + 2])
index = index + 1 index = index + 1
# print('total batch time:',time.perf_counter()-start_time) image_index = image_index + 1
print('batch count', image_index)
print('total batch time:', time.perf_counter() - start_time)
else: else:
time.sleep(1) time.sleep(1)
break
print('musereal inference processor stop') print('musereal inference processor stop')
def stop(self):
if self._exit_event is None:
return
self.pause_talk()
self._exit_event.clear()
self._run_thread.join()
logging.info('Infer stop')
def pause_talk(self): def pause_talk(self):
self._feat_queue.queue.clear() self._feat_queue.queue.clear()
self._audio_out_queue.queue.clear() self._audio_out_queue.queue.clear()

View File

@ -95,9 +95,8 @@ class TTSBase:
pass pass
def stop(self): def stop(self):
self._pcm_stream.stop_stream() self.input_stream.seek(0)
self._pcm_player.close(self._pcm_stream) self.input_stream.truncate()
self._pcm_player.terminate()
if self._exit_event is None: if self._exit_event is None:
return return

8
ui.py
View File

@ -59,6 +59,10 @@ class App(customtkinter.CTk):
self._render() self._render()
# self.play_audio() # self.play_audio()
def destroy(self):
self.on_destroy()
super().destroy()
def on_destroy(self): def on_destroy(self):
logger.info('------------App destroy------------') logger.info('------------App destroy------------')
self._human.on_destroy() self._human.on_destroy()
@ -173,5 +177,5 @@ if __name__ == "__main__":
logger.info('------------start------------') logger.info('------------start------------')
app = App() app = App()
app.mainloop() app.mainloop()
app.on_destroy() # app.on_destroy()
# logger.info('------------exit------------') logger.info('------------exit------------')