From 125de4164364420854d7fe537a9bd2fdaf7369d4 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 3 Dec 2024 21:58:54 +0800 Subject: [PATCH] =?UTF-8?q?fix=20speecht5=20failure=20issue=20in=20test=5F?= =?UTF-8?q?peft=5Fgradient=5Fcheckpointing=5Fenable=E2=80=A6=20(#34454)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix speecht5 failure issue in test_peft_gradient_checkpointing_enable_disable Signed-off-by: Wang, Yi * [run-slow] speecht5 --------- Signed-off-by: Wang, Yi Co-authored-by: Matt --- .../models/speecht5/modeling_speecht5.py | 2 +- tests/models/speecht5/test_modeling_speecht5.py | 12 ------------ 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 63b536d185a379..72cbe6b14a93be 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2114,7 +2114,7 @@ def get_input_embeddings(self): return self.encoder.get_input_embeddings() if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet): return self.decoder.get_input_embeddings() - return None + raise NotImplementedError def set_input_embeddings(self, value): if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet): diff --git a/tests/models/speecht5/test_modeling_speecht5.py b/tests/models/speecht5/test_modeling_speecht5.py index 97abf1a2cf2c2c..38f75ac5c01dc1 100644 --- a/tests/models/speecht5/test_modeling_speecht5.py +++ b/tests/models/speecht5/test_modeling_speecht5.py @@ -237,12 +237,6 @@ def test_torchscript_output_hidden_state(self): def test_torchscript_simple(self): pass - @unittest.skip( - reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527" - ) - def test_peft_gradient_checkpointing_enable_disable(self): - pass - @require_torch class SpeechT5ForSpeechToTextTester: @@ -1741,12 +1735,6 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass - @unittest.skip( - reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527" - ) - def test_peft_gradient_checkpointing_enable_disable(self): - pass - # overwrite from test_modeling_common def _mock_init_weights(self, module): if hasattr(module, "weight") and module.weight is not None: