modify human iamge

This commit is contained in:
jiegeaiai 2024-11-02 21:14:54 +08:00
parent db61dc5329
commit d5db3a3020
7 changed files with 30 additions and 21 deletions

View File

Before

Width:  |  Height:  |  Size: 114 KiB

After

Width:  |  Height:  |  Size: 114 KiB

View File

Before

Width:  |  Height:  |  Size: 258 KiB

After

Width:  |  Height:  |  Size: 258 KiB

View File

@ -103,7 +103,7 @@ class AudioInferenceHandler(AudioHandler):
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
print('img_batch:', img_batch.shape, 'mel_batch:', mel_batch.shape) # print('img_batch:', img_batch.shape, 'mel_batch:', mel_batch.shape)
with torch.no_grad(): with torch.no_grad():
pred = model(mel_batch, img_batch) pred = model(mel_batch, img_batch)

View File

@ -8,7 +8,7 @@ from .audio_mal_handler import AudioMalHandler
from .human_render import HumanRender from .human_render import HumanRender
from nlp import PunctuationSplit, DouBao from nlp import PunctuationSplit, DouBao
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
from utils import load_avatar, get_device from utils import load_avatar, get_device, object_stop
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
current_file_path = os.path.dirname(os.path.abspath(__file__)) current_file_path = os.path.dirname(os.path.abspath(__file__))
@ -24,6 +24,14 @@ class HumanContext:
self._stride_right_size = 10 self._stride_right_size = 10
self._render_batch = 5 self._render_batch = 5
self._asr = None
self._nlp = None
self._tts = None
self._tts_handle = None
self._mal_handler = None
self._infer_handler = None
self._render_handler = None
self._device = get_device() self._device = get_device()
print(f'device:{self._device}') print(f'device:{self._device}')
base_path = os.path.join(current_file_path, '..', 'face') base_path = os.path.join(current_file_path, '..', 'face')
@ -36,23 +44,16 @@ class HumanContext:
logging.info(f'face images length: {face_images_length}') logging.info(f'face images length: {face_images_length}')
print(f'face images length: {face_images_length}') print(f'face images length: {face_images_length}')
self._asr = None
self._nlp = None
self._tts = None
self._tts_handle = None
self._mal_handler = None
self._infer_handler = None
self._render_handler = None
def __del__(self): def __del__(self):
print(f'HumanContext: __del__') print(f'HumanContext: __del__')
self._asr.stop() object_stop(self._asr)
self._nlp.stop() object_stop(self._nlp)
self._tts.stop() object_stop(self._tts)
self._tts_handle.stop() object_stop(self._tts_handle)
self._mal_handler.stop() object_stop(self._mal_handler)
self._infer_handler.stop() object_stop(self._infer_handler)
self._render_handler.stop() object_stop(self._render_handler)
@property @property
def fps(self): def fps(self):

View File

@ -13,7 +13,7 @@ from human.message_type import MessageType
class VideoRender(BaseRender): class VideoRender(BaseRender):
def __init__(self, play_clock, context, human_render): def __init__(self, play_clock, context, human_render):
super().__init__(play_clock, context, 'Video', 0.03, "VideoRenderThread") super().__init__(play_clock, context, 'Video', 0.038, "VideoRenderThread")
self._human_render = human_render self._human_render = human_render
self._diff_avg_count = 0 self._diff_avg_count = 0
@ -31,7 +31,7 @@ class VideoRender(BaseRender):
clock_time = self._play_clock.clock_time() clock_time = self._play_clock.clock_time()
time_difference = clock_time - ps time_difference = clock_time - ps
if abs(time_difference) > self._play_clock.audio_diff_threshold: if abs(time_difference) > self._play_clock.audio_diff_threshold:
if self._diff_avg_count < 5: if self._diff_avg_count < 3:
self._diff_avg_count += 1 self._diff_avg_count += 1
else: else:
if time_difference < -self._play_clock.audio_diff_threshold: if time_difference < -self._play_clock.audio_diff_threshold:
@ -65,7 +65,7 @@ class VideoRender(BaseRender):
combine_frame[y1:y2, x1:x2] = res_frame combine_frame[y1:y2, x1:x2] = res_frame
image = combine_frame image = combine_frame
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self._human_render is not None: if self._human_render is not None:
self._human_render.put_image(image) self._human_render.put_image(image)
return return

View File

@ -3,5 +3,5 @@
from .async_task_queue import AsyncTaskQueue from .async_task_queue import AsyncTaskQueue
from .sync_queue import SyncQueue from .sync_queue import SyncQueue
from .utils import mirror_index, load_model, get_device, load_avatar, config_logging from .utils import mirror_index, load_model, get_device, load_avatar, config_logging
from .utils import read_image from .utils import read_image, object_stop
from .audio_utils import melspectrogram, save_wav from .audio_utils import melspectrogram, save_wav

View File

@ -34,7 +34,10 @@ def read_images(img_list):
print('reading images...') print('reading images...')
for img_path in tqdm(img_list): for img_path in tqdm(img_list):
print(f'read image path:{img_path}') print(f'read image path:{img_path}')
frame = cv2.imread(img_path) # frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
frame = Image.open(img_path)
frame = frame.convert("RGBA")
frame = np.array(frame)
frames.append(frame) frames.append(frame)
return frames return frames
@ -201,3 +204,8 @@ def config_logging(file_name: str, console_level: int = logging.INFO, file_level
level=min(console_level, file_level), level=min(console_level, file_level),
handlers=[file_handler, console_handler], handlers=[file_handler, console_handler],
) )
def object_stop(obj):
if obj is not None:
obj.stop()