Skip to content

Commit

Permalink
🚨All attention refactor🚨 (huggingface#35235)
Browse files Browse the repository at this point in the history
* refactor LlamaAttention

* minimal changes

* fix llama

* update

* modular gemmas

* modular nits

* modular updates

* nits

* simplify

* gpt2

* more modualr and fixes

* granite

* modular modular modular

* nits

* update

* qwen2 + starcoder2

* mostly gemma2

* Update image_processing_auto.py

* fix

* Update modular_starcoder2.py

* fix

* remove all copied from attentions

* remove gcv

* make fix-copies

* oups

* oups2.0

* fix some modulars + all copied from

* should be good now

* revert unwanted changes

* Update modeling_decision_transformer.py

* finish cleanup

* Update modeling_olmo.py

* consistency

* re-add gradient checkpointing attribute

* fix

* style

* make config necessary

* bis

* bis

* Update modeling_my_new_model2.py

* is_causal attr

* fix

* remove past kv return from decoder layer

* fix

* default rope config

* correctly fix rope config

* fix bias

* fix gpt2 attention output

* fix test

* fix inits

* fix default sdpa

* fix default sdpa implementation

* harmonize classes

* fix mistral

* fix sliding window models

* mixtral

* be more explicit

* style

* fix

* several fixes

* Update modeling_dbrx.py

* fix test

* olmo + phi

* rotary

* syle

* phi

* phi again

* again

* kwargs

* Update test_modeling_common.py

* skip fx tracing tests

* Update modeling_utils.py

* gemma 2

* again

* Update modeling_recurrent_gemma.py

* gemma2

* granite

* style

* starcoder

* Update sdpa_attention.py

* switch args

* Update modeling_mllama.py

* fix

* cache type tests

* gpt2

* Update test_modeling_common.py

* fix

* consistency

* fix shape with encoder

* should be the last one

* tests non model

* most comments

* small oupsi

* be more explicit in modulars

* more explicit modulars

* CIs! it works locally

* add kwargs to _flash_attention_forward

---------

Co-authored-by: Cyril Vallez <[email protected]>
  • Loading branch information
ArthurZucker and Cyrilvallez authored Dec 18, 2024
1 parent 75be5a0 commit 2c47618
Show file tree
Hide file tree
Showing 107 changed files with 5,635 additions and 9,778 deletions.
445 changes: 85 additions & 360 deletions examples/modular-transformers/modeling_dummy.py

Large diffs are not rendered by default.

447 changes: 85 additions & 362 deletions examples/modular-transformers/modeling_multimodal1.py

Large diffs are not rendered by default.

519 changes: 152 additions & 367 deletions examples/modular-transformers/modeling_my_new_model2.py

Large diffs are not rendered by default.

34 changes: 24 additions & 10 deletions examples/modular-transformers/modeling_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch import nn

from ...cache_utils import Cache, StaticCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_utils import PreTrainedModel
from ...utils import (
Expand Down Expand Up @@ -253,19 +253,28 @@ def tie_weights(self):
return self.language_model.tie_weights()

def _update_causal_mask(
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
self,
attention_mask,
token_type_ids,
past_key_values,
cache_position,
input_ids=None,
inputs_embeds=None,
is_training: bool = False,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

using_static_cache = isinstance(past_key_values, StaticCache)
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = inputs_embeds.shape[1]
min_dtype = torch.finfo(self.dtype).min
inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
elif isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand All @@ -278,7 +287,7 @@ def _update_causal_mask(
return attention_mask

causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
Expand All @@ -288,7 +297,7 @@ def _update_causal_mask(
causal_mask[:, :sequence_length] = 0.0

causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
Expand Down Expand Up @@ -317,7 +326,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor):
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
image_features = image_features / (self.config.hidden_size**0.5)
image_features = image_features / (self.config.text_config.hidden_size**0.5)
return image_features

@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
Expand Down Expand Up @@ -414,6 +423,7 @@ def prepare_inputs_for_generation(
token_type_ids=None,
use_cache=True,
num_logits_to_keep=None,
labels=None,
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
Expand All @@ -433,12 +443,16 @@ def prepare_inputs_for_generation(
# position_ids in NewTaskModel are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1

# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values

is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs

def resize_token_embeddings(
Expand Down
Loading

0 comments on commit 2c47618

Please sign in to comment.