modify human iamge
This commit is contained in:
parent
db61dc5329
commit
d5db3a3020
Before Width: | Height: | Size: 114 KiB After Width: | Height: | Size: 114 KiB |
Before Width: | Height: | Size: 258 KiB After Width: | Height: | Size: 258 KiB |
@ -103,7 +103,7 @@ class AudioInferenceHandler(AudioHandler):
|
||||
|
||||
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
||||
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
||||
print('img_batch:', img_batch.shape, 'mel_batch:', mel_batch.shape)
|
||||
# print('img_batch:', img_batch.shape, 'mel_batch:', mel_batch.shape)
|
||||
with torch.no_grad():
|
||||
pred = model(mel_batch, img_batch)
|
||||
|
||||
|
@ -8,7 +8,7 @@ from .audio_mal_handler import AudioMalHandler
|
||||
from .human_render import HumanRender
|
||||
from nlp import PunctuationSplit, DouBao
|
||||
from tts import TTSEdge, TTSAudioSplitHandle, TTSEdgeHttp
|
||||
from utils import load_avatar, get_device
|
||||
from utils import load_avatar, get_device, object_stop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
current_file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
@ -24,6 +24,14 @@ class HumanContext:
|
||||
self._stride_right_size = 10
|
||||
self._render_batch = 5
|
||||
|
||||
self._asr = None
|
||||
self._nlp = None
|
||||
self._tts = None
|
||||
self._tts_handle = None
|
||||
self._mal_handler = None
|
||||
self._infer_handler = None
|
||||
self._render_handler = None
|
||||
|
||||
self._device = get_device()
|
||||
print(f'device:{self._device}')
|
||||
base_path = os.path.join(current_file_path, '..', 'face')
|
||||
@ -36,23 +44,16 @@ class HumanContext:
|
||||
logging.info(f'face images length: {face_images_length}')
|
||||
print(f'face images length: {face_images_length}')
|
||||
|
||||
self._asr = None
|
||||
self._nlp = None
|
||||
self._tts = None
|
||||
self._tts_handle = None
|
||||
self._mal_handler = None
|
||||
self._infer_handler = None
|
||||
self._render_handler = None
|
||||
|
||||
def __del__(self):
|
||||
print(f'HumanContext: __del__')
|
||||
self._asr.stop()
|
||||
self._nlp.stop()
|
||||
self._tts.stop()
|
||||
self._tts_handle.stop()
|
||||
self._mal_handler.stop()
|
||||
self._infer_handler.stop()
|
||||
self._render_handler.stop()
|
||||
object_stop(self._asr)
|
||||
object_stop(self._nlp)
|
||||
object_stop(self._tts)
|
||||
object_stop(self._tts_handle)
|
||||
object_stop(self._mal_handler)
|
||||
object_stop(self._infer_handler)
|
||||
object_stop(self._render_handler)
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
|
@ -13,7 +13,7 @@ from human.message_type import MessageType
|
||||
|
||||
class VideoRender(BaseRender):
|
||||
def __init__(self, play_clock, context, human_render):
|
||||
super().__init__(play_clock, context, 'Video', 0.03, "VideoRenderThread")
|
||||
super().__init__(play_clock, context, 'Video', 0.038, "VideoRenderThread")
|
||||
self._human_render = human_render
|
||||
self._diff_avg_count = 0
|
||||
|
||||
@ -31,7 +31,7 @@ class VideoRender(BaseRender):
|
||||
clock_time = self._play_clock.clock_time()
|
||||
time_difference = clock_time - ps
|
||||
if abs(time_difference) > self._play_clock.audio_diff_threshold:
|
||||
if self._diff_avg_count < 5:
|
||||
if self._diff_avg_count < 3:
|
||||
self._diff_avg_count += 1
|
||||
else:
|
||||
if time_difference < -self._play_clock.audio_diff_threshold:
|
||||
@ -65,7 +65,7 @@ class VideoRender(BaseRender):
|
||||
combine_frame[y1:y2, x1:x2] = res_frame
|
||||
|
||||
image = combine_frame
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
if self._human_render is not None:
|
||||
self._human_render.put_image(image)
|
||||
return
|
||||
|
@ -3,5 +3,5 @@
|
||||
from .async_task_queue import AsyncTaskQueue
|
||||
from .sync_queue import SyncQueue
|
||||
from .utils import mirror_index, load_model, get_device, load_avatar, config_logging
|
||||
from .utils import read_image
|
||||
from .utils import read_image, object_stop
|
||||
from .audio_utils import melspectrogram, save_wav
|
||||
|
@ -34,7 +34,10 @@ def read_images(img_list):
|
||||
print('reading images...')
|
||||
for img_path in tqdm(img_list):
|
||||
print(f'read image path:{img_path}')
|
||||
frame = cv2.imread(img_path)
|
||||
# frame = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
||||
frame = Image.open(img_path)
|
||||
frame = frame.convert("RGBA")
|
||||
frame = np.array(frame)
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
@ -201,3 +204,8 @@ def config_logging(file_name: str, console_level: int = logging.INFO, file_level
|
||||
level=min(console_level, file_level),
|
||||
handlers=[file_handler, console_handler],
|
||||
)
|
||||
|
||||
|
||||
def object_stop(obj):
|
||||
if obj is not None:
|
||||
obj.stop()
|
||||
|
Loading…
Reference in New Issue
Block a user