diff --git a/audiotools/core/whisper.py b/audiotools/core/whisper.py index 46c071f9..53287e28 100644 --- a/audiotools/core/whisper.py +++ b/audiotools/core/whisper.py @@ -21,7 +21,7 @@ def setup_whisper( ).to(self.whisper_device) self.is_initialized = True - def get_whisper_features(self) -> torch.Tensor: + def get_whisper_features(self, **kwargs) -> torch.Tensor: """Preprocess audio signal as per the whisper model's training config. Returns @@ -49,11 +49,12 @@ def get_whisper_features(self) -> torch.Tensor: raw_speech, sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, return_tensors="pt", + **kwargs ).input_features return input_features - def get_whisper_transcript(self) -> str: + def get_whisper_transcript(self, **kwargs) -> str: """Get the transcript of the audio signal using the whisper model. Returns @@ -69,12 +70,12 @@ def get_whisper_transcript(self) -> str: with torch.inference_mode(): input_features = input_features.to(self.whisper_device) - generated_ids = self.whisper_model.generate(inputs=input_features) + generated_ids = self.whisper_model.generate(input_features=input_features, **kwargs) - transcription = self.whisper_processor.batch_decode(generated_ids) + transcription = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True) return transcription[0] - def get_whisper_embeddings(self) -> torch.Tensor: + def get_whisper_embeddings(self, **kwargs) -> torch.Tensor: """Get the last hidden state embeddings of the audio signal using the whisper model. Returns @@ -92,6 +93,6 @@ def get_whisper_embeddings(self) -> torch.Tensor: with torch.inference_mode(): input_features = input_features.to(self.whisper_device) - embeddings = encoder(input_features) + embeddings = encoder(input_features, **kwargs) return embeddings.last_hidden_state