Skip to content

Commit

Permalink
working :D
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Apr 25, 2024
1 parent 16c8e17 commit 14b19c4
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 157 deletions.
74 changes: 31 additions & 43 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .configuration_utils import PretrainedConfig
from .utils import logging


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -61,6 +60,14 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
return max_length - new_seq_length
return previous_seq_length

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

@property
def seen_tokens(self):
logger.warning_once(
Expand Down Expand Up @@ -158,14 +165,6 @@ def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return None

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
legacy_cache = ()
Expand Down Expand Up @@ -332,23 +331,14 @@ def update(

return self.key_cache[layer_idx], self.value_cache[layer_idx]

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))


class StaticCache(Cache):
"""
Static Cache class to be used with `torch.compile(model)`.
Parameters:
config (`PretrainedConfig):
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
required to initialize the static cache.
The configuration file defining the shape-related attributes required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
Expand All @@ -373,9 +363,18 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)

self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
for _ in range(config.num_hidden_layers):
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache.
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.key_cache.append(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_key_cache)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache.append(new_layer_value_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)

def update(
self,
Expand All @@ -394,42 +393,31 @@ def update(
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for. Kept for backward compatibility
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
to know how much of the cache it should overwrite.
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
to know how where to write in the cache.
Return:
A tuple containing the updated key and value states.
"""
new_cache_positions = cache_kwargs.get("cache_position")
k_out = self.key_cache
v_out = self.value_cache
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

return k_out, v_out

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
"""Returns the sequence length of the cached states that were seen by the model."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
# https://github.com/pytorch/pytorch/issues/120248 is fixed
return (self.key_cache[0, 0].any(dim=-1)).sum()
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
"""Returns the maximum sequence length of the cached states."""
return self.max_cache_len

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.key_cache.device
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
device = self.value_cache.device
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))

def to_legacy_cache(self):
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
return None
41 changes: 22 additions & 19 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,19 +1514,30 @@ def generate(
input_ids_length=input_ids_length,
)

if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not self._supports_cache_class:
raise ValueError(
"This model does not support the `cache_implementation` argument. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if generation_config.cache_implementation == "static":
if model_kwargs.get("past_key_values", False) is not False:
raise ValueError(
"Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository."
)
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
if hasattr(self.config, "_pre_quantization_dtype"):
cache_dtype = self.config._pre_quantization_dtype
else:
cache_dtype = self.dtype
model_kwargs["past_key_values"] = cache_cls(
config=self.config,
max_batch_size=batch_size,
max_cache_len=generation_config.max_length,
device=self.device,
dtype=cache_dtype,
)

self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

Expand Down Expand Up @@ -1844,14 +1855,6 @@ def typeerror():
**model_kwargs,
)

if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not callable(getattr(self, "_reset_cache", None)):
raise ValueError(
"A `static_cache` was used to generate but there was a failure when trying to release the cache. "
" Make sure this model implements a `_reset_cache` function."
)
self._reset_cache()

return result

def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
Expand Down
78 changes: 26 additions & 52 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,13 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

if isinstance(past_key_value, StaticCache):
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)

output_attentions = False

bsz, q_len, _ = hidden_states.size()
Expand Down Expand Up @@ -809,27 +816,6 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()

def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
raise ValueError(
"`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)

for layer in self.model.layers:
device = layer.input_layernorm.weight.device
if hasattr(self.config, "_pre_quantization_dtype"):
dtype = self.config._pre_quantization_dtype
else:
dtype = layer.self_attn.o_proj.weight.dtype
layer.self_attn.past_key_value = cache_cls(
self.config, max_batch_size, max_cache_len, device=device, dtype=dtype
)

def _reset_cache(self):
for layer in self.model.layers:
layer.self_attn.past_key_value = None


LLAMA_INPUTS_DOCSTRING = r"""
Args:
Expand Down Expand Up @@ -944,7 +930,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Cache | List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -973,23 +959,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
if use_cache: # kept for BC (cache positions)
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values)

# embed positions
hidden_states = inputs_embeds
Expand Down Expand Up @@ -1042,7 +1023,7 @@ def forward(
next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, DynamicCache) else next_decoder_cache
)
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
Expand All @@ -1058,7 +1039,7 @@ def _update_causal_mask(
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_seen_tokens: int,
past_key_values: Cache,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
Expand All @@ -1070,9 +1051,12 @@ def _update_causal_mask(
return attention_mask
return None

if self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
if self.config._attn_implementation == "sdpa" and not using_static_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
Expand All @@ -1081,9 +1065,9 @@ def _update_causal_mask(
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache
target_length = self.config.max_position_embeddings
else: # dynamic cache
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
Expand Down Expand Up @@ -1164,7 +1148,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Cache | List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down Expand Up @@ -1262,13 +1246,6 @@ def prepare_inputs_for_generation(
use_cache=True,
**kwargs,
):
# With static cache, the `past_key_values` is None
# TODO joao: standardize interface for the different Cache classes and remove of this if
has_static_cache = False
if past_key_values is None:
past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None)
has_static_cache = past_key_values is not None

past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
Expand Down Expand Up @@ -1327,9 +1304,6 @@ def prepare_inputs_for_generation(
elif use_cache:
cache_position = cache_position[-input_length:]

if has_static_cache:
past_key_values = None

model_inputs.update(
{
"position_ids": position_ids,
Expand Down Expand Up @@ -1388,7 +1362,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Cache | List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand Down Expand Up @@ -1505,7 +1479,7 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
past_key_values: Optional[Cache | List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
end_positions: Optional[torch.LongTensor] = None,
Expand Down
Loading

0 comments on commit 14b19c4

Please sign in to comment.