Skip to content

Commit

Permalink
fix speecht5 failure issue in test_peft_gradient_checkpointing_enable… (
Browse files Browse the repository at this point in the history
#34454)

* fix speecht5 failure issue in test_peft_gradient_checkpointing_enable_disable

Signed-off-by: Wang, Yi <[email protected]>

* [run-slow] speecht5

---------

Signed-off-by: Wang, Yi <[email protected]>
Co-authored-by: Matt <[email protected]>
  • Loading branch information
sywangyi and Rocketknight1 authored Dec 3, 2024
1 parent 7a7f276 commit 125de41
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 13 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/speecht5/modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2114,7 +2114,7 @@ def get_input_embeddings(self):
return self.encoder.get_input_embeddings()
if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet):
return self.decoder.get_input_embeddings()
return None
raise NotImplementedError

def set_input_embeddings(self, value):
if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet):
Expand Down
12 changes: 0 additions & 12 deletions tests/models/speecht5/test_modeling_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,6 @@ def test_torchscript_output_hidden_state(self):
def test_torchscript_simple(self):
pass

@unittest.skip(
reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527"
)
def test_peft_gradient_checkpointing_enable_disable(self):
pass


@require_torch
class SpeechT5ForSpeechToTextTester:
Expand Down Expand Up @@ -1741,12 +1735,6 @@ def test_training_gradient_checkpointing_use_reentrant(self):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass

@unittest.skip(
reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527"
)
def test_peft_gradient_checkpointing_enable_disable(self):
pass

# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
Expand Down

0 comments on commit 125de41

Please sign in to comment.