Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨All attention refactor🚨 #35235

Merged
merged 99 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 93 commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
79cb53c
refactor LlamaAttention
ArthurZucker Nov 28, 2024
4bb485b
minimal changes
ArthurZucker Dec 12, 2024
f370907
fix llama
ArthurZucker Dec 12, 2024
d3ef539
update
ArthurZucker Dec 12, 2024
45eac58
modular gemmas
ArthurZucker Dec 12, 2024
e52af49
modular nits
ArthurZucker Dec 12, 2024
5ed37ae
modular updates
ArthurZucker Dec 12, 2024
38cafc1
nits
ArthurZucker Dec 12, 2024
a862eac
simplify
ArthurZucker Dec 12, 2024
5639b81
gpt2
ArthurZucker Dec 12, 2024
452d8ed
more modualr and fixes
ArthurZucker Dec 12, 2024
81a0b66
granite
ArthurZucker Dec 12, 2024
bc72c3f
modular modular modular
ArthurZucker Dec 12, 2024
48caa89
nits
ArthurZucker Dec 12, 2024
df68dd0
update
ArthurZucker Dec 13, 2024
0325dc4
qwen2 + starcoder2
Cyrilvallez Dec 13, 2024
ecd814b
mostly gemma2
Cyrilvallez Dec 16, 2024
f5fc638
Update image_processing_auto.py
Cyrilvallez Dec 16, 2024
5e56d9c
fix
Cyrilvallez Dec 16, 2024
598b7bb
Update modular_starcoder2.py
Cyrilvallez Dec 16, 2024
0f565fb
fix
Cyrilvallez Dec 16, 2024
c9ac84d
remove all copied from attentions
Cyrilvallez Dec 16, 2024
d189fe7
remove gcv
ArthurZucker Dec 16, 2024
9c83d96
make fix-copies
ArthurZucker Dec 16, 2024
138368e
oups
ArthurZucker Dec 16, 2024
7225a4f
oups2.0
ArthurZucker Dec 16, 2024
a3b9195
fix some modulars + all copied from
Cyrilvallez Dec 16, 2024
8d93708
should be good now
ArthurZucker Dec 16, 2024
3cc2b4d
Merge branch 'all-attention-refactor' of github.com:huggingface/trans…
Cyrilvallez Dec 16, 2024
074e469
Merge branch 'all-attention-refactor' of github.com:huggingface/trans…
Cyrilvallez Dec 16, 2024
54d9b95
revert unwanted changes
Cyrilvallez Dec 16, 2024
944e26e
Update modeling_decision_transformer.py
Cyrilvallez Dec 16, 2024
911833f
finish cleanup
Cyrilvallez Dec 16, 2024
ea26910
Update modeling_olmo.py
Cyrilvallez Dec 16, 2024
bc421af
consistency
Cyrilvallez Dec 16, 2024
8664ddc
re-add gradient checkpointing attribute
Cyrilvallez Dec 16, 2024
607e928
fix
Cyrilvallez Dec 16, 2024
4612595
style
Cyrilvallez Dec 16, 2024
20c376c
make config necessary
Cyrilvallez Dec 16, 2024
0ac9db2
bis
Cyrilvallez Dec 16, 2024
349b7ab
bis
Cyrilvallez Dec 16, 2024
defa88f
Update modeling_my_new_model2.py
Cyrilvallez Dec 16, 2024
fbf4b55
is_causal attr
Cyrilvallez Dec 16, 2024
9104d0a
fix
Cyrilvallez Dec 16, 2024
0b09340
remove past kv return from decoder layer
Cyrilvallez Dec 16, 2024
46a0df7
fix
Cyrilvallez Dec 16, 2024
aedd88a
default rope config
Cyrilvallez Dec 16, 2024
57e9b49
correctly fix rope config
Cyrilvallez Dec 16, 2024
fe90ec0
fix bias
Cyrilvallez Dec 16, 2024
a3f50d0
fix gpt2 attention output
Cyrilvallez Dec 16, 2024
6a92c70
fix test
Cyrilvallez Dec 16, 2024
a28ad19
fix inits
Cyrilvallez Dec 16, 2024
9bd6c94
fix default sdpa
Cyrilvallez Dec 16, 2024
fae05e1
fix default sdpa implementation
Cyrilvallez Dec 16, 2024
838d211
harmonize classes
Cyrilvallez Dec 16, 2024
e0d10f6
fix mistral
Cyrilvallez Dec 16, 2024
b275fdc
fix sliding window models
Cyrilvallez Dec 16, 2024
71eb6a2
mixtral
Cyrilvallez Dec 16, 2024
4e25753
be more explicit
Cyrilvallez Dec 16, 2024
1e8712b
style
Cyrilvallez Dec 16, 2024
854537b
fix
Cyrilvallez Dec 16, 2024
99bddf0
several fixes
Cyrilvallez Dec 17, 2024
2f666b3
Update modeling_dbrx.py
Cyrilvallez Dec 17, 2024
bafa020
fix test
Cyrilvallez Dec 17, 2024
00a98e7
olmo + phi
Cyrilvallez Dec 17, 2024
8c25411
rotary
Cyrilvallez Dec 17, 2024
4bb2f25
syle
Cyrilvallez Dec 17, 2024
44ff5e3
phi
Cyrilvallez Dec 17, 2024
95f7b96
phi again
Cyrilvallez Dec 17, 2024
7d55036
again
Cyrilvallez Dec 17, 2024
24ac9ab
kwargs
Cyrilvallez Dec 17, 2024
bd8ede8
Update test_modeling_common.py
Cyrilvallez Dec 17, 2024
0d3d3e3
skip fx tracing tests
Cyrilvallez Dec 17, 2024
49135d0
Update modeling_utils.py
Cyrilvallez Dec 17, 2024
f80a2c3
gemma 2
Cyrilvallez Dec 17, 2024
3e461bd
again
Cyrilvallez Dec 17, 2024
7a882d5
Update modeling_recurrent_gemma.py
Cyrilvallez Dec 17, 2024
7870073
gemma2
Cyrilvallez Dec 17, 2024
5b4ebaa
granite
Cyrilvallez Dec 17, 2024
7bdf61c
style
Cyrilvallez Dec 17, 2024
7d5b0b5
starcoder
Cyrilvallez Dec 17, 2024
70ef2fd
Update sdpa_attention.py
Cyrilvallez Dec 17, 2024
b8429c5
switch args
Cyrilvallez Dec 17, 2024
533657c
Update modeling_mllama.py
Cyrilvallez Dec 17, 2024
fe20d63
fix
Cyrilvallez Dec 17, 2024
248a607
cache type tests
Cyrilvallez Dec 17, 2024
4646014
gpt2
Cyrilvallez Dec 17, 2024
ad16b1b
Update test_modeling_common.py
Cyrilvallez Dec 17, 2024
1df6e29
fix
Cyrilvallez Dec 17, 2024
6c01005
consistency
Cyrilvallez Dec 17, 2024
f651cd0
fix shape with encoder
Cyrilvallez Dec 17, 2024
98b7f97
should be the last one
Cyrilvallez Dec 17, 2024
88e2fe5
tests non model
Cyrilvallez Dec 17, 2024
5a3bdc4
most comments
Cyrilvallez Dec 18, 2024
f3923b6
small oupsi
Cyrilvallez Dec 18, 2024
a6a2ff9
be more explicit in modulars
Cyrilvallez Dec 18, 2024
aeea33b
more explicit modulars
Cyrilvallez Dec 18, 2024
ec3bef3
CIs! it works locally
Cyrilvallez Dec 18, 2024
fc74e39
add kwargs to _flash_attention_forward
Cyrilvallez Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
448 changes: 88 additions & 360 deletions examples/modular-transformers/modeling_dummy.py

Large diffs are not rendered by default.

450 changes: 88 additions & 362 deletions examples/modular-transformers/modeling_multimodal1.py

Large diffs are not rendered by default.

522 changes: 155 additions & 367 deletions examples/modular-transformers/modeling_my_new_model2.py

Large diffs are not rendered by default.

34 changes: 24 additions & 10 deletions examples/modular-transformers/modeling_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch import nn

from ...cache_utils import Cache, StaticCache
from ...cache_utils import Cache, HybridCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_utils import PreTrainedModel
from ...utils import (
Expand Down Expand Up @@ -253,19 +253,28 @@ def tie_weights(self):
return self.language_model.tie_weights()

def _update_causal_mask(
self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False
self,
attention_mask,
token_type_ids,
past_key_values,
cache_position,
input_ids=None,
inputs_embeds=None,
is_training: bool = False,
):
if self.config.text_config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

using_static_cache = isinstance(past_key_values, StaticCache)
dtype = inputs_embeds.dtype
min_dtype = torch.finfo(dtype).min
sequence_length = inputs_embeds.shape[1]
min_dtype = torch.finfo(self.dtype).min
inputs_lead_dim = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
sequence_length = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
elif isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
Expand All @@ -278,7 +287,7 @@ def _update_causal_mask(
return attention_mask

causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
(sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
)
# Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
if sequence_length != 1:
Expand All @@ -288,7 +297,7 @@ def _update_causal_mask(
causal_mask[:, :sequence_length] = 0.0

causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
Expand Down Expand Up @@ -317,7 +326,7 @@ def get_image_features(self, pixel_values: torch.FloatTensor):
image_outputs = self.vision_tower(pixel_values)
selected_image_feature = image_outputs.last_hidden_state
image_features = self.multi_modal_projector(selected_image_feature)
image_features = image_features / (self.config.hidden_size**0.5)
image_features = image_features / (self.config.text_config.hidden_size**0.5)
return image_features

@add_start_docstrings_to_model_forward(NEW_TASK_MODEL_INPUTS_DOCSTRING)
Expand Down Expand Up @@ -414,6 +423,7 @@ def prepare_inputs_for_generation(
token_type_ids=None,
use_cache=True,
num_logits_to_keep=None,
labels=None,
**kwargs,
):
# Overwritten -- custom `position_ids` and `pixel_values` handling
Expand All @@ -433,12 +443,16 @@ def prepare_inputs_for_generation(
# position_ids in NewTaskModel are 1-indexed
if model_inputs.get("position_ids") is not None:
model_inputs["position_ids"] += 1

# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
# Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values

is_training = token_type_ids is not None and labels is not None
if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
causal_mask = self._update_causal_mask(
attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
)
model_inputs["attention_mask"] = causal_mask
return model_inputs

def resize_token_embeddings(
Expand Down
Loading
Loading