Skip to content

Commit

Permalink
Cache: models return input cache type (#30716)
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored May 8, 2024
1 parent 71c1985 commit f26e407
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 70 deletions.
11 changes: 6 additions & 5 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,9 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -943,11 +945,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,9 @@ def forward(

inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -1182,13 +1184,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(
v
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,9 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -933,13 +935,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,9 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -1021,13 +1023,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,9 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -999,13 +1001,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
7 changes: 0 additions & 7 deletions tests/models/cohere/test_modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

import unittest

from parameterized import parameterized

from transformers import CohereConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
Expand Down Expand Up @@ -296,11 +294,6 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()

@unittest.skip("TODO @gante fix this for Cohere")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass

def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
Expand Down
7 changes: 0 additions & 7 deletions tests/models/dbrx/test_modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

import unittest

from parameterized import parameterized

from transformers import DbrxConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device

Expand Down Expand Up @@ -357,11 +355,6 @@ def test_model_from_pretrained(self):
def test_tied_weights_keys(self):
pass

@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass


@require_torch
class DbrxModelIntegrationTest(unittest.TestCase):
Expand Down
6 changes: 0 additions & 6 deletions tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import unittest

import pytest
from parameterized import parameterized

from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.testing_utils import (
Expand Down Expand Up @@ -367,11 +366,6 @@ def test_Gemma_sequence_classification_model_for_multi_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass

@unittest.skip("Gemma buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass
Expand Down
5 changes: 0 additions & 5 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,6 @@ def test_eager_matches_sdpa_generate(self):
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)

@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass


@require_torch_gpu
class LlamaIntegrationTest(unittest.TestCase):
Expand Down
5 changes: 0 additions & 5 deletions tests/models/olmo/test_modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,6 @@ def test_model_rope_scaling(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))

@unittest.skip("TODO @gante fix this for OLMo")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass


@require_torch
class OlmoIntegrationTest(unittest.TestCase):
Expand Down
7 changes: 0 additions & 7 deletions tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
""" Testing suite for the PyTorch RecurrentGemma model. """
import unittest

from parameterized import parameterized

from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
Expand Down Expand Up @@ -330,11 +328,6 @@ def test_model_various_embeddings(self):
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)

@unittest.skip("Recurrent gemma does not use legacy cache")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass

def test_save_load_fast_init_from_base(self):
pass

Expand Down

0 comments on commit f26e407

Please sign in to comment.