From 64c06df325c9b59a8d3aae7670c85d9977c517bf Mon Sep 17 00:00:00 2001 From: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Date: Wed, 15 May 2024 13:57:28 +0100 Subject: [PATCH] Jamba - Skip 4d custom attention mask test (#30826) * Jamba - Skip 4d custom attention mask test * Skip assistant greedy test --- src/transformers/models/jamba/modeling_jamba.py | 1 - tests/models/jamba/test_modeling_jamba.py | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 1dbcbc76f3c248..8a5cdcbc2eb92a 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1261,7 +1261,6 @@ class JambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index ffe859bb59d62e..0fe515e516cd44 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -502,6 +502,10 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): # They should result in very similar logits self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3)) + @unittest.skip("Jamba has its own special cache type") # FIXME: @gante + def test_assisted_decoding_matches_greedy_search_0_random(self): + pass + @require_flash_attn @require_torch_gpu @require_bitsandbytes