From a6953da2a29316d03a2a677d54863066e689a6a6 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Thu, 14 Nov 2024 10:51:05 +0000 Subject: [PATCH 1/3] Add health check mechanism to WhisperDeployment --- aana/deployments/whisper_deployment.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/aana/deployments/whisper_deployment.py b/aana/deployments/whisper_deployment.py index d0251d58..5e2c1c22 100644 --- a/aana/deployments/whisper_deployment.py +++ b/aana/deployments/whisper_deployment.py @@ -161,6 +161,12 @@ async def apply_config(self, config: dict[str, Any]): self.batched_model = BatchedInferencePipeline( model=self.model, ) + self.healthy = True + + async def check_health(self): + """Check the health of the deployment.""" + if not self.healthy: + raise RuntimeError(f"Whisper model {self.model_name} is unhealthy.") # noqa: TRY003 async def transcribe( self, audio: Audio, params: WhisperParams | None = None @@ -230,6 +236,9 @@ async def transcribe_stream( segments, info = self.model.transcribe( audio_array, **params.model_dump() ) + except torch.OutOfMemoryError as e: + self.healthy = False + raise InferenceException(self.model_name) from e except Exception as e: raise InferenceException(self.model_name) from e @@ -326,6 +335,9 @@ async def transcribe_in_chunks( batch_size=batch_size, **params.model_dump(), ) + except torch.OutOfMemoryError as e: + self.healthy = False + raise InferenceException(self.model_name) from e except Exception as e: raise InferenceException(self.model_name) from e asr_transcription_info = AsrTranscriptionInfo.from_whisper(info) From eb8403f947cb1ef55d2dacb0ef0bbaee576ceaab Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Thu, 14 Nov 2024 11:20:27 +0000 Subject: [PATCH 2/3] Move segment loop under try-except in WhisperDeployment to catch inference errors. --- aana/deployments/whisper_deployment.py | 55 ++++++++++---------------- 1 file changed, 21 insertions(+), 34 deletions(-) diff --git a/aana/deployments/whisper_deployment.py b/aana/deployments/whisper_deployment.py index 5e2c1c22..df32466d 100644 --- a/aana/deployments/whisper_deployment.py +++ b/aana/deployments/whisper_deployment.py @@ -161,12 +161,6 @@ async def apply_config(self, config: dict[str, Any]): self.batched_model = BatchedInferencePipeline( model=self.model, ) - self.healthy = True - - async def check_health(self): - """Check the health of the deployment.""" - if not self.healthy: - raise RuntimeError(f"Whisper model {self.model_name} is unhealthy.") # noqa: TRY003 async def transcribe( self, audio: Audio, params: WhisperParams | None = None @@ -236,24 +230,20 @@ async def transcribe_stream( segments, info = self.model.transcribe( audio_array, **params.model_dump() ) - except torch.OutOfMemoryError as e: - self.healthy = False - raise InferenceException(self.model_name) from e + asr_transcription_info = AsrTranscriptionInfo.from_whisper(info) + for segment in segments: + await asyncio.sleep(0) + asr_segments = [AsrSegment.from_whisper(segment)] + asr_transcription = AsrTranscription(text=segment.text) + + yield WhisperOutput( + segments=asr_segments, + transcription_info=asr_transcription_info, + transcription=asr_transcription, + ) except Exception as e: raise InferenceException(self.model_name) from e - asr_transcription_info = AsrTranscriptionInfo.from_whisper(info) - for segment in segments: - await asyncio.sleep(0) - asr_segments = [AsrSegment.from_whisper(segment)] - asr_transcription = AsrTranscription(text=segment.text) - - yield WhisperOutput( - segments=asr_segments, - transcription_info=asr_transcription_info, - transcription=asr_transcription, - ) - async def transcribe_batch( self, audio_batch: list[Audio], params: WhisperParams | None = None ) -> WhisperBatchOutput: @@ -335,18 +325,15 @@ async def transcribe_in_chunks( batch_size=batch_size, **params.model_dump(), ) - except torch.OutOfMemoryError as e: - self.healthy = False - raise InferenceException(self.model_name) from e + asr_transcription_info = AsrTranscriptionInfo.from_whisper(info) + for segment in segments: + await asyncio.sleep(0) + asr_segments = [AsrSegment.from_whisper(segment)] + asr_transcription = AsrTranscription(text=segment.text) + yield WhisperOutput( + segments=asr_segments, + transcription_info=asr_transcription_info, + transcription=asr_transcription, + ) except Exception as e: raise InferenceException(self.model_name) from e - asr_transcription_info = AsrTranscriptionInfo.from_whisper(info) - for segment in segments: - await asyncio.sleep(0) - asr_segments = [AsrSegment.from_whisper(segment)] - asr_transcription = AsrTranscription(text=segment.text) - yield WhisperOutput( - segments=asr_segments, - transcription_info=asr_transcription_info, - transcription=asr_transcription, - ) From e8126a622ed04228de422744bf40547d959fad7a Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Fri, 15 Nov 2024 13:00:43 +0000 Subject: [PATCH 3/3] Reduce default batch size from 16 to 4 in WhisperDeployment. --- aana/deployments/whisper_deployment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aana/deployments/whisper_deployment.py b/aana/deployments/whisper_deployment.py index df32466d..8f8eaa9c 100644 --- a/aana/deployments/whisper_deployment.py +++ b/aana/deployments/whisper_deployment.py @@ -283,7 +283,7 @@ async def transcribe_in_chunks( self, audio: Audio, vad_segments: list[VadSegment] | None = None, - batch_size: int = 16, + batch_size: int = 4, params: BatchedWhisperParams | None = None, ) -> AsyncGenerator[WhisperOutput, None]: """Transcribe a single audio by segmenting it into chunks (4x faster) in streaming mode.