From aba2be10b61e8c3d4383b5716da6e4e3d4d28318 Mon Sep 17 00:00:00 2001 From: movchan74 Date: Thu, 9 Nov 2023 12:04:48 +0000 Subject: [PATCH] Changed whisper output structure --- aana/configs/endpoints.py | 6 ++- aana/configs/pipeline.py | 26 ++++++++-- aana/deployments/whisper_deployment.py | 53 +++++++++++++++----- aana/models/pydantic/asr_output.py | 68 +++++++++++++++++--------- 4 files changed, 112 insertions(+), 41 deletions(-) diff --git a/aana/configs/endpoints.py b/aana/configs/endpoints.py index e003061c..02ff43da 100644 --- a/aana/configs/endpoints.py +++ b/aana/configs/endpoints.py @@ -45,7 +45,11 @@ name="whisper_transcribe", path="/video/transcribe", summary="Transcribe a video using Whisper Medium", - outputs=["video_transcriptions_whisper_medium"], + outputs=[ + "video_transcriptions_whisper_medium", + "video_transcriptions_segments_whisper_medium", + "video_transcriptions_info_whisper_medium", + ], ) ], } diff --git a/aana/configs/pipeline.py b/aana/configs/pipeline.py index 6e50e3e9..23f1c906 100644 --- a/aana/configs/pipeline.py +++ b/aana/configs/pipeline.py @@ -3,7 +3,11 @@ It is used to generate the pipeline and the API endpoints. """ -from aana.models.pydantic.asr_output import AsrOutputList +from aana.models.pydantic.asr_output import ( + AsrSegmentsList, + AsrTranscriptionInfoList, + AsrTranscriptionList, +) from aana.models.pydantic.captions import CaptionsList, VideoCaptionsList from aana.models.pydantic.image_input import ImageInputList from aana.models.pydantic.prompt import Prompt @@ -281,12 +285,24 @@ } ], "outputs": [ + { + "name": "video_transcriptions_segments_whisper_medium", + "key": "segments", + "path": "video_batch.videos.[*].segments", + "data_model": AsrSegmentsList, + }, + { + "name": "video_transcriptions_info_whisper_medium", + "key": "transcription_info", + "path": "video_batch.videos.[*].transcription_info", + "data_model": AsrTranscriptionInfoList, + }, { "name": "video_transcriptions_whisper_medium", - "key": "asr_outputs", - "path": "video_batch.videos.[*].asr_output", - "data_model": AsrOutputList, - } + "key": "transcription", + "path": "video_batch.videos.[*].transcription", + "data_model": AsrTranscriptionList, + }, ], }, ] diff --git a/aana/deployments/whisper_deployment.py b/aana/deployments/whisper_deployment.py index d8faac13..0dcd96d9 100644 --- a/aana/deployments/whisper_deployment.py +++ b/aana/deployments/whisper_deployment.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List +from typing import Any, Dict, List, Union, cast from faster_whisper import WhisperModel from pydantic import BaseModel, Field from ray import serve @@ -8,7 +8,11 @@ from aana.deployments.base_deployment import BaseDeployment from aana.exceptions.general import InferenceException from aana.models.core.video import Video -from aana.models.pydantic.asr_output import AsrOutput +from aana.models.pydantic.asr_output import ( + AsrSegment, + AsrTranscription, + AsrTranscriptionInfo, +) from aana.models.pydantic.whisper_params import WhisperParams @@ -111,7 +115,7 @@ async def apply_config(self, config: Dict[str, Any]): # TODO: add audio support async def transcribe( self, media: Video, params: WhisperParams = WhisperParams() - ) -> Dict[str, AsrOutput]: + ) -> Dict[str, Union[List[AsrSegment], AsrTranscriptionInfo, AsrTranscription]]: """ Transcribe the media with the whisper model. @@ -120,8 +124,11 @@ async def transcribe( params (WhisperParams): The parameters for the whisper model. Returns: - Dict[str, Any]: The transcription output as a dictionary: - - asr_output (AsrOutput): The ASR output. + Dict[str, Union[List[AsrSegment], AsrTranscriptionInfo, AsrTranscription]]: + The transcription output as a dictionary: + segments (List[AsrSegment]): The ASR segments. + transcription_info (AsrTranscriptionInfo): The ASR transcription info. + transcription (AsrTranscription): The ASR transcription. Raises: InferenceException: If the inference fails. @@ -132,14 +139,25 @@ async def transcribe( segments, info = self.model.transcribe(media_path, **params.dict()) except Exception as e: raise InferenceException(self.model_name) from e - asr_output = AsrOutput.from_whisper(segments, info) + + asr_segments = [AsrSegment.from_whisper(seg) for seg in segments] + asr_transcription_info = AsrTranscriptionInfo.from_whisper(info) + transcription = "".join([seg.text for seg in asr_segments]) + asr_transcription = AsrTranscription(text=transcription) return { - "asr_output": asr_output, + "segments": asr_segments, + "transcription_info": asr_transcription_info, + "transcription": asr_transcription, } async def transcribe_batch( self, media: List[Video], params: WhisperParams = WhisperParams() - ) -> Dict[str, List[AsrOutput]]: + ) -> Dict[ + str, + Union[ + List[List[AsrSegment]], List[AsrTranscriptionInfo], List[AsrTranscription] + ], + ]: """ Transcribe the batch of media with the whisper model. @@ -148,17 +166,26 @@ async def transcribe_batch( params (WhisperParams): The parameters for the whisper model. Returns: - Dict[str, List[AsrOutput]]: The transcription outputs for each media as a dictionary: - - asr_output (List[AsrOutput]): The ASR outputs. + Dict[str, Union[List[List[AsrSegment]], List[AsrTranscriptionInfo], List[AsrTranscription]]]: + The transcription output as a dictionary: + segments (List[List[AsrSegment]]): The ASR segments for each media. + transcription_info (List[AsrTranscriptionInfo]): The ASR transcription info for each media. + transcription (List[AsrTranscription]): The ASR transcription for each media. Raises: InferenceException: If the inference fails. """ - asr_outputs = [] + segments: List[List[AsrSegment]] = [] + infos: List[AsrTranscriptionInfo] = [] + transcriptions: List[AsrTranscription] = [] for m in media: output = await self.transcribe(m, params) - asr_outputs.append(output["asr_output"]) + segments.append(cast(List[AsrSegment], output["segments"])) + infos.append(cast(AsrTranscriptionInfo, output["transcription_info"])) + transcriptions.append(cast(AsrTranscription, output["transcription"])) return { - "asr_outputs": asr_outputs, + "segments": segments, + "transcription_info": infos, + "transcription": transcriptions, } diff --git a/aana/models/pydantic/asr_output.py b/aana/models/pydantic/asr_output.py index 3137afdd..600476d6 100644 --- a/aana/models/pydantic/asr_output.py +++ b/aana/models/pydantic/asr_output.py @@ -121,42 +121,66 @@ class Config: } -class AsrOutput(BaseModel): +class AsrTranscription(BaseModel): """ - Pydantic schema for ASR output. + Pydantic schema for Transcription/Translation. """ - segments: List[AsrSegment] = Field(description="List of segments") - transcription_info: AsrTranscriptionInfo = Field(description="Transcription info") + text: str = Field(description="The text of the transcription/translation") - @classmethod - def from_whisper( - cls, - segments: List[WhisperSegment], - transcription_info: WhisperTranscriptionInfo, - ) -> "AsrOutput": - """ - Convert Whisper output to ASR output. - """ - return cls( - segments=[AsrSegment.from_whisper(seg) for seg in segments], - transcription_info=AsrTranscriptionInfo.from_whisper(transcription_info), - ) + class Config: + schema_extra = { + "description": "Transcription/Translation", + } + + +class AsrSegments(BaseListModel): + """ + Pydantic schema for the list of ASR segments. + """ + + __root__: List[AsrSegment] + + class Config: + schema_extra = { + "description": "List of ASR segments", + } + + +class AsrSegmentsList(BaseListModel): + """ + Pydantic schema for the list of lists of ASR segments. + """ + + __root__: List[AsrSegments] + + class Config: + schema_extra = { + "description": "List of lists of ASR segments", + } + + +class AsrTranscriptionInfoList(BaseListModel): + """ + Pydantic schema for the list of ASR transcription info. + """ + + __root__: List[AsrTranscriptionInfo] class Config: schema_extra = { - "description": "ASR output", + "description": "List of ASR transcription info", } -class AsrOutputList(BaseListModel): +class AsrTranscriptionList(BaseListModel): """ - Pydantic schema for the list of ASR outputs. + Pydantic schema for the list of ASR transcription. """ - __root__: List[AsrOutput] + __root__: List[AsrTranscription] class Config: schema_extra = { - "description": "List of ASR outputs", + "description": "List of ASR transcription", }