Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable StaticCache for assisted generation #34797

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
c205b2e
enable StaticCache for assisted generation
yao-matrix Nov 19, 2024
30021dd
update
yao-matrix Nov 19, 2024
620c861
remove warnings import
yao-matrix Nov 19, 2024
b5283e9
enable StaticCache for assisted generation
yao-matrix Nov 19, 2024
71b7d22
update
yao-matrix Nov 19, 2024
c967bbe
remove warnings import
yao-matrix Nov 19, 2024
980aa08
done
yao-matrix Nov 20, 2024
c79411d
done
yao-matrix Nov 20, 2024
c717652
fix review comments
yao-matrix Nov 22, 2024
67618e5
Merge branch 'main' of https://github.com/yao-matrix/transformers
yao-matrix Nov 22, 2024
c8e2428
Merge branch 'main' into main
yao-matrix Nov 26, 2024
fde7ebd
Merge branch 'main' into main
yao-matrix Nov 26, 2024
e1169a3
Merge branch 'main' into main
yao-matrix Nov 29, 2024
8a9a753
add static cache ci
Nov 29, 2024
b74a7fe
Merge branch 'main' of https://github.com/yao-matrix/transformers
Nov 29, 2024
e67a3fd
ship Gemma2 StaticCache CI since it uses HybridCache
Nov 29, 2024
177634c
ruff format
Nov 29, 2024
45d0410
fix phimoe ci
yao-matrix Nov 29, 2024
a33f660
fix mixtral ci
yao-matrix Nov 29, 2024
ff07e47
fix ci
yao-matrix Nov 29, 2024
fbef806
cont.
yao-matrix Nov 29, 2024
3564a87
ci
yao-matrix Nov 29, 2024
b9cf597
fix ci
yao-matrix Dec 2, 2024
803166d
ci
yao-matrix Dec 2, 2024
df1594c
ci
yao-matrix Dec 2, 2024
87b7f15
ci
yao-matrix Dec 2, 2024
e60b1fe
Merge branch 'main' into main
yao-matrix Dec 2, 2024
5e195c2
ci
yao-matrix Dec 2, 2024
759da36
ci
yao-matrix Dec 2, 2024
093b647
ci
yao-matrix Dec 2, 2024
3775dc2
ci
yao-matrix Dec 2, 2024
817d303
ci
yao-matrix Dec 2, 2024
6e2ad2a
add # Ignore copy
yao-matrix Dec 2, 2024
99b6bc2
using a smarter way, ignore in test_utils
yao-matrix Dec 2, 2024
af33391
ci
yao-matrix Dec 2, 2024
587b55f
skip Gemma2, it declars support static cache, but it's hybrid cache a…
yao-matrix Dec 2, 2024
0a49d6f
refine error message
yao-matrix Dec 2, 2024
1deeb55
Merge branch 'main' into main
yao-matrix Dec 4, 2024
62b70e4
Merge branch 'main' into main
yao-matrix Dec 5, 2024
210c2e0
Merge branch 'main' into main
yao-matrix Dec 10, 2024
7b97aa4
add test case test_assisted_decoding_compile
yao-matrix Dec 11, 2024
9cb45da
Merge branch 'main' into main
yao-matrix Dec 11, 2024
93cd7bf
fix bug
yao-matrix Dec 11, 2024
3cc23d7
Merge branch 'main' into main
yao-matrix Dec 13, 2024
b45336c
Merge branch 'main' into main
yao-matrix Dec 15, 2024
04f2ea1
Merge branch 'main' into main
yao-matrix Dec 18, 2024
b08d1fc
Merge branch 'main' into main
yao-matrix Dec 19, 2024
0904268
Merge branch 'main' into main
yao-matrix Dec 20, 2024
dd148a8
Merge branch 'main' into main
yao-matrix Dec 22, 2024
3171476
cohere2 is HybridCache
Dec 23, 2024
4e064ec
Merge branch 'main' into main
yao-matrix Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,31 @@ def update(

return k_out, v_out

def crop(self, max_length: int):
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
seq_length = self.get_seq_length()
# In case it is negative
if max_length < 0:
max_length = seq_length - abs(max_length)

if seq_length <= max_length:
return

begin = max_length
end = seq_length + 1
index = torch.arange(begin, end, device=self.key_cache[0].device)

self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
try:
self.key_cache[idx].index_fill_(2, index, 0)
self.value_cache[idx].index_fill_(2, index, 0)
except NotImplementedError:
# The operator 'aten::index_fill' is not currently implemented for the MPS device.
self.key_cache[idx][:, :, index] = 0
self.value_cache[idx][:, :, index] = 0

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
Expand Down
8 changes: 3 additions & 5 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import torch

from ..cache_utils import DynamicCache
from ..cache_utils import Cache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor

Expand Down Expand Up @@ -177,9 +177,6 @@ def __init__(
"Please pass in `min_length` into `.generate()` instead"
)

# We need to roll back the cache in assisted generation, only DynamicCache is supported
self.generation_config.cache_implementation = None

def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand Down Expand Up @@ -229,6 +226,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor,

# 3. Update variables for the next round of candidate generation
self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
self.generation_config.cache_implementation = None

# 4. Prepare variables for output
candidate_logits = torch.stack(assistant_output.scores, dim=1)
Expand Down Expand Up @@ -748,7 +746,7 @@ def _crop_past_key_values(model, past_key_values, max_length):
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
elif isinstance(past_key_values, DynamicCache):
elif isinstance(past_key_values, Cache):
past_key_values.crop(max_length)
elif past_key_values is not None:
for idx in range(len(past_key_values)):
Expand Down
26 changes: 13 additions & 13 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def prepare_inputs_for_generation(
model_input = model_input.clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input

# 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
# 6. Create 4D attention mask if we are using a `StaticCache` (important for performant compiled forward pass)
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
Expand Down Expand Up @@ -1749,16 +1749,6 @@ def _prepare_cache_for_generation(
return

# Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`

# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
# which is only supported in dynamic caches atm
if assistant_model is not None and generation_config.cache_implementation is not None:
logger.warning_once(
"An assistant model is provided, using a dynamic cache instead of a cache of type="
f"'{generation_config.cache_implementation}'."
)
generation_config.cache_implementation = None

if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
Expand All @@ -1773,6 +1763,15 @@ def _prepare_cache_for_generation(
device=device,
model_kwargs=model_kwargs,
)
if assistant_model is not None:
assistant_model._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences)
* batch_size,
max_cache_len=max_cache_length,
device=device,
model_kwargs=model_kwargs,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
Expand Down Expand Up @@ -2119,6 +2118,7 @@ def generate(
and not self.config.is_encoder_decoder
):
max_cache_length += inputs_tensor.shape[1]

self._prepare_cache_for_generation(
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
)
Expand Down Expand Up @@ -2172,8 +2172,8 @@ def generate(
raise ValueError("assisted generate is only supported for batch_size = 1")
if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
raise ValueError("assisted generate is not supported with Static cache classes`")
if generation_config.cache_implementation in ["hybrid", "sliding_window"]:
raise ValueError("assisted generate is not supported with hybrid & sliding_window cache classes`")
if self._is_stateful:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
Expand Down
7 changes: 6 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2088,13 +2088,16 @@ def test_generate_methods_with_num_logits_to_keep(self):
without_all_logits = model.generate(**inputs_dict, **generation_kwargs)
self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist())

@parameterized.expand([("static", False), (None, True)])
@pytest.mark.generate
def test_assisted_decoding_with_num_logits_to_keep(self):
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache):
for model_class in self.all_generative_model_classes:
if "num_logits_to_keep" not in set(inspect.signature(model_class.forward).parameters.keys()):
self.skipTest(reason="This model does not support `num_logits_to_keep` argument.")
if model_class._is_stateful:
self.skipTest(reason="Stateful models don't support assisted generation")
if cache_implementation == "static" and not model_class._supports_static_cache:
self.skipTest(reason="This model does not support `cache_implementation=static`.")

config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
# NOTE: assisted generation only works with cache on at the moment.
Expand All @@ -2114,6 +2117,8 @@ def test_assisted_decoding_with_num_logits_to_keep(self):
"assistant_model": assistant_model,
"return_dict_in_generate": True,
"output_scores": True,
"cache_implementation": cache_implementation,
"return_legacy_cache": return_legacy_cache,
}

# Setting num_logits_to_keep at 0 keeps all logits (old behavior)
Expand Down
5 changes: 5 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def test_prompt_lookup_decoding_matches_greedy_search(self, assistant_type):
def test_assisted_decoding_sample(self):
pass

@parameterized.expand([("static", False)])
@unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding StaticCache")
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache):
pass

@unittest.skip("Gemma2 has HybridCache which is not compatible with dola decoding")
def test_dola_decoding_sample(self):
pass
Expand Down