Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable StaticCache for assisted generation #34797

Open
wants to merge 51 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
c205b2e
enable StaticCache for assisted generation
yao-matrix Nov 19, 2024
30021dd
update
yao-matrix Nov 19, 2024
620c861
remove warnings import
yao-matrix Nov 19, 2024
b5283e9
enable StaticCache for assisted generation
yao-matrix Nov 19, 2024
71b7d22
update
yao-matrix Nov 19, 2024
c967bbe
remove warnings import
yao-matrix Nov 19, 2024
980aa08
done
yao-matrix Nov 20, 2024
c79411d
done
yao-matrix Nov 20, 2024
c717652
fix review comments
yao-matrix Nov 22, 2024
67618e5
Merge branch 'main' of https://github.com/yao-matrix/transformers
yao-matrix Nov 22, 2024
c8e2428
Merge branch 'main' into main
yao-matrix Nov 26, 2024
fde7ebd
Merge branch 'main' into main
yao-matrix Nov 26, 2024
e1169a3
Merge branch 'main' into main
yao-matrix Nov 29, 2024
8a9a753
add static cache ci
Nov 29, 2024
b74a7fe
Merge branch 'main' of https://github.com/yao-matrix/transformers
Nov 29, 2024
e67a3fd
ship Gemma2 StaticCache CI since it uses HybridCache
Nov 29, 2024
177634c
ruff format
Nov 29, 2024
45d0410
fix phimoe ci
yao-matrix Nov 29, 2024
a33f660
fix mixtral ci
yao-matrix Nov 29, 2024
ff07e47
fix ci
yao-matrix Nov 29, 2024
fbef806
cont.
yao-matrix Nov 29, 2024
3564a87
ci
yao-matrix Nov 29, 2024
b9cf597
fix ci
yao-matrix Dec 2, 2024
803166d
ci
yao-matrix Dec 2, 2024
df1594c
ci
yao-matrix Dec 2, 2024
87b7f15
ci
yao-matrix Dec 2, 2024
e60b1fe
Merge branch 'main' into main
yao-matrix Dec 2, 2024
5e195c2
ci
yao-matrix Dec 2, 2024
759da36
ci
yao-matrix Dec 2, 2024
093b647
ci
yao-matrix Dec 2, 2024
3775dc2
ci
yao-matrix Dec 2, 2024
817d303
ci
yao-matrix Dec 2, 2024
6e2ad2a
add # Ignore copy
yao-matrix Dec 2, 2024
99b6bc2
using a smarter way, ignore in test_utils
yao-matrix Dec 2, 2024
af33391
ci
yao-matrix Dec 2, 2024
587b55f
skip Gemma2, it declars support static cache, but it's hybrid cache a…
yao-matrix Dec 2, 2024
0a49d6f
refine error message
yao-matrix Dec 2, 2024
1deeb55
Merge branch 'main' into main
yao-matrix Dec 4, 2024
62b70e4
Merge branch 'main' into main
yao-matrix Dec 5, 2024
210c2e0
Merge branch 'main' into main
yao-matrix Dec 10, 2024
7b97aa4
add test case test_assisted_decoding_compile
yao-matrix Dec 11, 2024
9cb45da
Merge branch 'main' into main
yao-matrix Dec 11, 2024
93cd7bf
fix bug
yao-matrix Dec 11, 2024
3cc23d7
Merge branch 'main' into main
yao-matrix Dec 13, 2024
b45336c
Merge branch 'main' into main
yao-matrix Dec 15, 2024
04f2ea1
Merge branch 'main' into main
yao-matrix Dec 18, 2024
b08d1fc
Merge branch 'main' into main
yao-matrix Dec 19, 2024
0904268
Merge branch 'main' into main
yao-matrix Dec 20, 2024
dd148a8
Merge branch 'main' into main
yao-matrix Dec 22, 2024
3171476
cohere2 is HybridCache
Dec 23, 2024
4e064ec
Merge branch 'main' into main
yao-matrix Jan 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,25 @@ def update(

return k_out, v_out

def crop(self, max_length: int):
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
# In case it is negative
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
yao-matrix marked this conversation as resolved.
Show resolved Hide resolved

if self.get_seq_length() <= max_length:
return

begin = max_length
end = self.get_seq_length() + 1
index = torch.arange(begin, end, device=self.key_cache[0].device)

self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
self.key_cache[idx].index_fill_(2, index, 0)
self.value_cache[idx].index_fill_(2, index, 0)
yao-matrix marked this conversation as resolved.
Show resolved Hide resolved

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import torch

from ..cache_utils import DynamicCache
from ..cache_utils import DynamicCache, StaticCache
from ..pytorch_utils import isin_mps_friendly
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor

Expand Down Expand Up @@ -176,10 +176,10 @@ def __init__(
"Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
"Please pass in `min_length` into `.generate()` instead"
)

# We need to roll back the cache in assisted generation, only DynamicCache is supported
# assume cache created while _prepare_cache_for_generation is called
self.generation_config.cache_implementation = None


def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
"""
Fetches the candidates to be tried for the current input.
Expand Down Expand Up @@ -696,6 +696,8 @@ def _crop_past_key_values(model, past_key_values, max_length):
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
elif isinstance(past_key_values, DynamicCache):
past_key_values.crop(max_length)
elif isinstance(past_key_values, StaticCache):
past_key_values.crop(max_length)
yao-matrix marked this conversation as resolved.
Show resolved Hide resolved
elif past_key_values is not None:
for idx in range(len(past_key_values)):
if past_key_values[idx] != ([], []):
Expand Down
23 changes: 11 additions & 12 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def prepare_inputs_for_generation(
model_input = model_input.clone(memory_format=torch.contiguous_format)
model_inputs[model_input_name] = model_input

# 6. Create 4D attention mask is we are using a `StaticCache` (important for performant compiled forward pass)
# 6. Create 4D attention mask if we are using a `StaticCache` (important for performant compiled forward pass)
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
Expand Down Expand Up @@ -1727,16 +1727,6 @@ def _prepare_cache_for_generation(
return

# Otherwise we NEED to prepare a cache, based on `generation_config.cache_implementation`

# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
# which is only supported in dynamic caches atm
if assistant_model is not None and generation_config.cache_implementation is not None:
logger.warning_once(
"An assistant model is provided, using a dynamic cache instead of a cache of type="
f"'{generation_config.cache_implementation}'."
)
generation_config.cache_implementation = None

if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
Expand All @@ -1751,6 +1741,14 @@ def _prepare_cache_for_generation(
device=device,
model_kwargs=model_kwargs,
)
if assistant_model is not None:
assistant_model._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
max_cache_len=max_cache_length,
device=device,
model_kwargs=model_kwargs,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think it will be called on assistant model when we call assistant.generate() so there is no need. We can only remove self.generation_config.cache_implementation = None in candidate generator

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the thing is: when we leave to let assistant_model.generate which is in get_candiates to call this. since the max_new _tokens will be set to max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) when it's first-time called, so the cache_length will be set to int(self.num_assistant_tokens) + prompt_len, less than the actual needed cache_length max_token_length + prompt_length, and lead to assert out while generation. So, the key here is assistant model's cache length should be same as main model here. And then I can see this function has assistant_model as an argument but not used it, I think it may be here for the cases like this. That's the rational behind.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i see, that makes sense. Then we can leave cache init here

elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
Expand Down Expand Up @@ -2097,6 +2095,7 @@ def generate(
and not self.config.is_encoder_decoder
):
max_cache_length += inputs_tensor.shape[1]

self._prepare_cache_for_generation(
generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
)
Expand Down Expand Up @@ -2150,7 +2149,7 @@ def generate(
raise ValueError("assisted generate is only supported for batch_size = 1")
if not model_kwargs["use_cache"]:
raise ValueError("assisted generate requires `use_cache=True`")
if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
if generation_config.cache_implementation in ["hybrid", "sliding_window"]:
raise ValueError("assisted generate is not supported with Static cache classes`")
if self._is_stateful:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
Expand Down