diff --git a/asr/asr_base.py b/asr/asr_base.py index 4eadd41..1118e44 100644 --- a/asr/asr_base.py +++ b/asr/asr_base.py @@ -30,11 +30,11 @@ class AsrBase: pass def _notify_process(self, message: str): + EventBus().post('clear_cache') for observer in self._observers: observer.process(message) def _notify_complete(self, message: str): - EventBus().post('clear_cache') for observer in self._observers: observer.completed(message) diff --git a/asr/sherpa_ncnn_asr.py b/asr/sherpa_ncnn_asr.py index 4124e76..9216cb6 100644 --- a/asr/sherpa_ncnn_asr.py +++ b/asr/sherpa_ncnn_asr.py @@ -83,15 +83,16 @@ class SherpaNcnnAsr(AsrBase): self._notify_complete(result) segment_id += 1 self._recognizer.reset() + ''' - while self._stop_event.is_set(): - logger.info(f'_recognize_loop000') - self._notify_complete('介绍中国5000年历史文学') - logger.info(f'_recognize_loop111') - segment_id += 1 - time.sleep(150) - logger.info(f'_recognize_loop222') - logger.info(f'_recognize_loop exit') +while self._stop_event.is_set(): + logger.info(f'_recognize_loop000') + self._notify_complete('介绍中国5000年历史文学') + logger.info(f'_recognize_loop111') + segment_id += 1 + time.sleep(150) + logger.info(f'_recognize_loop222') +logger.info(f'_recognize_loop exit') ''' diff --git a/audio_render/AudioRender.dll b/audio_render/AudioRender.dll deleted file mode 100644 index 6487ba6..0000000 Binary files a/audio_render/AudioRender.dll and /dev/null differ diff --git a/audio_render/AudioRender.lib b/audio_render/AudioRender.lib deleted file mode 100644 index 1872fe7..0000000 Binary files a/audio_render/AudioRender.lib and /dev/null differ diff --git a/audio_render/AudioRender.pdb b/audio_render/AudioRender.pdb deleted file mode 100644 index e87a60d..0000000 Binary files a/audio_render/AudioRender.pdb and /dev/null differ diff --git a/audio_render/__init__.py b/audio_render/__init__.py deleted file mode 100644 index 8208c73..0000000 --- a/audio_render/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -#encoding = utf8 - -from .audio_render import AudioRender diff --git a/audio_render/audio_render.py b/audio_render/audio_render.py deleted file mode 100644 index 2ece237..0000000 --- a/audio_render/audio_render.py +++ /dev/null @@ -1,39 +0,0 @@ -#encoding = utf8 - -from ctypes import * -import os - -import numpy as np - -current = os.path.dirname(__file__) -dynamic_path = os.path.join(current, 'AudioRender.dll') - - -def audio_render_log_callback(level, log, size): - print(f'level={level}, log={log}, len={size}') - - -class AudioRender: - def __init__(self): - self.__audio_render_obj = WinDLL(dynamic_path) - print(self.__audio_render_obj) - if self.__audio_render_obj is not None: - CALLBACK_TYPE = CFUNCTYPE(None, c_int, c_ubyte, c_uint) - c_callback = CALLBACK_TYPE(audio_render_log_callback) - self.__init = self.__audio_render_obj.Initialize(c_callback) - print('AudioRender init', self.__init) - - def __del__(self): - print('AudioRender __del__') - if self.__audio_render_obj is None: - return - if self.__init: - self.__audio_render_obj.Uninitialize() - - def write(self, data, size): - if not self.__init: - return False - - self.__audio_render_obj.argtypes = (POINTER(c_uint8), c_uint) - byte_data = np.frombuffer(data, dtype=np.uint8) - return self.__audio_render_obj.Write(byte_data.ctypes.data_as(POINTER(c_uint8)), size) diff --git a/human/__init__.py b/human/__init__.py index 503a041..06c491a 100644 --- a/human/__init__.py +++ b/human/__init__.py @@ -4,4 +4,5 @@ from .human_context import HumanContext from .audio_mal_handler import AudioMalHandler from .audio_inference_handler import AudioInferenceHandler from .audio_inference_onnx_handler import AudioInferenceOnnxHandler -from .human_render import HumanRender +from .huaman_status import HumanStatusEnum, HumanStatus +from .human_render import HumanRender, RenderStatus diff --git a/human/audio_inference_handler.py b/human/audio_inference_handler.py index 82a8d51..1f57b6f 100644 --- a/human/audio_inference_handler.py +++ b/human/audio_inference_handler.py @@ -3,15 +3,16 @@ import logging import os import queue import time -from queue import Queue from threading import Event, Thread +import cv2 import numpy as np import torch from eventbus import EventBus from human_handler import AudioHandler from utils import load_model, mirror_index, get_device, SyncQueue +from .huaman_status import HumanStatus logger = logging.getLogger(__name__) current_file_path = os.path.dirname(os.path.abspath(__file__)) @@ -22,7 +23,6 @@ class AudioInferenceHandler(AudioHandler): super().__init__(context, handler) EventBus().register('stop', self._on_stop) - EventBus().register('clear_cache', self.on_clear_cache) self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel") self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio") @@ -36,15 +36,10 @@ class AudioInferenceHandler(AudioHandler): def __del__(self): EventBus().unregister('stop', self._on_stop) - EventBus().unregister('clear_cache', self.on_clear_cache) def _on_stop(self, *args, **kwargs): self.stop() - def on_clear_cache(self, *args, **kwargs): - self._mal_queue.clear() - self._audio_queue.clear() - def on_handle(self, stream, type_): if not self._is_running: return @@ -67,12 +62,13 @@ class AudioInferenceHandler(AudioHandler): logger.info("Model loaded") face_list_cycle = self._context.face_list_cycle - length = len(face_list_cycle) index = 0 count = 0 count_time = 0 logger.info('start inference') + silence_length = 133 + human_status = HumanStatus(length, silence_length) device = get_device() logger.info(f'use device:{device}') @@ -107,16 +103,22 @@ class AudioInferenceHandler(AudioHandler): for i in range(batch_size): if not self._is_running: break - self.on_next_handle((None, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), - 0) - index = index + 1 + # self.on_next_handle((None, mirror_index(length, index), + self.on_next_handle((None, human_status.get_index(), + audio_frames[i * 2:i * 2 + 2]), 0) + # index = index + 1 else: + human_status.start_talking() logger.info(f'infer======= {current_text}') + # human_status.try_to_talk() t = time.perf_counter() img_batch = [] + index_list = [] # for i in range(batch_size): for i in range(len(mel_batch)): - idx = mirror_index(length, index + i) + # idx = mirror_index(length, index + i) + idx = human_status.get_index() + index_list.append(idx) face = face_list_cycle[idx] img_batch.append(face) @@ -150,9 +152,11 @@ class AudioInferenceHandler(AudioHandler): if not self._is_running: break self.on_next_handle( - (res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), + # (res_frame, mirror_index(length, index), audio_frames[i * 2:i * 2 + 2]), + (res_frame, index_list[i], audio_frames[i * 2:i * 2 + 2]), 0) - index = index + 1 + # index = index + 1 + logger.info(f'total batch time: {time.perf_counter() - start_time}') else: time.sleep(1) @@ -171,6 +175,4 @@ class AudioInferenceHandler(AudioHandler): def pause_talk(self): print('AudioInferenceHandler pause_talk', self._audio_queue.size(), self._mal_queue.size()) self._audio_queue.clear() - print('AudioInferenceHandler111') self._mal_queue.clear() - print('AudioInferenceHandler222') diff --git a/human/audio_inference_onnx_handler.py b/human/audio_inference_onnx_handler.py index f7d28db..826fdae 100644 --- a/human/audio_inference_onnx_handler.py +++ b/human/audio_inference_onnx_handler.py @@ -5,7 +5,9 @@ import queue import time from threading import Event, Thread -from gfpgan import GFPGANer +import cv2 + +# from gfpgan import GFPGANer from eventbus import EventBus from human_handler import AudioHandler from utils import load_model, mirror_index, get_device, SyncQueue @@ -16,32 +18,32 @@ current_file_path = os.path.dirname(os.path.abspath(__file__)) def load_gfpgan_model(model_path): logger.info(f'load_gfpgan_model, path:{model_path}') - model = GFPGANer( - model_path=model_path, - upscale=1, - arch='clean', - channel_multiplier=2, - bg_upsampler=None, - ) - return model + # model = GFPGANer( + # model_path=model_path, + # upscale=1, + # arch='clean', + # channel_multiplier=2, + # bg_upsampler=None, + # ) + return None + #model def load_model(model_path): - import onnxruntime as ort - - sess_opt = ort.SessionOptions() - sess_opt.intra_op_num_threads = 8 - sess = ort.InferenceSession(model_path, sess_options=sess_opt, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) - - return sess + # import onnxruntime as ort + # sess_opt = ort.SessionOptions() + # sess_opt.intra_op_num_threads = 8 + # sess = ort.InferenceSession(model_path, sess_options=sess_opt, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + # + # return sess + return None class AudioInferenceOnnxHandler(AudioHandler): def __init__(self, context, handler): super().__init__(context, handler) EventBus().register('stop', self._on_stop) - EventBus().register('clear_cache', self.on_clear_cache) self._mal_queue = SyncQueue(1, "AudioInferenceHandler_Mel") self._audio_queue = SyncQueue(context.batch_size * 2, "AudioInferenceHandler_Audio") @@ -55,15 +57,10 @@ class AudioInferenceOnnxHandler(AudioHandler): def __del__(self): EventBus().unregister('stop', self._on_stop) - EventBus().unregister('clear_cache', self.on_clear_cache) def _on_stop(self, *args, **kwargs): self.stop() - def on_clear_cache(self, *args, **kwargs): - self._mal_queue.clear() - self._audio_queue.clear() - def on_handle(self, stream, type_): if not self._is_running: return @@ -92,9 +89,9 @@ class AudioInferenceOnnxHandler(AudioHandler): gfpgan_model = load_gfpgan_model(gfpgan_model_path) face_list_cycle = self._context.face_list_cycle - - length = len(face_list_cycle) + for i in range(length): + cv2.imwrite(f'face_{i}.png', face_list_cycle[i]) index = 0 count = 0 count_time = 0 @@ -156,18 +153,6 @@ class AudioInferenceOnnxHandler(AudioHandler): onnx_out = model_g.run(onnx_names, onnx_input)[0] pred = onnx_out - # onnxruntime_inputs = {"audio_seqs__0": mel_batch, } - # onnxruntime_names = [output.name for output in model_a.get_outputs()] - # embeddings = model_a.run(onnxruntime_names, onnxruntime_inputs)[0] - # - # onnxruntime_inputs = {"audio_embedings__0": embeddings, "img_seqs__1": img_batch} - # onnxruntime_names = [output.name for output in model_g.get_outputs()] - # - # start_model = time.time() - # onnxruntime_output = model_g.run(onnxruntime_names, onnxruntime_inputs)[0] - # end_model = time.time() - # pred = onnxruntime_output - count_time += (time.perf_counter() - t) count += batch_size @@ -205,3 +190,4 @@ class AudioInferenceOnnxHandler(AudioHandler): print('AudioInferenceHandler111') self._mal_queue.clear() print('AudioInferenceHandler222') + super().pause_talk() diff --git a/human/audio_mal_handler.py b/human/audio_mal_handler.py index 679f27d..43104cf 100644 --- a/human/audio_mal_handler.py +++ b/human/audio_mal_handler.py @@ -55,8 +55,7 @@ class AudioMalHandler(AudioHandler): logging.info('chunk2mal run') while self._exit_event.is_set() and self._is_running: self._run_step() - time.sleep(0.02) - + # time.sleep(0.01) logging.info('chunk2mal exit') def _run_step(self): @@ -80,7 +79,6 @@ class AudioMalHandler(AudioHandler): # print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames)) # cut off stride left = max(0, self._context.stride_left_size * 80 / self._context.fps) - right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50) mel_idx_multiplier = 80. * 2 / self._context.fps mel_step_size = 16 i = 0 @@ -107,6 +105,7 @@ class AudioMalHandler(AudioHandler): chunk = np.zeros(self.chunk, dtype=np.float32) frame = (chunk, '') type_ = 1 + # time.sleep(0.02) # logging.info(f'AudioMalHandler get_audio_frame type:{type_}') return frame, type_ @@ -124,3 +123,4 @@ class AudioMalHandler(AudioHandler): def pause_talk(self): print('AudioMalHandler pause_talk', self._queue.size()) self._queue.clear() + super().pause_talk() diff --git a/human/huaman_status.py b/human/huaman_status.py new file mode 100644 index 0000000..fe24139 --- /dev/null +++ b/human/huaman_status.py @@ -0,0 +1,60 @@ +#encoding = utf8 + +import logging + + +from enum import Enum + + +class HumanStatusEnum(Enum): + silence = 1 + talking = 2 + + +class HumanStatus: + def __init__(self, total_frames=0, silence_length=0): + self._status = HumanStatusEnum.silence + self._total_frames = total_frames + self._silence_length = silence_length + self._talking_length = total_frames - silence_length + self._current_frame = 0 + self._is_talking = False + + def get_status(self): + return self._status + + def set_status(self, status): + self._status = status + return self._status + + def try_to_talk(self): + if self._status == HumanStatusEnum.silence: + if self._current_frame - self._silence_length < 0: + return False + self._status = HumanStatusEnum.talking + return True + + def get_index(self): + if not self._is_talking: + index = self._current_frame % self._silence_length + + if self._current_frame >= self._silence_length: + self._is_talking = True + self._current_frame = 0 + + else: + index = self._silence_length + (self._current_frame - self._silence_length) % self._talking_length + + if self._current_frame >= self._silence_length + self._talking_length: + self._is_talking = False + self._current_frame = 0 + + self._current_frame = (self._current_frame + 1) % self._total_frames + return index + + def start_talking(self): + self._is_talking = True + + def stop_talking(self): + self._is_talking = False + self._current_frame = 0 diff --git a/human/human_context.py b/human/human_context.py index e76e8ad..6d55e8e 100644 --- a/human/human_context.py +++ b/human/human_context.py @@ -4,11 +4,9 @@ import os from asr import SherpaNcnnAsr from eventbus import EventBus -from .audio_inference_onnx_handler import AudioInferenceOnnxHandler from .audio_inference_handler import AudioInferenceHandler from .audio_mal_handler import AudioMalHandler -from .human_render import HumanRender -from nlp import PunctuationSplit, DouBao +from nlp import PunctuationSplit, DouBao, Kimi from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp from utils import load_avatar, get_device, object_stop, load_avatar_from_processed, load_avatar_from_256_processed @@ -18,13 +16,12 @@ current_file_path = os.path.dirname(os.path.abspath(__file__)) class HumanContext: def __init__(self): - self._fps = 25 # 20 ms per frame + self._fps = 50 # 20 ms per frame self._image_size = 288 - self._batch_size = 16 + self._batch_size = 6 self._sample_rate = 16000 - self._stride_left_size = 10 - self._stride_right_size = 10 - self._render_batch = 5 + self._stride_left_size = 4 + self._stride_right_size = 4 self._asr = None self._nlp = None @@ -64,10 +61,6 @@ class HumanContext: def image_size(self): return self._image_size - @property - def render_batch(self): - return self._render_batch - @property def device(self): return self._device @@ -118,14 +111,15 @@ class HumanContext: else: logger.info(f'notify message:{message}') - def build(self): - self._render_handler = HumanRender(self, None) + def build(self, render_handler): + self._render_handler = render_handler self._infer_handler = AudioInferenceHandler(self, self._render_handler) self._mal_handler = AudioMalHandler(self, self._infer_handler) self._tts_handle = TTSAudioSplitHandle(self, self._mal_handler) self._tts = TTSEdgeHttp(self._tts_handle) split = PunctuationSplit() self._nlp = DouBao(self, split, self._tts) + # self._nlp = Kimi(self, split, self._tts) self._asr = SherpaNcnnAsr() self._asr.attach(self._nlp) diff --git a/human/human_render.py b/human/human_render.py index 186ddbc..efe3a63 100644 --- a/human/human_render.py +++ b/human/human_render.py @@ -2,66 +2,73 @@ import logging import time +from enum import Enum from queue import Empty -from threading import Event, Thread from eventbus import EventBus -from human.message_type import MessageType from human_handler import AudioHandler -from render import VoiceRender, VideoRender, PlayClock from utils import SyncQueue logger = logging.getLogger(__name__) +class RenderStatus(Enum): + E_Normal = 0, + E_Full = 1, + E_Empty = 2 + + class HumanRender(AudioHandler): def __init__(self, context, handler): super().__init__(context, handler) EventBus().register('stop', self._on_stop) - EventBus().register('clear_cache', self.on_clear_cache) - play_clock = PlayClock() - self._voice_render = VoiceRender(play_clock, context) - self._video_render = VideoRender(play_clock, context, self) - self._is_running = True self._queue = SyncQueue(context.batch_size, "HumanRender_queue") - self._exit_event = Event() - self._thread = Thread(target=self._on_run, name="AudioMalHandlerThread") - self._exit_event.set() - self._thread.start() - self._image_render = None - self._last_audio_ps = 0 - self._last_video_ps = 0 self._empty_log = True + self._should_exit = False + self._render_status = RenderStatus.E_Empty def __del__(self): EventBus().unregister('stop', self._on_stop) - EventBus().unregister('clear_cache', self.on_clear_cache) def _on_stop(self, *args, **kwargs): + self._should_exit = True self.stop() - def on_clear_cache(self, *args, **kwargs): - self._queue.clear() + def _render(self, video_frame, voice_frame): + pass - def _on_run(self): + def run(self): logging.info('human render run') - while self._exit_event.is_set() and self._is_running: + delay = 1000 / self._context.fps * 0.001 + while not self._should_exit: + if self._render_status is RenderStatus.E_Full: + time.sleep(delay) + continue + + t = time.perf_counter() self._run_step() - delay = 0.04 - time.sleep(delay) + use = time.perf_counter() - t + if self._render_status is RenderStatus.E_Empty: + continue + real_delay = delay - use + # print(f'send voice {use}') + if real_delay > 0: + time.sleep(real_delay) + # else: + # print(f'send voice {real_delay}') logging.info('human render exit') def _run_step(self): try: - value = self._queue.get(timeout=.005) + value = self._queue.get(timeout=1) if value is None: return res_frame, idx, audio_frames = value if not self._empty_log: self._empty_log = True - logging.info('render render:') + logging.info('human render:') except Empty: if self._empty_log: self._empty_log = False @@ -71,47 +78,25 @@ class HumanRender(AudioHandler): type_ = 1 if audio_frames[0][1] != 0 and audio_frames[1][1] != 0: type_ = 0 - if self._voice_render is not None: - self._voice_render.render(audio_frames, self._last_audio_ps) - self._last_audio_ps = self._last_audio_ps + 0.4 - if self._video_render is not None: - self._video_render.render((res_frame, idx, type_), self._last_video_ps) - self._last_video_ps = self._last_video_ps + 0.4 - def set_image_render(self, render): - self._image_render = render - - def put_image(self, image): - if self._image_render is not None: - self._image_render.on_render(image) + self._render((res_frame, idx, type_), audio_frames) def on_message(self, message): super().on_message(message) def on_handle(self, stream, index): - if not self._is_running: + if self._should_exit: return - self._queue.put(stream) def pause_talk(self): logging.info('hunan pause_talk') - # self._voice_render.pause_talk() - # self._video_render.pause_talk() + self._queue.clear() + super().pause_talk() def stop(self): logging.info('hunan render stop') - self._is_running = False - if self._exit_event is None: - return - + self._should_exit = True self._queue.clear() - self._exit_event.clear() - if self._thread.is_alive(): - self._thread.join() - logging.info('hunan render stop') - # self._voice_render.stop() - # self._video_render.stop() - # self._exit_event.clear() - # self._thread.join() + logging.info('hunan render stop') diff --git a/human_handler/audio_handler.py b/human_handler/audio_handler.py index f2e307e..eb5eff6 100644 --- a/human_handler/audio_handler.py +++ b/human_handler/audio_handler.py @@ -29,4 +29,8 @@ class AudioHandler(ABC): logging.info(f'_handler is None') def pause_talk(self): - pass + if self._handler is not None: + self._handler.pause_talk() + else: + logging.info(f'next_pause_talk _handler is None') + diff --git a/ipc/__init__.py b/ipc/__init__.py new file mode 100644 index 0000000..3cd60af --- /dev/null +++ b/ipc/__init__.py @@ -0,0 +1,3 @@ +#encoding = utf8 + +from .ipc_util import IPCUtil diff --git a/ipc/ipc.dll b/ipc/ipc.dll new file mode 100644 index 0000000..1fc4150 Binary files /dev/null and b/ipc/ipc.dll differ diff --git a/ipc/ipc.exp b/ipc/ipc.exp new file mode 100644 index 0000000..6d11e2b Binary files /dev/null and b/ipc/ipc.exp differ diff --git a/ipc/ipc.lib b/ipc/ipc.lib new file mode 100644 index 0000000..8f4c600 Binary files /dev/null and b/ipc/ipc.lib differ diff --git a/ipc/ipc.pdb b/ipc/ipc.pdb new file mode 100644 index 0000000..4a63a4a Binary files /dev/null and b/ipc/ipc.pdb differ diff --git a/ipc/ipc_mem.py b/ipc/ipc_mem.py new file mode 100644 index 0000000..3abb62e --- /dev/null +++ b/ipc/ipc_mem.py @@ -0,0 +1,62 @@ +#encoding = utf8 + +import logging +import os + +from ctypes import * + +current = os.path.dirname(__file__) +dynamic_path = os.path.join(current, 'ipc.dll') + + +class IPCMem: + def __init__(self, sender, receiver): + self.__ipc_obj = WinDLL(dynamic_path) + print(self.__ipc_obj) + if self.__ipc_obj is not None: + self.__ipc_obj.initialize.argtypes = [c_char_p, c_char_p] + self.__ipc_obj.initialize.restype = c_bool + print('IPCUtil init', sender.encode('utf-8'), receiver.encode('utf-8')) + self.__init = self.__ipc_obj.initialize(sender.encode('utf-8'), receiver.encode('utf-8')) + print('IPCUtil init', self.__init) + + def __del__(self): + print('IPCUtil __del__') + if self.__ipc_obj is None: + return + if self.__init: + self.__ipc_obj.uninitialize() + + def listen(self): + if not self.__init: + return False + self.__ipc_obj.listen.restype = c_bool + return self.__ipc_obj.listen() + + def send_text(self, data): + if not self.__init: + return False + self.__ipc_obj.send.argtypes = [c_char_p, c_uint] + self.__ipc_obj.send.restype = c_bool + send_data = data.encode('utf-8') + send_len = len(send_data) + 1 + if not self.__ipc_obj.send(send_data, send_len): + self.__ipc_obj.reConnect() + return True + + def send_binary(self, data, size): + if not self.__init: + return False + self.__ipc_obj.send.argtypes = [c_char_p, c_uint] + self.__ipc_obj.send.restype = c_bool + data_ptr = cast(data, c_char_p) + return self.__ipc_obj.send(data_ptr, size) + + def set_reader_callback(self, callback): + if not self.__init: + return False + CALLBACK_TYPE = CFUNCTYPE(None, c_char_p, c_uint) + self.c_callback = CALLBACK_TYPE(callback) # Store the callback to prevent garbage collection + self.__ipc_obj.setReaderCallback.argtypes = [CALLBACK_TYPE] + self.__ipc_obj.setReaderCallback.restype = c_bool + return self.__ipc_obj.setReaderCallback(self.c_callback) diff --git a/ipc/ipc_util.py b/ipc/ipc_util.py new file mode 100644 index 0000000..9ae5d80 --- /dev/null +++ b/ipc/ipc_util.py @@ -0,0 +1,71 @@ +#encoding = utf8 +import os +import time +from ctypes import * + +current = os.path.dirname(__file__) +dynamic_path = os.path.join(current, 'ipc.dll') + + +class IPCUtil: + def __init__(self, sender, receiver): + self.__ipc_obj = WinDLL(dynamic_path) + print(self.__ipc_obj) + if self.__ipc_obj is not None: + self.__ipc_obj.initialize.argtypes = [c_char_p, c_char_p] + self.__ipc_obj.initialize.restype = c_bool + print('IPCUtil init', sender.encode('utf-8'), receiver.encode('utf-8')) + self.__init = self.__ipc_obj.initialize(sender.encode('utf-8'), receiver.encode('utf-8')) + print('IPCUtil init', self.__init) + + def __del__(self): + print('IPCUtil __del__') + if self.__ipc_obj is None: + return + if self.__init: + self.__ipc_obj.uninitialize() + + def listen(self): + if not self.__init: + return False + self.__ipc_obj.listen.restype = c_bool + return self.__ipc_obj.listen() + + def send_text(self, data): + if not self.__init: + return False + self.__ipc_obj.send.argtypes = [c_char_p, c_uint] + self.__ipc_obj.send.restype = c_bool + send_data = data.encode('utf-8') + send_len = len(send_data) + 1 + if not self.__ipc_obj.trySend(send_data, send_len): + self.__ipc_obj.reConnect() + return True + + def send_binary(self, data, size): + if not self.__init: + return False + self.__ipc_obj.send.argtypes = [c_char_p, c_uint] + self.__ipc_obj.send.restype = c_bool + data_ptr = cast(data, c_char_p) + return self.__ipc_obj.trySend(data_ptr, size) + + def set_reader_callback(self, callback): + if not self.__init: + return False + CALLBACK_TYPE = CFUNCTYPE(None, c_char_p, c_uint) + self.c_callback = CALLBACK_TYPE(callback) # Store the callback to prevent garbage collection + self.__ipc_obj.setReaderCallback.argtypes = [CALLBACK_TYPE] + self.__ipc_obj.setReaderCallback.restype = c_bool + return self.__ipc_obj.setReaderCallback(self.c_callback) + + +# def ipc_log_callback(log, size): +# print(f'log={log.decode("utf-8")}, len={size}') +# +# +# util = IPCUtil('ipc_sender', 'ipc_sender') +# util.set_reader_callback(ipc_log_callback) +# print(util.listen()) +# print(util.send_text('hello')) +# time.sleep(200) diff --git a/main.py b/main.py new file mode 100644 index 0000000..959b63f --- /dev/null +++ b/main.py @@ -0,0 +1,22 @@ +#encoding = utf8 + +import logging +import os + +from human import HumanContext +from ui import IpcRender +from utils import config_logging + +logger = logging.getLogger(__name__) +current_file_path = os.path.dirname(os.path.abspath(__file__)) + +if __name__ == '__main__': + config_logging('./logs/info.log', logging.INFO, logging.INFO) + + logger.info('------------start------------') + context = HumanContext() + render = IpcRender(context) + context.build(render) + render.run() + render.stop() + logger.info('------------finish------------') \ No newline at end of file diff --git a/nlp/__init__.py b/nlp/__init__.py index 3ebb749..9da22f8 100644 --- a/nlp/__init__.py +++ b/nlp/__init__.py @@ -2,4 +2,5 @@ from .nlp_callback import NLPCallback from .nlp_doubao import DouBao +from .nlp_kimi import Kimi from .nlp_split import PunctuationSplit diff --git a/nlp/nlp_base.py b/nlp/nlp_base.py index 5534901..377eb27 100644 --- a/nlp/nlp_base.py +++ b/nlp/nlp_base.py @@ -28,7 +28,9 @@ class NLPBase(AsrObserver): def on_clear_cache(self, *args, **kwargs): logger.info('NLPBase clear_cache') - self._ask_queue.clear() + self.pause_talk() + if self._callback is not None: + self._callback.on_clear() @property def callback(self): @@ -74,6 +76,5 @@ class NLPBase(AsrObserver): def pause_talk(self): logger.info('NLPBase pause_talk') - self._is_running = False self._ask_queue.clear() \ No newline at end of file diff --git a/nlp/nlp_callback.py b/nlp/nlp_callback.py index 96619b2..854a89e 100644 --- a/nlp/nlp_callback.py +++ b/nlp/nlp_callback.py @@ -7,3 +7,7 @@ class NLPCallback(ABC): @abstractmethod def on_message(self, txt: str): pass + + @abstractmethod + def on_clear(self): + pass diff --git a/nlp/nlp_doubao.py b/nlp/nlp_doubao.py index 26244ca..b97a4a2 100644 --- a/nlp/nlp_doubao.py +++ b/nlp/nlp_doubao.py @@ -80,7 +80,7 @@ class DouBaoHttp: } data = { - "model": "ep-20241008152048-fsgzf", + "model": "ep-20241207182221-mhhzq", "messages": question, 'stream': True } @@ -93,7 +93,7 @@ class DouBaoHttp: self._requesting = True logger.info(f'-------dou_bao ask:{question}') msg_list = [ - {"role": "system", "content": "你是测试客服,是由字节跳动开发的 AI 人工智能助手"}, + {"role": "system", "content": "你是库里,是一个学习小老师,擅长初中历史和数学。"}, {"role": "user", "content": question} ] self._response = self.__request(msg_list) diff --git a/nlp/nlp_kimi.py b/nlp/nlp_kimi.py new file mode 100644 index 0000000..833bc03 --- /dev/null +++ b/nlp/nlp_kimi.py @@ -0,0 +1,117 @@ +#encoding = utf8 +import json +import logging +import time + +import requests + +from nlp.nlp_base import NLPBase + +logger = logging.getLogger(__name__) + + +class KimiHttp: + def __init__(self, token): + self.__token = token + self._response = None + self._requesting = False + + def __request(self, question): + url = "https://api.moonshot.cn/v1/chat/completions" + headers = { + "Authorization": "Bearer " + self.__token, + "Content-Type": "application/json" + } + + data = { + "model": "moonshot-v1-8k", + "messages": question, + 'stream': True, + "temperature": 0.3 + } + + response = requests.post(url, headers=headers, json=data, stream=True) + return response + + def request(self, question, handle, callback): + t = time.time() + self._requesting = True + logger.info(f'-------dou_bao ask:{question}') + msg_list = [ + {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。" + "你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视," + "黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。"}, + {"role": "user", "content": question} + ] + self._response = self.__request(msg_list) + if not self._response.ok: + logger.error(f"请求失败,状态码:{self._response.status_code}") + return + sec = '' + for chunk in self._response.iter_lines(): + content = chunk.decode("utf-8").strip() + if len(content) < 1: + continue + content = content[5:] + content = content.strip() + if content == '[DONE]': + break + + try: + content = json.loads(content) + except Exception as e: + logger.error(f"json解析失败,错误信息:{e, content}") + continue + sec = sec + content["choices"][0]["delta"]["content"] + sec, message = handle.handle(sec) + if len(message) > 0: + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + callback(message) + if len(sec) > 0: + callback(sec) + + self._requesting = False + logger.info(f'-------dou_bao nlp time:{time.time() - t:.4f}s') + + def close(self): + if self._response is not None and self._requesting: + self._response.close() + + def aclose(self): + if self._response is not None and self._requesting: + self._response.close() + logger.info('DouBaoHttp close') + + +class Kimi(NLPBase): + def __init__(self, context, split, callback=None): + super().__init__(context, split, callback) + logger.info("DouBao init") + # Access Key ID + # AKLTYTdmOTBmNWFjODkxNDE2Zjk3MjU0NjRhM2JhM2IyN2Y + # AKLTNDZjNTdhNDlkZGE3NDZjMDlkMzk5YWQ3MDA4MTY1ZDc + # Secret Access Key + # WmpRelltRXhNbVkyWWpnNU5HRmpNamc0WTJZMFpUWmpOV1E1TTJFME1tTQ== + # TkRJMk1tTTFZamt4TkRVNE5HRTNZMkUyTnpFeU5qQmxNMkUwWXpaak1HRQ== + # endpoint_id + # ep-20241008152048-fsgzf + # api_key + # c9635f9e-0f9e-4ca1-ac90-8af25a541b74 + # api_ky + # eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJhcmstY29uc29sZSIsImV4cCI6MTczMDk2NTMxOSwiaWF0IjoxNzI4MzczMzE5LCJ0IjoidXNlciIsImt2IjoxLCJhaWQiOiIyMTAyMjc3NDc1IiwidWlkIjoiMCIsImlzX291dGVyX3VzZXIiOnRydWUsInJlc291cmNlX3R5cGUiOiJlbmRwb2ludCIsInJlc291cmNlX2lkcyI6WyJlcC0yMDI0MTAwODE1MjA0OC1mc2d6ZiJdfQ.BHgFj-UKeu7IGG5VL2e6iPQEMNMkQrgmM46zYmTpoNG_ySgSFJLWYzbrIABZmqVDB4Rt58j8kvoORs-RHJUz81rXUlh3BYl9-ZwbggtAU7Z1pm54_qZ00jF0jQ6r-fUSXZo2PVCLxb_clNuEh06NyaV7ullZwUCyLKx3vhCsxPAuEvQvLc_qDBx-IYNT-UApVADaqMs-OyewoxahqQ7RvaHFF14R6ihmg9H0uvl00_JiGThJveszKvy_T-Qk6iPOy-EDI2pwJxdHMZ7By0bWK5EfZoK2hOvOSRD0BNTYnvrTfI0l2JgS0nwCVEPR4KSTXxU_oVVtuUSZp1UHvvkhvA + self.__token = 'sk-yCx0lZUmfGx0ECEQAp8jTnAisHwUIokoDXN7XNBuvMILxWnN' + self._dou_bao = KimiHttp(self.__token) + + def _request(self, question): + self._dou_bao.request(question, self._split_handle, self._on_callback) + + def _on_close(self): + if self._dou_bao is not None: + self._dou_bao.close() + logger.info('AsyncArk close') + + def on_clear_cache(self, *args, **kwargs): + super().on_clear_cache(*args, **kwargs) + if self._dou_bao is not None: + self._dou_bao.aclose() + logger.info('DouBao clear_cache') diff --git a/nlp/nlp_split.py b/nlp/nlp_split.py index 75f6327..d938088 100644 --- a/nlp/nlp_split.py +++ b/nlp/nlp_split.py @@ -15,6 +15,7 @@ class PunctuationSplit(NLPSplit): def handle(self, message: str): message = message.replace('*', '') + message = message.replace('#', '') match = re.search(self._pattern, message) if match: pos = match.start() + 1 diff --git a/render/__init__.py b/render/__init__.py deleted file mode 100644 index 8d7f244..0000000 --- a/render/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -#encoding = utf8 - -from .voice_render import VoiceRender -from .video_render import VideoRender -from .play_clock import PlayClock diff --git a/render/base_render.py b/render/base_render.py deleted file mode 100644 index 3bccaf5..0000000 --- a/render/base_render.py +++ /dev/null @@ -1,25 +0,0 @@ -#encoding = utf8 -import logging -import time -from abc import ABC, abstractmethod -from queue import Queue -from threading import Event, Thread - -from utils import SyncQueue - -logger = logging.getLogger(__name__) - - -class BaseRender(ABC): - def __init__(self, play_clock, context, type_): - self._play_clock = play_clock - self._context = context - # self._queue = SyncQueue(context.batch_size, f'{type_}RenderQueue') - # self._exit_event = Event() - # self._thread = Thread(target=self._on_run, name=thread_name) - # self._exit_event.set() - # self._thread.start() - - @abstractmethod - def render(self, frame, ps): - pass diff --git a/render/play_clock.py b/render/play_clock.py deleted file mode 100644 index 870aee4..0000000 --- a/render/play_clock.py +++ /dev/null @@ -1,37 +0,0 @@ -#encoding = utf8 -import time - - -class PlayClock: - def __init__(self): - self._start = time.time() - self._current_time = 0 - self._display_time = self._start - self._audio_diff_threshold = 0.01 - - @property - def start_time(self): - return self._start - - @property - def current_time(self): - return self._current_time - - @current_time.setter - def current_time(self, v): - self._current_time = v - - @property - def audio_diff_threshold(self): - return self._audio_diff_threshold - - @property - def display_time(self): - return self._display_time - - def update_display_time(self): - self._display_time = time.time() - - def clock_time(self): - elapsed = time.time() - self._display_time - return self.current_time + elapsed diff --git a/render/video_render.py b/render/video_render.py deleted file mode 100644 index 91426bd..0000000 --- a/render/video_render.py +++ /dev/null @@ -1,23 +0,0 @@ -#encoding = utf8 -import copy -import logging -import time - -import cv2 -import numpy as np - -from .base_render import BaseRender - - -class VideoRender(BaseRender): - def __init__(self, play_clock, context, human_render): - super().__init__(play_clock, context, 'Video') - self._human_render = human_render - self.index = 0 - - def render(self, frame, ps): - if self._human_render is not None: - self._human_render.put_image(frame) - - # image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) - diff --git a/render/voice_render.py b/render/voice_render.py deleted file mode 100644 index eff4f19..0000000 --- a/render/voice_render.py +++ /dev/null @@ -1,39 +0,0 @@ -#encoding = utf8 -import logging -import time -from queue import Empty - -import numpy as np - -from audio_render import AudioRender -from human.message_type import MessageType -from .base_render import BaseRender - -logger = logging.getLogger(__name__) - - -class VoiceRender(BaseRender): - def __init__(self, play_clock, context): - self._audio_render = AudioRender() - super().__init__(play_clock, context, 'Voice') - self._current_text = '' - - def render(self, frame, ps): - self._play_clock.update_display_time() - self._play_clock.current_time = ps - - for audio_frame in frame: - frame, type_ = audio_frame - chunk, txt = frame - if txt != self._current_text: - self._current_text = txt - logging.info(f'VoiceRender: {txt}') - chunk = (chunk * 32767).astype(np.int16) - - if self._audio_render is not None: - try: - chunk_len = int(chunk.shape[0] * 2) - # print('audio frame:', frame.shape, chunk_len) - self._audio_render.write(chunk.tobytes(), chunk_len) - except Exception as e: - logging.error(f'Error writing audio frame: {e}') diff --git a/tts/tts_audio_handle.py b/tts/tts_audio_handle.py index 13ba8e4..1860aee 100644 --- a/tts/tts_audio_handle.py +++ b/tts/tts_audio_handle.py @@ -20,18 +20,13 @@ class TTSAudioHandle(AudioHandler): self._index = -1 EventBus().register('stop', self._on_stop) - EventBus().register('clear_cache', self.on_clear_cache) def __del__(self): EventBus().unregister('stop', self._on_stop) - EventBus().unregister('clear_cache', self.on_clear_cache) def _on_stop(self, *args, **kwargs): self.stop() - def on_clear_cache(self, *args, **kwargs): - self._index = -1 - @property def sample_rate(self): return self._sample_rate @@ -51,7 +46,8 @@ class TTSAudioHandle(AudioHandler): pass def pause_talk(self): - pass + self._index = -1 + super().pause_talk() class TTSAudioSplitHandle(TTSAudioHandle): @@ -76,7 +72,7 @@ class TTSAudioSplitHandle(TTSAudioHandle): if chunks is not None: for chunk in chunks: self.on_next_handle((chunk, txt), 0) - time.sleep(0.01) # Sleep briefly to prevent busy-waiting + time.sleep(0.001) # Sleep briefly to prevent busy-waiting def on_handle(self, stream, index): if not self._is_running: @@ -103,8 +99,8 @@ class TTSAudioSplitHandle(TTSAudioHandle): self._is_running = False self._thread.join() - def on_clear_cache(self, *args, **kwargs): - super().on_clear_cache() + def pause_talk(self): + super().pause_talk() with self._lock: self._current = 0 self._priority_queue.clear() diff --git a/tts/tts_base.py b/tts/tts_base.py index 62af6c7..5ee0bab 100644 --- a/tts/tts_base.py +++ b/tts/tts_base.py @@ -16,18 +16,18 @@ class TTSBase(NLPCallback): self._message_queue = AsyncTaskQueue('TTSBaseQueue', 5) self._is_running = True EventBus().register('stop', self.on_stop) - EventBus().register('clear_cache', self.on_clear_cache) def __del__(self): EventBus().unregister('stop', self.on_stop) - EventBus().unregister('clear_cache', self.on_clear_cache) def on_stop(self, *args, **kwargs): self.stop() - def on_clear_cache(self, *args, **kwargs): - logger.info('TTSBase clear_cache') - self._message_queue.clear() + def on_clear(self): + self.pause_talk() + + if self._handle is not None: + self._handle.pause_talk() @property def handle(self): diff --git a/tts/tts_edge_http.py b/tts/tts_edge_http.py index c3ad86a..6d669d1 100644 --- a/tts/tts_edge_http.py +++ b/tts/tts_edge_http.py @@ -97,8 +97,8 @@ class TTSEdgeHttp(TTSBase): # if self._byte_stream is not None and not self._byte_stream.closed: # self._byte_stream.close() - def on_clear_cache(self, *args, **kwargs): + def on_clear(self): logger.info('TTSEdgeHttp clear_cache') - super().on_clear_cache(*args, **kwargs) for response in self._response_list: response.close() + super().on_clear() diff --git a/ui/__init__.py b/ui/__init__.py index 865208a..d16e6ce 100644 --- a/ui/__init__.py +++ b/ui/__init__.py @@ -1,3 +1,3 @@ #encoding = utf8 - +from .ipc_render import IpcRender diff --git a/ui/ipc_render.py b/ui/ipc_render.py new file mode 100644 index 0000000..6b42433 --- /dev/null +++ b/ui/ipc_render.py @@ -0,0 +1,79 @@ +#encoding = utf8 + +import os +import logging +import time + +import numpy as np + +from human import HumanRender, RenderStatus +from ipc import IPCUtil +from utils import render_image + +logger = logging.getLogger(__name__) +current_file_path = os.path.dirname(os.path.abspath(__file__)) + + +class IpcRender(HumanRender): + def __init__(self, context): + super().__init__(context, None) + self._ipc = IPCUtil('human_product', 'human_render') + self._current_text = '' + + def _send_image(self, image): + height, width, channels = image.shape + t = time.perf_counter() + width_bytes = width.to_bytes(4, byteorder='little') + height_bytes = height.to_bytes(4, byteorder='little') + bit_depth_bytes = channels.to_bytes(4, byteorder='little') + + img_bytes = image.tobytes() + identifier = b'\x01' + data = identifier + width_bytes + height_bytes + bit_depth_bytes + img_bytes + self._ipc.send_binary(data, len(data)) + + def _send_voice(self, voice): + voice_identifier = b'\x02' + data = voice_identifier + for audio_frame in voice: + frame, type_ = audio_frame + chunk, txt = frame + if txt != self._current_text: + self._current_text = txt + logging.info(f'VoiceRender: {txt}') + chunk = (chunk * 32767).astype(np.int16) + voice_bytes = chunk.tobytes() + data = data + voice_bytes + + self._ipc.send_binary(data, len(data)) + + def _on_reader_callback(self, data_str, size): + data_str = data_str.decode('utf-8') + # print(f'on_reader_callback: {data_str}, size:{size}') + if 'quit' == data_str: + self._context.stop() + elif 'heartbeat' == data_str: + pass + elif 'full' == data_str: + if self._render_status != RenderStatus.E_Full: + # logger.info(f'change to E_Full status') + self._render_status = RenderStatus.E_Full + elif 'empty' == data_str: + if self._render_status != RenderStatus.E_Empty: + # logger.info(f'change to E_Full status') + self._render_status = RenderStatus.E_Empty + elif 'normal' == data_str: + if self._render_status != RenderStatus.E_Normal: + # logger.info(f'change to E_Normal status') + self._render_status = RenderStatus.E_Normal + + def run(self): + self._ipc.set_reader_callback(self._on_reader_callback) + logger.info(f'ipc listen:{self._ipc.listen()}') + super().run() + + def _render(self, video_frame, voice_frame): + image = render_image(self._context, video_frame) + self._send_image(image) + self._send_voice(voice_frame) + diff --git a/ui/pygame_ui.py b/ui/pygame_ui.py index 134b580..02bb31b 100644 --- a/ui/pygame_ui.py +++ b/ui/pygame_ui.py @@ -1,70 +1,51 @@ #encoding = utf8 -import copy import logging import os -import time from queue import Queue import cv2 -import numpy as np import pygame from pygame.locals import * from human import HumanContext -from utils import config_logging +from ipc import IPCUtil +from utils import config_logging, render_image logger = logging.getLogger(__name__) current_file_path = os.path.dirname(os.path.abspath(__file__)) - -def img_warp_back_inv_m(img, img_to, inv_m): - h_up, w_up, c = img_to.shape - mask = np.ones_like(img).astype(np.float32) - inv_mask = cv2.warpAffine(mask, inv_m, (w_up, h_up)) - inv_img = cv2.warpAffine(img, inv_m, (w_up, h_up)) - mask_indices = inv_mask == 1 - if 4 == c: - img_to[:, :, :3][mask_indices] = inv_img[mask_indices] - else: - img_to[inv_mask == 1] = inv_img[inv_mask == 1] - return img_to +ipc = IPCUtil('ipc_sender', 'ipc_sender') -def render_image(context, frame): - res_frame, idx, type_ = frame +def send_image(identifier, image): + height, width, channels = image.shape - if type_ == 0: - combine_frame = context.frame_list_cycle[idx] - else: - bbox = context.coord_list_cycle[idx] - combine_frame = copy.deepcopy(context.frame_list_cycle[idx]) - af = context.align_frames[idx] - inv_m = context.inv_m_frames[idx] - y1, y2, x1, x2 = bbox - try: - t = time.perf_counter() - res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) - af[y1:y2, x1:x2] = res_frame - combine_frame = img_warp_back_inv_m(af, combine_frame, inv_m) - except Exception as e: - logging.error(f'resize error:{e}') - return + width_bytes = width.to_bytes(4, byteorder='little') + height_bytes = height.to_bytes(4, byteorder='little') + bit_depth_bytes = channels.to_bytes(4, byteorder='little') - image = combine_frame - return image + img_bytes = image.tobytes() + data = identifier + width_bytes + height_bytes + bit_depth_bytes + img_bytes + ipc.send_binary(data, len(data)) + + +def cal_box(inv_m, p): + x = inv_m[0][0] * p[0] + inv_m[0][1] * p[1] + inv_m[0][2] + y = inv_m[1][0] * p[0] + inv_m[1][1] * p[1] + inv_m[1][2] + return x, y class PyGameUI: def __init__(self): self._human_context = None self._queue = None - self.screen_ = pygame.display.set_mode((1920, 1080), HWSURFACE | DOUBLEBUF | RESIZABLE) + self.screen_ = pygame.display.set_mode((920, 860), HWSURFACE | DOUBLEBUF | RESIZABLE) self.clock = pygame.time.Clock() background = os.path.join(current_file_path, '..', 'data', 'background', 'background.jpg') logger.info(f'background: {background}') self._background = pygame.image.load(background).convert() - self.background_display_ = pygame.transform.scale(self._background, (1920, 1080)) + self.background_display_ = pygame.transform.scale(self._background, (920, 860)) self._human_image = None self.running = True @@ -87,7 +68,7 @@ class PyGameUI: self.screen_.blit(self.background_display_, (0, 0)) self._update_human() if self._human_image is not None: - self.screen_.blit(self._human_image, (760, -300)) + self.screen_.blit(self._human_image, (0, -300)) fps = self.clock.get_fps() pygame.display.set_caption('fps:{:.2f}'.format(fps)) diff --git a/utils/__init__.py b/utils/__init__.py index 2f3ca11..c3c95a3 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -5,4 +5,5 @@ from .sync_queue import SyncQueue from .utils import mirror_index, load_model, get_device, load_avatar, config_logging from .utils import read_image, object_stop from .utils import load_avatar_from_processed, load_avatar_from_256_processed +from .utils import render_image from .audio_utils import melspectrogram, save_wav diff --git a/utils/utils.py b/utils/utils.py index eb307ba..23b117e 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,4 +1,5 @@ #encoding = utf8 +import copy import glob import logging import os @@ -274,3 +275,38 @@ def config_logging(file_name: str, console_level: int = logging.INFO, file_level def object_stop(obj): if obj is not None: obj.stop() + + +def img_warp_back_inv_m(img, img_to, inv_m): + h_up, w_up, c = img_to.shape + mask = np.ones_like(img).astype(np.float32) + inv_mask = cv2.warpAffine(mask, inv_m, (w_up, h_up)) + inv_img = cv2.warpAffine(img, inv_m, (w_up, h_up)) + mask_indices = inv_mask == 1 + if 4 == c: + img_to[:, :, :3][mask_indices] = inv_img[mask_indices] + else: + img_to[inv_mask == 1] = inv_img[inv_mask == 1] + return img_to + + +def render_image(context, frame): + res_frame, idx, type_ = frame + + if type_ == 0: + combine_frame = context.frame_list_cycle[idx] + else: + bbox = context.coord_list_cycle[idx] + combine_frame = copy.deepcopy(context.frame_list_cycle[idx]) + af = context.align_frames[idx] + inv_m = context.inv_m_frames[idx] + y1, y2, x1, x2 = bbox + try: + res_frame = cv2.resize(res_frame.astype(np.uint8), (x2 - x1, y2 - y1)) + af[y1:y2, x1:x2] = res_frame + combine_frame = img_warp_back_inv_m(af, combine_frame, inv_m) + except Exception as e: + logging.error(f'resize error:{e}') + return None + + return combine_frame