Skip to content

Commit

Permalink
Changed whisper output structure
Browse files Browse the repository at this point in the history
  • Loading branch information
movchan74 committed Nov 9, 2023
1 parent dde9313 commit aba2be1
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 41 deletions.
6 changes: 5 additions & 1 deletion aana/configs/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@
name="whisper_transcribe",
path="/video/transcribe",
summary="Transcribe a video using Whisper Medium",
outputs=["video_transcriptions_whisper_medium"],
outputs=[
"video_transcriptions_whisper_medium",
"video_transcriptions_segments_whisper_medium",
"video_transcriptions_info_whisper_medium",
],
)
],
}
26 changes: 21 additions & 5 deletions aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
It is used to generate the pipeline and the API endpoints.
"""

from aana.models.pydantic.asr_output import AsrOutputList
from aana.models.pydantic.asr_output import (
AsrSegmentsList,
AsrTranscriptionInfoList,
AsrTranscriptionList,
)
from aana.models.pydantic.captions import CaptionsList, VideoCaptionsList
from aana.models.pydantic.image_input import ImageInputList
from aana.models.pydantic.prompt import Prompt
Expand Down Expand Up @@ -281,12 +285,24 @@
}
],
"outputs": [
{
"name": "video_transcriptions_segments_whisper_medium",
"key": "segments",
"path": "video_batch.videos.[*].segments",
"data_model": AsrSegmentsList,
},
{
"name": "video_transcriptions_info_whisper_medium",
"key": "transcription_info",
"path": "video_batch.videos.[*].transcription_info",
"data_model": AsrTranscriptionInfoList,
},
{
"name": "video_transcriptions_whisper_medium",
"key": "asr_outputs",
"path": "video_batch.videos.[*].asr_output",
"data_model": AsrOutputList,
}
"key": "transcription",
"path": "video_batch.videos.[*].transcription",
"data_model": AsrTranscriptionList,
},
],
},
]
53 changes: 40 additions & 13 deletions aana/deployments/whisper_deployment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List
from typing import Any, Dict, List, Union, cast
from faster_whisper import WhisperModel
from pydantic import BaseModel, Field
from ray import serve
Expand All @@ -8,7 +8,11 @@
from aana.deployments.base_deployment import BaseDeployment
from aana.exceptions.general import InferenceException
from aana.models.core.video import Video
from aana.models.pydantic.asr_output import AsrOutput
from aana.models.pydantic.asr_output import (
AsrSegment,
AsrTranscription,
AsrTranscriptionInfo,
)
from aana.models.pydantic.whisper_params import WhisperParams


Expand Down Expand Up @@ -111,7 +115,7 @@ async def apply_config(self, config: Dict[str, Any]):
# TODO: add audio support
async def transcribe(
self, media: Video, params: WhisperParams = WhisperParams()
) -> Dict[str, AsrOutput]:
) -> Dict[str, Union[List[AsrSegment], AsrTranscriptionInfo, AsrTranscription]]:
"""
Transcribe the media with the whisper model.
Expand All @@ -120,8 +124,11 @@ async def transcribe(
params (WhisperParams): The parameters for the whisper model.
Returns:
Dict[str, Any]: The transcription output as a dictionary:
- asr_output (AsrOutput): The ASR output.
Dict[str, Union[List[AsrSegment], AsrTranscriptionInfo, AsrTranscription]]:
The transcription output as a dictionary:
segments (List[AsrSegment]): The ASR segments.
transcription_info (AsrTranscriptionInfo): The ASR transcription info.
transcription (AsrTranscription): The ASR transcription.
Raises:
InferenceException: If the inference fails.
Expand All @@ -132,14 +139,25 @@ async def transcribe(
segments, info = self.model.transcribe(media_path, **params.dict())
except Exception as e:
raise InferenceException(self.model_name) from e
asr_output = AsrOutput.from_whisper(segments, info)

asr_segments = [AsrSegment.from_whisper(seg) for seg in segments]
asr_transcription_info = AsrTranscriptionInfo.from_whisper(info)
transcription = "".join([seg.text for seg in asr_segments])
asr_transcription = AsrTranscription(text=transcription)
return {
"asr_output": asr_output,
"segments": asr_segments,
"transcription_info": asr_transcription_info,
"transcription": asr_transcription,
}

async def transcribe_batch(
self, media: List[Video], params: WhisperParams = WhisperParams()
) -> Dict[str, List[AsrOutput]]:
) -> Dict[
str,
Union[
List[List[AsrSegment]], List[AsrTranscriptionInfo], List[AsrTranscription]
],
]:
"""
Transcribe the batch of media with the whisper model.
Expand All @@ -148,17 +166,26 @@ async def transcribe_batch(
params (WhisperParams): The parameters for the whisper model.
Returns:
Dict[str, List[AsrOutput]]: The transcription outputs for each media as a dictionary:
- asr_output (List[AsrOutput]): The ASR outputs.
Dict[str, Union[List[List[AsrSegment]], List[AsrTranscriptionInfo], List[AsrTranscription]]]:
The transcription output as a dictionary:
segments (List[List[AsrSegment]]): The ASR segments for each media.
transcription_info (List[AsrTranscriptionInfo]): The ASR transcription info for each media.
transcription (List[AsrTranscription]): The ASR transcription for each media.
Raises:
InferenceException: If the inference fails.
"""

asr_outputs = []
segments: List[List[AsrSegment]] = []
infos: List[AsrTranscriptionInfo] = []
transcriptions: List[AsrTranscription] = []
for m in media:
output = await self.transcribe(m, params)
asr_outputs.append(output["asr_output"])
segments.append(cast(List[AsrSegment], output["segments"]))
infos.append(cast(AsrTranscriptionInfo, output["transcription_info"]))
transcriptions.append(cast(AsrTranscription, output["transcription"]))
return {
"asr_outputs": asr_outputs,
"segments": segments,
"transcription_info": infos,
"transcription": transcriptions,
}
68 changes: 46 additions & 22 deletions aana/models/pydantic/asr_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,42 +121,66 @@ class Config:
}


class AsrOutput(BaseModel):
class AsrTranscription(BaseModel):
"""
Pydantic schema for ASR output.
Pydantic schema for Transcription/Translation.
"""

segments: List[AsrSegment] = Field(description="List of segments")
transcription_info: AsrTranscriptionInfo = Field(description="Transcription info")
text: str = Field(description="The text of the transcription/translation")

@classmethod
def from_whisper(
cls,
segments: List[WhisperSegment],
transcription_info: WhisperTranscriptionInfo,
) -> "AsrOutput":
"""
Convert Whisper output to ASR output.
"""
return cls(
segments=[AsrSegment.from_whisper(seg) for seg in segments],
transcription_info=AsrTranscriptionInfo.from_whisper(transcription_info),
)
class Config:
schema_extra = {
"description": "Transcription/Translation",
}


class AsrSegments(BaseListModel):
"""
Pydantic schema for the list of ASR segments.
"""

__root__: List[AsrSegment]

class Config:
schema_extra = {
"description": "List of ASR segments",
}


class AsrSegmentsList(BaseListModel):
"""
Pydantic schema for the list of lists of ASR segments.
"""

__root__: List[AsrSegments]

class Config:
schema_extra = {
"description": "List of lists of ASR segments",
}


class AsrTranscriptionInfoList(BaseListModel):
"""
Pydantic schema for the list of ASR transcription info.
"""

__root__: List[AsrTranscriptionInfo]

class Config:
schema_extra = {
"description": "ASR output",
"description": "List of ASR transcription info",
}


class AsrOutputList(BaseListModel):
class AsrTranscriptionList(BaseListModel):
"""
Pydantic schema for the list of ASR outputs.
Pydantic schema for the list of ASR transcription.
"""

__root__: List[AsrOutput]
__root__: List[AsrTranscription]

class Config:
schema_extra = {
"description": "List of ASR outputs",
"description": "List of ASR transcription",
}

0 comments on commit aba2be1

Please sign in to comment.