modiy crash
This commit is contained in:
parent
273dbfb0ec
commit
cb6dd54497
@ -21,7 +21,7 @@ class AudioInferenceHandler(AudioHandler):
|
||||
super().__init__(context, handler)
|
||||
|
||||
self._mal_queue = Queue()
|
||||
self._audio_queue = SyncQueue(context.render_batch)
|
||||
self._audio_queue = SyncQueue(context.batch_size * 2)
|
||||
|
||||
self._exit_event = Event()
|
||||
self._run_thread = Thread(target=self.__on_run)
|
||||
@ -34,6 +34,7 @@ class AudioInferenceHandler(AudioHandler):
|
||||
self._mal_queue.put(stream)
|
||||
elif type_ == 0:
|
||||
self._audio_queue.put(stream)
|
||||
print('AudioInferenceHandler on_handle', type_)
|
||||
|
||||
def on_message(self, message):
|
||||
super().on_message(message)
|
||||
@ -60,7 +61,7 @@ class AudioInferenceHandler(AudioHandler):
|
||||
start_time = time.perf_counter()
|
||||
batch_size = self._context.batch_size
|
||||
try:
|
||||
mel_batch = self._mal_queue.get(block=True, timeout=0.1)
|
||||
mel_batch = self._mal_queue.get(block=True, timeout=1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
@ -85,6 +86,7 @@ class AudioInferenceHandler(AudioHandler):
|
||||
face = face_list_cycle[idx]
|
||||
img_batch.append(face)
|
||||
|
||||
print('orign img_batch:', len(img_batch), 'origin mel_batch:', len(mel_batch))
|
||||
img_batch = np.asarray(img_batch)
|
||||
mel_batch = np.asarray(mel_batch)
|
||||
img_masked = img_batch.copy()
|
||||
@ -96,7 +98,7 @@ class AudioInferenceHandler(AudioHandler):
|
||||
|
||||
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
|
||||
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
|
||||
|
||||
print('img_batch:', img_batch.shape, 'mel_batch:', mel_batch.shape)
|
||||
with torch.no_grad():
|
||||
pred = model(mel_batch, img_batch)
|
||||
|
||||
|
@ -18,7 +18,7 @@ class AudioMalHandler(AudioHandler):
|
||||
def __init__(self, context, handler):
|
||||
super().__init__(context, handler)
|
||||
|
||||
self._queue = SyncQueue(context.render_batch)
|
||||
self._queue = SyncQueue(context.batch_size)
|
||||
self._exit_event = Event()
|
||||
self._thread = Thread(target=self._on_run)
|
||||
self._exit_event.set()
|
||||
@ -32,6 +32,7 @@ class AudioMalHandler(AudioHandler):
|
||||
super().on_message(message)
|
||||
|
||||
def on_handle(self, stream, index):
|
||||
print('AudioMalHandler on_handle', index)
|
||||
self._queue.put(stream)
|
||||
|
||||
def _on_run(self):
|
||||
@ -51,12 +52,13 @@ class AudioMalHandler(AudioHandler):
|
||||
if len(self.frames) <= self._context.stride_left_size + self._context.stride_right_size:
|
||||
return
|
||||
|
||||
print('AudioMalHandler _run_step')
|
||||
inputs = np.concatenate(self.frames) # [N * chunk]
|
||||
mel = melspectrogram(inputs)
|
||||
# print(mel.shape[0],mel.shape,len(mel[0]),len(self.frames))
|
||||
# cut off stride
|
||||
left = max(0, self._context.stride_left_size * 80 / 50)
|
||||
right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50)
|
||||
# right = min(len(mel[0]), len(mel[0]) - self._context.stride_right_size * 80 / 50)
|
||||
mel_idx_multiplier = 80. * 2 / self._context.fps
|
||||
mel_step_size = 16
|
||||
i = 0
|
||||
@ -76,12 +78,15 @@ class AudioMalHandler(AudioHandler):
|
||||
|
||||
def get_audio_frame(self):
|
||||
try:
|
||||
frame = self._queue.get(block=True, timeout=0.01)
|
||||
frame = self._queue.get()
|
||||
type_ = 0
|
||||
if frame is None:
|
||||
frame = np.zeros(self.chunk, dtype=np.float32)
|
||||
type_ = 1
|
||||
except queue.Empty:
|
||||
frame = np.zeros(self.chunk, dtype=np.float32)
|
||||
type_ = 1
|
||||
|
||||
print('AudioMalHandler get_audio_frame type:', type_)
|
||||
return frame, type_
|
||||
|
||||
def stop(self):
|
||||
|
@ -16,7 +16,7 @@ class BaseRender(ABC):
|
||||
self._context = context
|
||||
self._type = type_
|
||||
self._delay = delay
|
||||
self._queue = SyncQueue(context.render_batch)
|
||||
self._queue = SyncQueue(context.batch_size)
|
||||
self._exit_event = Event()
|
||||
self._thread = Thread(target=self._on_run)
|
||||
self._exit_event.set()
|
||||
|
@ -20,7 +20,10 @@ class VideoRender(BaseRender):
|
||||
def _run_step(self):
|
||||
while self._exit_event.is_set():
|
||||
try:
|
||||
frame, ps = self._queue.get(block=True, timeout=0.02)
|
||||
value = self._queue.get()
|
||||
if value is None:
|
||||
return
|
||||
frame, ps = value
|
||||
res_frame, idx, type_ = frame
|
||||
except Empty:
|
||||
return
|
||||
@ -45,7 +48,7 @@ class VideoRender(BaseRender):
|
||||
self._diff_avg_count = 0
|
||||
|
||||
print('video render:', ps, ' ', clock_time, ' ', time_difference,
|
||||
'get face', self._queue.qsize(), self._diff_avg_count)
|
||||
'get face', self._queue.size(), self._diff_avg_count)
|
||||
|
||||
if type_ == 0:
|
||||
combine_frame = self._context.frame_list_cycle[idx]
|
||||
|
@ -19,11 +19,14 @@ class VoiceRender(BaseRender):
|
||||
super().__init__(play_clock, context, 'Voice')
|
||||
|
||||
def is_full(self):
|
||||
return self._queue.qsize() >= self._context.render_batch * 2
|
||||
return self._queue.size() >= self._context.render_batch * 2
|
||||
|
||||
def _run_step(self):
|
||||
try:
|
||||
audio_frames, ps = self._queue.get(block=True, timeout=0.01)
|
||||
value = self._queue.get()
|
||||
if value is None:
|
||||
return
|
||||
audio_frames, ps = value
|
||||
# print('voice render queue size', self._queue.qsize())
|
||||
except Empty:
|
||||
self._context.notify({'msg_id': MessageType.Video_Render_Queue_Empty})
|
||||
@ -37,9 +40,9 @@ class VoiceRender(BaseRender):
|
||||
self._is_empty = False
|
||||
|
||||
status = MessageType.Video_Render_Queue_Not_Empty
|
||||
if self._queue.qsize() < self._context.render_batch:
|
||||
if self._queue.size() < self._context.render_batch:
|
||||
status = MessageType.Video_Render_Queue_Empty
|
||||
elif self._queue.qsize() >= self._context.render_batch * 2:
|
||||
elif self._queue.size() >= self._context.render_batch * 2:
|
||||
status = MessageType.Video_Render_Queue_Full
|
||||
self._context.notify({'msg_id': status})
|
||||
|
||||
|
@ -2,12 +2,12 @@
|
||||
|
||||
import time
|
||||
|
||||
from tts import TTSEdge, TTSAudioSaveHandle
|
||||
from tts import TTSEdge, TTSAudioSaveHandle, TTSEdgeHttp
|
||||
|
||||
|
||||
def main():
|
||||
handle = TTSAudioSaveHandle()
|
||||
tts = TTSEdge(handle)
|
||||
handle = TTSAudioSaveHandle(None, None)
|
||||
tts = TTSEdgeHttp(handle)
|
||||
tts.message('你好,')
|
||||
tts.message('请问有什么可以帮到您,')
|
||||
tts.message('很高兴为您服务。')
|
||||
|
@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
class TTSBase(NLPCallback):
|
||||
def __init__(self, handle):
|
||||
self._handle = handle
|
||||
self._message_queue = AsyncTaskQueue(5)
|
||||
self._message_queue = AsyncTaskQueue(1)
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
|
@ -21,7 +21,7 @@ class TTSEdgeHttp(TTSBase):
|
||||
logger.info(f"TTSEdge init, {voice}")
|
||||
|
||||
async def _on_request(self, txt: str):
|
||||
print('_on_request, txt')
|
||||
print('TTSEdgeHttp, _on_request, txt:', txt)
|
||||
data = {
|
||||
"model": "tts-1",
|
||||
"input": txt,
|
||||
@ -31,6 +31,7 @@ class TTSEdgeHttp(TTSBase):
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self._url, json=data) as response:
|
||||
print('TTSEdgeHttp, _on_request, response:', response)
|
||||
if response.status == 200:
|
||||
stream = BytesIO(await response.read())
|
||||
return stream
|
||||
@ -43,7 +44,7 @@ class TTSEdgeHttp(TTSBase):
|
||||
try:
|
||||
stream.seek(0)
|
||||
byte_stream = self.__create_bytes_stream(stream)
|
||||
print('-------tts start push chunk')
|
||||
print('-------tts start push chunk', index)
|
||||
self._handle.on_handle(byte_stream, index)
|
||||
stream.seek(0)
|
||||
stream.truncate()
|
||||
|
@ -3,20 +3,24 @@
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
|
||||
'''
|
||||
class SyncQueue:
|
||||
def __init__(self, maxsize):
|
||||
self._queue = Queue(maxsize)
|
||||
# self._queue = Queue()
|
||||
self._condition = threading.Condition()
|
||||
|
||||
def put(self, item):
|
||||
# self._queue.put(item)
|
||||
with self._condition:
|
||||
while self._queue.full():
|
||||
print('put wait')
|
||||
self._condition.wait()
|
||||
self._queue.put(item)
|
||||
self._condition.notify()
|
||||
|
||||
def get(self):
|
||||
# return self._queue.get(block=True, timeout=0.01)
|
||||
with self._condition:
|
||||
while self._queue.empty():
|
||||
self._condition.wait()
|
||||
@ -25,11 +29,29 @@ class SyncQueue:
|
||||
return item
|
||||
|
||||
def clear(self):
|
||||
# self._queue.queue.clear()
|
||||
with self._condition:
|
||||
while not self._queue.empty():
|
||||
self._queue.get()
|
||||
self._queue.task_done()
|
||||
self._queue.queue.clear()
|
||||
self._condition.notify_all()
|
||||
|
||||
def size(self):
|
||||
return self._queue.qsize()
|
||||
'''
|
||||
|
||||
|
||||
class SyncQueue:
|
||||
def __init__(self, maxsize):
|
||||
self._queue = Queue()
|
||||
|
||||
def put(self, item):
|
||||
self._queue.put(item)
|
||||
|
||||
def get(self):
|
||||
return self._queue.get(block=True, timeout=0.2)
|
||||
|
||||
def clear(self):
|
||||
self._queue.queue.clear()
|
||||
|
||||
def size(self):
|
||||
return self._queue.qsize()
|
||||
|
Loading…
Reference in New Issue
Block a user