diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 6349eb67e75a3d..177afef836918f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1344,13 +1344,18 @@ def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCa cache_dtype = self.config._pre_quantization_dtype else: cache_dtype = self.dtype - self._static_cache = StaticCache( - config=self.config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=self.device, - dtype=cache_dtype, - ) + try: + self._static_cache = StaticCache( + config=self.config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=self.device, + dtype=cache_dtype, + ) + except AttributeError: + raise ValueError( + f"This model's class ({self.__class__.__name__}) does not support static cache for generation." + ) else: self._static_cache.reset() # reset the cache for a new generation return self._static_cache diff --git a/src/transformers/models/dbrx/configuration_dbrx.py b/src/transformers/models/dbrx/configuration_dbrx.py index b03d2c17b09e07..61c1a8ced4ccf9 100644 --- a/src/transformers/models/dbrx/configuration_dbrx.py +++ b/src/transformers/models/dbrx/configuration_dbrx.py @@ -249,6 +249,7 @@ def __init__( self.use_cache = use_cache self.initializer_range = initializer_range self.output_router_logits = output_router_logits + self.num_key_value_heads = self.attn_config.kv_n_heads tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) if tie_word_embeddings: diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c605a165840464..bb55e417882a8c 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1655,6 +1655,7 @@ def test_new_cache_format(self, num_beams, do_sample): @require_torch_gpu @slow + @is_flaky() # compilation may result in equivalent (!= same) FP ops, causing the argmax in generate to be flaky def test_generate_compile_fullgraph(self): """Tests that `.generate` is compatible with torch.compile without graph breaks, keeping the same results""" for model_class in self.all_generative_model_classes: diff --git a/tests/models/dbrx/test_modeling_dbrx.py b/tests/models/dbrx/test_modeling_dbrx.py index 4c6b74a4d7baf2..cf07ea63a85835 100644 --- a/tests/models/dbrx/test_modeling_dbrx.py +++ b/tests/models/dbrx/test_modeling_dbrx.py @@ -355,6 +355,10 @@ def test_model_from_pretrained(self): def test_tied_weights_keys(self): pass + @unittest.skip("Dbrx does not support `torch.compile` with `fullgraph=True`.") + def test_generate_compile_fullgraph(self): + pass + @require_torch class DbrxModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index ffe859bb59d62e..ab463f451c36d2 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -644,6 +644,10 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): def test_new_cache_format(self, num_beams, do_sample): pass + @unittest.skip("Jamba has its own special cache type") + def test_generate_compile_fullgraph(self): + pass + @require_torch class JambaModelIntegrationTest(unittest.TestCase):