From f26e4073707189c93915227779a4f6ea3c40d43b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 8 May 2024 18:26:34 +0100 Subject: [PATCH] Cache: models return input cache type (#30716) --- src/transformers/models/cohere/modeling_cohere.py | 11 ++++++----- src/transformers/models/dbrx/modeling_dbrx.py | 13 ++++++------- src/transformers/models/gemma/modeling_gemma.py | 13 ++++++------- src/transformers/models/llama/modeling_llama.py | 13 ++++++------- src/transformers/models/olmo/modeling_olmo.py | 13 ++++++------- tests/models/cohere/test_modeling_cohere.py | 7 ------- tests/models/dbrx/test_modeling_dbrx.py | 7 ------- tests/models/gemma/test_modeling_gemma.py | 6 ------ tests/models/llama/test_modeling_llama.py | 5 ----- tests/models/olmo/test_modeling_olmo.py | 5 ----- .../test_modeling_recurrent_gemma.py | 7 ------- 11 files changed, 30 insertions(+), 70 deletions(-) diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index ce6fe29b18859c..d729def45a72c2 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -881,7 +881,9 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: @@ -943,11 +945,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache - ) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 14baf7ae19e441..6b737565caadc9 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1115,7 +1115,9 @@ def forward( inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: @@ -1182,13 +1184,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, DynamicCache) - else next_decoder_cache - ) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple( v diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 9917c3d3ce5f63..a0640e99be33bf 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -865,7 +865,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: @@ -933,13 +935,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, DynamicCache) - else next_decoder_cache - ) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0e971478e26ef3..29ce8d38e87c26 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -960,7 +960,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: @@ -1021,13 +1023,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, DynamicCache) - else next_decoder_cache - ) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 64890b8576d613..380be9e3dc3e12 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -938,7 +938,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: @@ -999,13 +1001,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if isinstance(next_decoder_cache, DynamicCache) - else next_decoder_cache - ) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( diff --git a/tests/models/cohere/test_modeling_cohere.py b/tests/models/cohere/test_modeling_cohere.py index 2165943e796ea6..07fd36372469ae 100644 --- a/tests/models/cohere/test_modeling_cohere.py +++ b/tests/models/cohere/test_modeling_cohere.py @@ -16,8 +16,6 @@ import unittest -from parameterized import parameterized - from transformers import CohereConfig, is_torch_available from transformers.testing_utils import ( require_bitsandbytes, @@ -296,11 +294,6 @@ def setUp(self): def test_config(self): self.config_tester.run_common_tests() - @unittest.skip("TODO @gante fix this for Cohere") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) diff --git a/tests/models/dbrx/test_modeling_dbrx.py b/tests/models/dbrx/test_modeling_dbrx.py index a66bf2acfc286c..4c6b74a4d7baf2 100644 --- a/tests/models/dbrx/test_modeling_dbrx.py +++ b/tests/models/dbrx/test_modeling_dbrx.py @@ -17,8 +17,6 @@ import unittest -from parameterized import parameterized - from transformers import DbrxConfig, is_torch_available from transformers.testing_utils import require_torch, slow, torch_device @@ -357,11 +355,6 @@ def test_model_from_pretrained(self): def test_tied_weights_keys(self): pass - @unittest.skip("TODO @gante fix this for Llama") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch class DbrxModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index e70dab3d95d722..80f275e54ce87e 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -17,7 +17,6 @@ import unittest import pytest -from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available from transformers.testing_utils import ( @@ -367,11 +366,6 @@ def test_Gemma_sequence_classification_model_for_multi_label(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) - @unittest.skip("TODO @gante fix this for Llama") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @unittest.skip("Gemma buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 0592922e447093..e63e53797462b4 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -591,11 +591,6 @@ def test_eager_matches_sdpa_generate(self): msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", ) - @unittest.skip("TODO @gante fix this for Llama") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch_gpu class LlamaIntegrationTest(unittest.TestCase): diff --git a/tests/models/olmo/test_modeling_olmo.py b/tests/models/olmo/test_modeling_olmo.py index ce354db52b298b..906bd73a70d2a9 100644 --- a/tests/models/olmo/test_modeling_olmo.py +++ b/tests/models/olmo/test_modeling_olmo.py @@ -353,11 +353,6 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - @unittest.skip("TODO @gante fix this for OLMo") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - @require_torch class OlmoIntegrationTest(unittest.TestCase): diff --git a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py index c46718b6803884..599161e8d4059e 100644 --- a/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py +++ b/tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py @@ -15,8 +15,6 @@ """ Testing suite for the PyTorch RecurrentGemma model. """ import unittest -from parameterized import parameterized - from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed from transformers.testing_utils import ( require_bitsandbytes, @@ -330,11 +328,6 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - @unittest.skip("Recurrent gemma does not use legacy cache") - @parameterized.expand([(1, False), (1, True), (4, False)]) - def test_new_cache_format(self, num_beams, do_sample): - pass - def test_save_load_fast_init_from_base(self): pass