diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 1dbcbc76f3c248..8a5cdcbc2eb92a 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -1261,7 +1261,6 @@ class JambaPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_cache_class = True def _init_weights(self, module): std = self.config.initializer_range diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index ffe859bb59d62e..0fe515e516cd44 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -502,6 +502,10 @@ def _prepare_model_kwargs(input_ids, attention_mask, signature): # They should result in very similar logits self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3)) + @unittest.skip("Jamba has its own special cache type") # FIXME: @gante + def test_assisted_decoding_matches_greedy_search_0_random(self): + pass + @require_flash_attn @require_torch_gpu @require_bitsandbytes