From ad54248ff37ee5be78f3977c334102a8f4f7191f Mon Sep 17 00:00:00 2001 From: brige Date: Thu, 17 Oct 2024 23:26:21 +0800 Subject: [PATCH] modify ui and nlp tts code --- asr/sherpa_ncnn_asr.py | 87 ++++---------------------- face_detection/detection/sfd/detect.py | 2 + human/__init__.py | 4 +- human/audio_handler.py | 21 ------- human/audio_inference_handler.py | 33 ++++++---- human/audio_mal_handler.py | 3 +- human/human_context.py | 50 +++++++++------ human/human_render.py | 14 ++--- nlp/nlp_doubao.py | 36 ++--------- test/test_human_context.py | 44 +++++++++++++ tts/tts_audio_handle.py | 27 ++++++-- tts/tts_edge.py | 6 +- ui.py | 41 +++++------- utils/__init__.py | 2 +- utils/utils.py | 23 ++++++- 15 files changed, 192 insertions(+), 201 deletions(-) delete mode 100644 human/audio_handler.py create mode 100644 test/test_human_context.py diff --git a/asr/sherpa_ncnn_asr.py b/asr/sherpa_ncnn_asr.py index 777d04b..331320c 100644 --- a/asr/sherpa_ncnn_asr.py +++ b/asr/sherpa_ncnn_asr.py @@ -1,8 +1,7 @@ #encoding = utf8 - +import logging import os import sys -import time try: import sounddevice as sd @@ -16,18 +15,26 @@ except ImportError as e: import sherpa_ncnn - from asr.asr_base import AsrBase +logger = logging.getLogger(__name__) + +current_file_path = os.path.dirname(os.path.abspath(__file__)) class SherpaNcnnAsr(AsrBase): def __init__(self): super().__init__() self._recognizer = self._create_recognizer() + logger.info('SherpaNcnnAsr init') + + def __del__(self): + self.__del__() + logger.info('SherpaNcnnAsr del') def _create_recognizer(self): - base_path = os.path.join(os.getcwd(), 'data', 'asr', 'sherpa-ncnn', + base_path = os.path.join(current_file_path, '..', 'data', 'asr', 'sherpa-ncnn', 'sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23') + logger.info(f'_create_recognizer init, path:{base_path}') recognizer = sherpa_ncnn.Recognizer( tokens=base_path + '/tokens.txt', encoder_param=base_path + '/encoder_jit_trace-pnnx.ncnn.param', @@ -50,6 +57,7 @@ class SherpaNcnnAsr(AsrBase): def _recognize_loop(self): segment_id = 0 last_result = "" + logger.info(f'_recognize_loop') with sd.InputStream(channels=1, dtype="float32", samplerate=self._sample_rate) as s: while not self._stop_event.is_set(): samples, _ = s.read(self._samples_per_read) # a blocking read @@ -70,74 +78,3 @@ class SherpaNcnnAsr(AsrBase): self._notify_complete(result) segment_id += 1 self._recognizer.reset() - -def main(): - print("Started! Please speak") - asr = SherpaNcnnAsr() - time.sleep(20) - print("Stop! ") - asr.stop() - - # print("Started! Please speak") - # recognizer = create_recognizer() - # sample_rate = recognizer.sample_rate - # samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms - # last_result = "" - # segment_id = 0 - # - # with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: - # while True: - # samples, _ = s.read(samples_per_read) # a blocking read - # samples = samples.reshape(-1) - # recognizer.accept_waveform(sample_rate, samples) - # - # is_endpoint = recognizer.is_endpoint - # - # result = recognizer.text - # if result and (last_result != result): - # last_result = result - # print("\r{}:{}".format(segment_id, result), end=".", flush=True) - # - # if is_endpoint: - # if result: - # print("\r{}:{}".format(segment_id, result), flush=True) - # segment_id += 1 - # recognizer.reset() - - # print("Started! Please speak") - # recognizer = create_recognizer() - # sample_rate = recognizer.sample_rate - # samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms - # last_result = "" - # with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: - # while True: - # samples, _ = s.read(samples_per_read) # a blocking read - # samples = samples.reshape(-1) - # recognizer.accept_waveform(sample_rate, samples) - # result = recognizer.text - # if last_result != result: - # last_result = result - # print("\r{}".format(result), end="", flush=True) - -''' -if __name__ == "__main__": - devices = sd.query_devices() - print(devices) - default_input_device_idx = sd.default.device[0] - print(f'Use default device: {devices[default_input_device_idx]["name"]}') - - try: - main() - except KeyboardInterrupt: - print("\nCaught Ctrl + C. Exiting") - - # devices = sd.query_devices() - # print(devices) - # default_input_device_idx = sd.default.device[0] - # print(f'Use default device: {devices[default_input_device_idx]["name"]}') - # - # try: - # main() - # except KeyboardInterrupt: - # print("\nCaught Ctrl + C. Exiting") -''' diff --git a/face_detection/detection/sfd/detect.py b/face_detection/detection/sfd/detect.py index efef627..d6ff706 100644 --- a/face_detection/detection/sfd/detect.py +++ b/face_detection/detection/sfd/detect.py @@ -55,6 +55,7 @@ def detect(net, img, device): return bboxlist + def batch_detect(net, imgs, device): imgs = imgs - np.array([104, 117, 123]) imgs = imgs.transpose(0, 3, 1, 2) @@ -93,6 +94,7 @@ def batch_detect(net, imgs, device): return bboxlist + def flip_detect(net, img, device): img = cv2.flip(img, 1) b = detect(net, img, device) diff --git a/human/__init__.py b/human/__init__.py index 175ad85..966011f 100644 --- a/human/__init__.py +++ b/human/__init__.py @@ -1,4 +1,6 @@ #encoding = utf8 -from .audio_handler import AudioHandler from .human_context import HumanContext +from .audio_mal_handler import AudioMalHandler +from .audio_inference_handler import AudioInferenceHandler +from .human_render import HumanRender diff --git a/human/audio_handler.py b/human/audio_handler.py deleted file mode 100644 index 79f9239..0000000 --- a/human/audio_handler.py +++ /dev/null @@ -1,21 +0,0 @@ -#encoding = utf8 -import logging -from abc import ABC, abstractmethod - -logger = logging.getLogger(__name__) - - -class AudioHandler(ABC): - def __init__(self, context, handler): - self._context = context - self._handler = handler - - @abstractmethod - def on_handle(self, stream, index): - pass - - def on_next_handle(self, stream, type_): - if self._handler is not None: - self._handler.on_handle(stream, type_) - else: - logging.info(f'_handler is None') diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py index 1b829d5..8b74b25 100644 --- a/human/audio_inference_handler.py +++ b/human/audio_inference_handler.py @@ -1,4 +1,6 @@ #encoding = utf8 +import logging +import os import queue import time from queue import Queue @@ -7,9 +9,12 @@ from threading import Event, Thread import numpy as np import torch -from .audio_handler import AudioHandler +from human_handler import AudioHandler from utils import load_model, mirror_index, get_device +logger = logging.getLogger(__name__) +current_file_path = os.path.dirname(os.path.abspath(__file__)) + class AudioInferenceHandler(AudioHandler): def __init__(self, context, handler): @@ -22,6 +27,7 @@ class AudioInferenceHandler(AudioHandler): self._run_thread = Thread(target=self.__on_run) self._exit_event.set() self._run_thread.start() + logger.info("AudioInferenceHandler init") def on_handle(self, stream, type_): if type_ == 1: @@ -30,8 +36,10 @@ class AudioInferenceHandler(AudioHandler): self._audio_queue.put(stream) def __on_run(self): - model = load_model(r'.\checkpoints\wav2lip.pth') - print("Model loaded") + wav2lip_path = os.path.join(current_file_path, '..', 'checkpoints', 'wav2lip.pth') + logger.info(f'AudioInferenceHandler init, path:{wav2lip_path}') + model = load_model(wav2lip_path) + logger.info("Model loaded") face_list_cycle = self._context.face_list_cycle @@ -39,10 +47,10 @@ class AudioInferenceHandler(AudioHandler): index = 0 count = 0 count_time = 0 - print('start inference') + logger.info('start inference') device = get_device() - print(f'use device:{device}') + logger.info(f'use device:{device}') while True: if self._exit_event.is_set(): @@ -66,7 +74,7 @@ class AudioInferenceHandler(AudioHandler): 0) index = index + 1 else: - print('infer=======') + logger.info('infer=======') t = time.perf_counter() img_batch = [] for i in range(batch_size): @@ -95,20 +103,21 @@ class AudioInferenceHandler(AudioHandler): count += batch_size if count >= 100: - print(f"------actual avg infer fps:{count / count_time:.4f}") + logger.info(f"------actual avg infer fps:{count / count_time:.4f}") count = 0 count_time = 0 - image_index = 0 for i, res_frame in enumerate(pred): self.on_next_handle( (res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), 0) index = index + 1 - image_index = image_index + 1 - print('batch count', image_index) - print('total batch time:', time.perf_counter() - start_time) + logger.info(f'total batch time: {time.perf_counter() - start_time}') else: time.sleep(1) break - print('musereal inference processor stop') + logger.info('AudioInferenceHandler inference processor stop') + + def stop(self): + self._exit_event.clear() + self._run_thread.join() diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py index 78dc4ae..a9dc50d 100644 --- a/human/audio_mal_handler.py +++ b/human/audio_mal_handler.py @@ -7,7 +7,7 @@ from threading import Thread, Event import numpy as np -from .audio_handler import AudioHandler +from human_handler import AudioHandler from utils import melspectrogram logger = logging.getLogger(__name__) @@ -25,6 +25,7 @@ class AudioMalHandler(AudioHandler): self.frames = [] self.chunk = context.sample_rate // context.fps + logger.info("AudioMalHandler init") def on_handle(self, stream, index): self._queue.put(stream) diff --git a/human/human_context.py b/human/human_context.py index 34b9ce0..12fcb0e 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -1,15 +1,17 @@ #encoding = utf8 import logging +import os from asr import SherpaNcnnAsr -from human.audio_inference_handler import AudioInferenceHandler -from human.audio_mal_handler import AudioMalHandler -from human.human_render import HumanRender +from .audio_inference_handler import AudioInferenceHandler +from .audio_mal_handler import AudioMalHandler +from .human_render import HumanRender from nlp import PunctuationSplit, DouBao from tts import TTSEdge, TTSAudioSplitHandle from utils import load_avatar, get_device logger = logging.getLogger(__name__) +current_file_path = os.path.dirname(os.path.abspath(__file__)) class HumanContext: @@ -23,7 +25,9 @@ class HumanContext: self._device = get_device() print(f'device:{self._device}') - full_images, face_frames, coord_frames = load_avatar(r'./face/', self._image_size, self._device) + base_path = os.path.join(current_file_path, '..', 'face') + logger.info(f'_create_recognizer init, path:{base_path}') + full_images, face_frames, coord_frames = load_avatar(base_path, self._image_size, self._device) self._frame_list_cycle = full_images self._face_list_cycle = face_frames self._coord_list_cycle = coord_frames @@ -31,14 +35,24 @@ class HumanContext: logging.info(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}') - self.asr = None - self.nlp = None - self.tts = None - self.tts_handle = None - self.mal_handler = None - self.infer_handler = None + self._asr = None + self._nlp = None + self._tts = None + self._tts_handle = None + self._mal_handler = None + self._infer_handler = None self._render_handler = None + def __del__(self): + print(f'HumanContext: __del__') + self._asr.stop() + self._nlp.stop() + self._tts.stop() + self._tts_handle.stop() + self._mal_handler.stop() + self._infer_handler.stop() + self._render_handler.stop() + @property def fps(self): return self._fps @@ -81,17 +95,17 @@ class HumanContext: @property def render_handler(self): - return self.render_handler + return self._render_handler def build(self): self._render_handler = HumanRender(self, None) - self.infer_handler = AudioInferenceHandler(self, self._render_handler) - self.mal_handler = AudioMalHandler(self, self.infer_handler) - self.tts_handle = TTSAudioSplitHandle(self, self.mal_handler) - self.tts = TTSEdge(self.tts_handle) + self._infer_handler = AudioInferenceHandler(self, self._render_handler) + self._mal_handler = AudioMalHandler(self, self._infer_handler) + self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler) + self._tts = TTSEdge(self._tts_handle) split = PunctuationSplit() - nlp = DouBao(split, self.tts) - self.asr = SherpaNcnnAsr() - self.asr.attach(nlp) + self._nlp = DouBao(split, self._tts) + self._asr = SherpaNcnnAsr() + self._asr.attach(self._nlp) diff --git a/human/human_render.py b/human/human_render.py index db319f0..9ff5f40 100644 --- a/human/human_render.py +++ b/human/human_render.py @@ -9,8 +9,7 @@ from threading import Thread, Event import cv2 import numpy as np -from audio_render import AudioRender -from .audio_handler import AudioHandler +from human_handler import AudioHandler class HumanRender(AudioHandler): @@ -27,12 +26,12 @@ class HumanRender(AudioHandler): self._thread.start() def _on_run(self): - logging.info('chunk2mal run') + logging.info('human render run') while self._exit_event.is_set(): self._run_step() time.sleep(0.002) - logging.info('chunk2mal exit') + logging.info('human render exit') def _run_step(self): try: @@ -58,7 +57,7 @@ class HumanRender(AudioHandler): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self._image_render is not None: - self._image_render.render(image) + self._image_render.on_render(image) for audio_frame in audio_frames: frame, type_ = audio_frame @@ -69,7 +68,6 @@ class HumanRender(AudioHandler): # new_frame.planes[0].update(frame.tobytes()) # new_frame.sample_rate = 16000 - def set_audio_render(self, render): self._audio_render = render @@ -79,4 +77,6 @@ class HumanRender(AudioHandler): def on_handle(self, stream, index): self._queue.put(stream) - + def stop(self): + self._exit_event.clear() + self._thread.join() diff --git a/nlp/nlp_doubao.py b/nlp/nlp_doubao.py index ab7c9ad..303a19a 100644 --- a/nlp/nlp_doubao.py +++ b/nlp/nlp_doubao.py @@ -12,6 +12,7 @@ logger = logging.getLogger(__name__) class DouBao(NLPBase): def __init__(self, split, callback=None): super().__init__(split, callback) + logger.info("DouBao init") # Access Key ID # AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y # AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc @@ -30,7 +31,7 @@ class DouBao(NLPBase): async def _request(self, question): t = time.time() logger.info(f'_request:{question}') - print(f'-------dou_bao ask:', question) + logger.info(f'-------dou_bao ask:{question}') try: stream = await self.__client.chat.completions.create( model="ep-20241008152048-fsgzf", @@ -51,38 +52,9 @@ class DouBao(NLPBase): except Exception as e: print(e) logger.info(f'_request:{question}, time:{time.time() - t:.4f}s') - print(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') async def _on_close(self): - print('AsyncArk close') + logger.info('AsyncArk close') if self.__client is not None and not self.__client.is_closed(): await self.__client.close() - -''' -if __name__ == "__main__": - # print(get_dou_bao_api()) - dou_bao = DouBao() - dou_bao.ask('你好。') - dou_bao.ask('你好,你是谁?') - dou_bao.ask('你能做什么?') - dou_bao.ask('介绍一下,我自己。') - count = 1000 - sec = '' - while count >= 0: - count = count - 1 - if nlp_queue.empty(): - time.sleep(0.1) - continue - sec = sec + nlp_queue.get(block=True, timeout=0.01) - - pattern = r'[,。、;?!,.!?]' - match = re.search(pattern, sec) - if match: - pos = match.start() + 1 - print(sec[:pos]) - sec = sec[pos:] - print(sec) - - - dou_bao.stop() -''' diff --git a/test/test_human_context.py b/test/test_human_context.py new file mode 100644 index 0000000..da35ecb --- /dev/null +++ b/test/test_human_context.py @@ -0,0 +1,44 @@ +#encoding = utf8 +import logging +import os +import sys +import time + +from human import HumanContext +from utils import config_logging + + +# try: +# import sounddevice as sd +# except ImportError as e: +# print("Please install sounddevice first. You can use") +# print() +# print(" pip install sounddevice") +# print() +# print("to install it") +# sys.exit(-1) +# + + +def main(): + print("Started! Please speak") + human = HumanContext() + human.build() + time.sleep(60) + print("Stop! ") + + +if __name__ == "__main__": + # devices = sd.query_devices() + # print(devices) + # default_input_device_idx = sd.default.device[0] + # print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + current_file_path = os.path.dirname(os.path.abspath(__file__)) + log_path = os.path.join(current_file_path, '..', 'logs', 'info.log') + config_logging(log_path, logging.INFO, logging.INFO) + + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/tts/tts_audio_handle.py b/tts/tts_audio_handle.py index e892033..7c09d31 100644 --- a/tts/tts_audio_handle.py +++ b/tts/tts_audio_handle.py @@ -1,9 +1,13 @@ #encoding = utf8 +import heapq +import logging import os import shutil from utils import save_wav -from human import AudioHandler +from human_handler import AudioHandler + +logger = logging.getLogger(__name__) class TTSAudioHandle(AudioHandler): @@ -27,26 +31,38 @@ class TTSAudioHandle(AudioHandler): def on_handle(self, stream, index): pass + def stop(self): + pass + class TTSAudioSplitHandle(TTSAudioHandle): def __init__(self, context, handler): super().__init__(context, handler) self.sample_rate = self._context.sample_rate self._chunk = self.sample_rate // self._context.fps + self._priority_queue = [] + logger.info("TTSAudioSplitHandle init") def on_handle(self, stream, index): + # heapq.heappush(self._priority_queue, (index, stream)) + if stream is None: + heapq.heappush(self._priority_queue, (index, None)) + stream_len = stream.shape[0] idx = 0 while stream_len >= self._chunk: - self._context.put_audio_frame(stream[idx:idx + self._chunk]) + self.on_next_handle(stream[idx:idx + self._chunk], 0) stream_len -= self._chunk idx += self._chunk + def stop(self): + pass + class TTSAudioSaveHandle(TTSAudioHandle): - def __init__(self): - super().__init__() + def __init__(self, context, handler): + super().__init__(context, handler) self._save_path_dir = '../temp/audio/' self._clean() @@ -72,3 +88,6 @@ class TTSAudioSaveHandle(TTSAudioHandle): file_name = self._save_path_dir + str(index) + '.wav' save_wav(stream, file_name, self.sample_rate) + def stop(self): + pass + diff --git a/tts/tts_edge.py b/tts/tts_edge.py index 0926c25..ceea740 100644 --- a/tts/tts_edge.py +++ b/tts/tts_edge.py @@ -1,5 +1,5 @@ #encoding = utf8 - +import logging from io import BytesIO import numpy as np @@ -9,11 +9,14 @@ import resampy from .tts_base import TTSBase +logger = logging.getLogger(__name__) + class TTSEdge(TTSBase): def __init__(self, handle, voice='zh-CN-XiaoyiNeural'): super().__init__(handle) self._voice = voice + logger.info(f"TTSEdge init, {voice}") async def _on_request(self, txt: str): print('_on_request, txt') @@ -42,6 +45,7 @@ class TTSEdge(TTSBase): print('-------tts finish push chunk') except Exception as e: + self._handle.on_handle(None, index) stream.seek(0) stream.truncate() print('-------tts finish error:', e) diff --git a/ui.py b/ui.py index 1ea22bd..f4388c8 100644 --- a/ui.py +++ b/ui.py @@ -16,8 +16,11 @@ from PIL import Image, ImageTk from playsound import playsound +from audio_render import AudioRender # from Human import Human from human import HumanContext +from utils import config_logging + # from tts.EdgeTTS import EdgeTTS logger = logging.getLogger(__name__) @@ -48,7 +51,7 @@ class App(customtkinter.CTk): # self.logo_label.grid(row=0, column=0, padx=20, pady=(20, 10)) self.entry = customtkinter.CTkEntry(self, placeholder_text="输入内容") - self.entry.insert(0, "大家好,我是九零科技有限公司,虚拟数字人。") + self.entry.insert(0, "大家好,测试虚拟数字人。") self.entry.grid(row=2, column=0, columnspan=2, padx=(20, 0), pady=(20, 20), sticky="nsew") self.main_button_1 = customtkinter.CTkButton(master=self, fg_color="transparent", border_width=2, @@ -58,13 +61,14 @@ class App(customtkinter.CTk): self._init_image_canvas() - self._is_play_audio = False + self._audio_render = AudioRender() # self._human = Human() self._queue = Queue() self._human_context = HumanContext() self._human_context.build() render = self._human_context.render_handler render.set_image_render(self) + render.set_audio_render(self._audio_render) self._render() # self.play_audio() @@ -84,13 +88,14 @@ class App(customtkinter.CTk): self._canvas.pack(fill=customtkinter.BOTH, expand=customtkinter.YES) def _render(self): + after_time = 29 try: - image = self._queue.get() + image = self._queue.get(block=True, timeout=0.003) if image is None: - self.after(20, self._render) + self.after(after_time, self._render) return except queue.Empty: - self.after(20, self._render) + self.after(after_time, self._render) return iheight, iwidth = image.shape[0], image.shape[1] @@ -111,7 +116,7 @@ class App(customtkinter.CTk): height = self.winfo_height() * 0.5 self._canvas.create_image(width, height, anchor=customtkinter.CENTER, image=imgtk) self._canvas.update() - self.after(20, self._render) + self.after(after_time, self._render) def request_tts(self): content = self.entry.get() @@ -148,27 +153,9 @@ class App(customtkinter.CTk): sound = AudioSegment.from_mp3('./audio/mp3/' + file_name) sound.export('./audio/wav/' + file_name + '.wav', format="wav") - # open('./audio/', 'wb') with - - -def config_logging(file_name: str, console_level: int=logging.INFO, file_level: int=logging.DEBUG): - file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8") - file_handler.setFormatter(logging.Formatter( - '%(asctime)s [%(levelname)s] %(module)s.%(lineno)d %(name)s:\t%(message)s' - )) - file_handler.setLevel(file_level) - - console_handler = logging.StreamHandler() - console_handler.setFormatter(logging.Formatter( - '[%(asctime)s %(levelname)s] %(message)s', - datefmt="%Y/%m/%d %H:%M:%S" - )) - console_handler.setLevel(console_level) - - logging.basicConfig( - level=min(console_level, file_level), - handlers=[file_handler, console_handler], - ) + def on_render(self, image): + self._queue.put(image) + print('on_render', self._queue.qsize()) if __name__ == "__main__": diff --git a/utils/__init__.py b/utils/__init__.py index 0d5c37a..425ee71 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,6 +1,6 @@ #encoding = utf8 from .async_task_queue import AsyncTaskQueue -from .utils import mirror_index, load_model, get_device, load_avatar +from .utils import mirror_index, load_model, get_device, load_avatar, config_logging from .audio_utils import melspectrogram, save_wav diff --git a/utils/utils.py b/utils/utils.py index 371ac87..5cad18a 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -38,7 +38,7 @@ def read_files_path(path): files = os.listdir(path) for file in files: if not os.path.isdir(file): - file_paths.append(path + file) + file_paths.append(os.path.join(path, file)) return file_paths @@ -160,6 +160,7 @@ def load_model(path): def load_avatar(path, img_size, device): + print(f'load avatar:{path}') face_images_path = path face_images_path = read_files_path(face_images_path) full_list_cycle = read_images(face_images_path) @@ -174,3 +175,23 @@ def load_avatar(path, img_size, device): coord_frames.append(coord) return full_list_cycle, face_frames, coord_frames + + +def config_logging(file_name: str, console_level: int=logging.INFO, file_level: int=logging.DEBUG): + file_handler = logging.FileHandler(file_name, mode='a', encoding="utf8") + file_handler.setFormatter(logging.Formatter( + '%(asctime)s [%(levelname)s] %(module)s.%(lineno)d %(name)s:\t%(message)s' + )) + file_handler.setLevel(file_level) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter( + '[%(asctime)s %(levelname)s] %(message)s', + datefmt="%Y/%m/%d %H:%M:%S" + )) + console_handler.setLevel(console_level) + + logging.basicConfig( + level=min(console_level, file_level), + handlers=[file_handler, console_handler], + )