From f611479dca4e9dc629e20bc1a4c2c2ab90a3ef98 Mon Sep 17 00:00:00 2001 From: Ajay Arasanipalai Date: Tue, 5 Mar 2024 16:06:10 -0800 Subject: [PATCH 1/2] feat: accept kwargs for whisper mixin --- audiotools/core/whisper.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/audiotools/core/whisper.py b/audiotools/core/whisper.py index 46c071f9..e3508e7a 100644 --- a/audiotools/core/whisper.py +++ b/audiotools/core/whisper.py @@ -21,7 +21,7 @@ def setup_whisper( ).to(self.whisper_device) self.is_initialized = True - def get_whisper_features(self) -> torch.Tensor: + def get_whisper_features(self, **kwargs) -> torch.Tensor: """Preprocess audio signal as per the whisper model's training config. Returns @@ -49,11 +49,12 @@ def get_whisper_features(self) -> torch.Tensor: raw_speech, sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, return_tensors="pt", + **kwargs ).input_features return input_features - def get_whisper_transcript(self) -> str: + def get_whisper_transcript(self, **kwargs) -> str: """Get the transcript of the audio signal using the whisper model. Returns @@ -69,12 +70,12 @@ def get_whisper_transcript(self) -> str: with torch.inference_mode(): input_features = input_features.to(self.whisper_device) - generated_ids = self.whisper_model.generate(inputs=input_features) + generated_ids = self.whisper_model.generate(input_features=input_features, **kwargs) transcription = self.whisper_processor.batch_decode(generated_ids) return transcription[0] - def get_whisper_embeddings(self) -> torch.Tensor: + def get_whisper_embeddings(self, **kwargs) -> torch.Tensor: """Get the last hidden state embeddings of the audio signal using the whisper model. Returns @@ -92,6 +93,6 @@ def get_whisper_embeddings(self) -> torch.Tensor: with torch.inference_mode(): input_features = input_features.to(self.whisper_device) - embeddings = encoder(input_features) + embeddings = encoder(input_features, **kwargs) return embeddings.last_hidden_state From 11f77d03d582c065966ede4e46a7e7d6a1b358c1 Mon Sep 17 00:00:00 2001 From: Ajay Arasanipalai Date: Tue, 5 Mar 2024 16:10:53 -0800 Subject: [PATCH 2/2] feat: skip special tokens --- audiotools/core/whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/audiotools/core/whisper.py b/audiotools/core/whisper.py index e3508e7a..53287e28 100644 --- a/audiotools/core/whisper.py +++ b/audiotools/core/whisper.py @@ -72,7 +72,7 @@ def get_whisper_transcript(self, **kwargs) -> str: input_features = input_features.to(self.whisper_device) generated_ids = self.whisper_model.generate(input_features=input_features, **kwargs) - transcription = self.whisper_processor.batch_decode(generated_ids) + transcription = self.whisper_processor.batch_decode(generated_ids, skip_special_tokens=True) return transcription[0] def get_whisper_embeddings(self, **kwargs) -> torch.Tensor: