Skip to content

Commit

Permalink
Merge pull request #204 from mobiusml/whisper_health_check
Browse files Browse the repository at this point in the history
Improved Error Handling in WhisperDeployment
  • Loading branch information
movchan74 authored Nov 15, 2024
2 parents 62a5ee6 + e8126a6 commit 5024bbf
Showing 1 changed file with 22 additions and 23 deletions.
45 changes: 22 additions & 23 deletions aana/deployments/whisper_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,20 @@ async def transcribe_stream(
segments, info = self.model.transcribe(
audio_array, **params.model_dump()
)
asr_transcription_info = AsrTranscriptionInfo.from_whisper(info)
for segment in segments:
await asyncio.sleep(0)
asr_segments = [AsrSegment.from_whisper(segment)]
asr_transcription = AsrTranscription(text=segment.text)

yield WhisperOutput(
segments=asr_segments,
transcription_info=asr_transcription_info,
transcription=asr_transcription,
)
except Exception as e:
raise InferenceException(self.model_name) from e

asr_transcription_info = AsrTranscriptionInfo.from_whisper(info)
for segment in segments:
await asyncio.sleep(0)
asr_segments = [AsrSegment.from_whisper(segment)]
asr_transcription = AsrTranscription(text=segment.text)

yield WhisperOutput(
segments=asr_segments,
transcription_info=asr_transcription_info,
transcription=asr_transcription,
)

async def transcribe_batch(
self, audio_batch: list[Audio], params: WhisperParams | None = None
) -> WhisperBatchOutput:
Expand Down Expand Up @@ -284,7 +283,7 @@ async def transcribe_in_chunks(
self,
audio: Audio,
vad_segments: list[VadSegment] | None = None,
batch_size: int = 16,
batch_size: int = 4,
params: BatchedWhisperParams | None = None,
) -> AsyncGenerator[WhisperOutput, None]:
"""Transcribe a single audio by segmenting it into chunks (4x faster) in streaming mode.
Expand Down Expand Up @@ -326,15 +325,15 @@ async def transcribe_in_chunks(
batch_size=batch_size,
**params.model_dump(),
)
asr_transcription_info = AsrTranscriptionInfo.from_whisper(info)
for segment in segments:
await asyncio.sleep(0)
asr_segments = [AsrSegment.from_whisper(segment)]
asr_transcription = AsrTranscription(text=segment.text)
yield WhisperOutput(
segments=asr_segments,
transcription_info=asr_transcription_info,
transcription=asr_transcription,
)
except Exception as e:
raise InferenceException(self.model_name) from e
asr_transcription_info = AsrTranscriptionInfo.from_whisper(info)
for segment in segments:
await asyncio.sleep(0)
asr_segments = [AsrSegment.from_whisper(segment)]
asr_transcription = AsrTranscription(text=segment.text)
yield WhisperOutput(
segments=asr_segments,
transcription_info=asr_transcription_info,
transcription=asr_transcription,
)

0 comments on commit 5024bbf

Please sign in to comment.