Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Core] Refactor modeling code #34987

Draft
wants to merge 42 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f14637a
refactor LlamaAttention
ArthurZucker Nov 28, 2024
f446bd4
only change lLlama
ArthurZucker Dec 11, 2024
0384db9
more refactoring
ArthurZucker Dec 11, 2024
4e681b9
nits
ArthurZucker Dec 11, 2024
893ef38
nits
ArthurZucker Dec 11, 2024
13a195a
_output_embedding and _input_embeding
ArthurZucker Dec 11, 2024
39ab8b7
oupts
ArthurZucker Dec 11, 2024
0418f97
make auto for causal lm work
ArthurZucker Dec 11, 2024
341b8ce
nits
ArthurZucker Dec 11, 2024
556aa4e
updates
ArthurZucker Dec 11, 2024
f61a5fe
pass attention
ArthurZucker Dec 11, 2024
dcf7a37
cache concatenates on the wrong axis
ArthurZucker Dec 11, 2024
1baabd3
update
ArthurZucker Dec 11, 2024
38dd294
fix
ArthurZucker Dec 11, 2024
4015481
revert some stuff
ArthurZucker Dec 11, 2024
28829d2
there was an issue with tie weight keys
ArthurZucker Dec 11, 2024
1ef18f4
style
ArthurZucker Dec 11, 2024
4b9a429
style
ArthurZucker Dec 11, 2024
e5d60b4
fix
ArthurZucker Dec 11, 2024
3bbae39
remove tanh
ArthurZucker Dec 11, 2024
89d32d6
fix auto set
ArthurZucker Dec 11, 2024
7a911ef
update
ArthurZucker Dec 11, 2024
20c512b
clean
ArthurZucker Dec 11, 2024
d915636
mm
ArthurZucker Dec 11, 2024
6018982
fix!
ArthurZucker Dec 11, 2024
e9d751a
fix attention_mask
ArthurZucker Dec 11, 2024
7a608da
update
ArthurZucker Dec 11, 2024
6028e85
fixup
ArthurZucker Dec 11, 2024
725d00c
fix some stuff
ArthurZucker Dec 12, 2024
c224f36
fix some tests
ArthurZucker Dec 12, 2024
3f68c7c
9 left!
ArthurZucker Dec 12, 2024
1a5a834
fix auto?
ArthurZucker Dec 12, 2024
53450ac
fix
ArthurZucker Dec 12, 2024
2016bc4
default init weights
ArthurZucker Dec 12, 2024
4f36712
nit?
ArthurZucker Dec 12, 2024
f7395cc
Merge branch 'main' into llama-refactor
ArthurZucker Dec 12, 2024
9461039
nits
ArthurZucker Dec 12, 2024
57eece6
Merge branch 'llama-refactor' of github.com:huggingface/transformers …
ArthurZucker Dec 12, 2024
584b443
fix unpack imoprt
ArthurZucker Dec 12, 2024
95cb944
be permissive
ArthurZucker Dec 12, 2024
caaa5e5
tgi update
Cyrilvallez Dec 12, 2024
5060a33
remove layer_idx
Cyrilvallez Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@
"AutoImageProcessor",
"AutoProcessor",
"AutoTokenizer",
"AutoForCausalLM",
],
"models.autoformer": ["AutoformerConfig"],
"models.bark": [
Expand Down Expand Up @@ -2611,10 +2612,6 @@
)
_import_structure["models.llama"].extend(
[
"LlamaForCausalLM",
"LlamaForQuestionAnswering",
"LlamaForSequenceClassification",
"LlamaForTokenClassification",
"LlamaModel",
"LlamaPreTrainedModel",
]
Expand Down Expand Up @@ -5084,6 +5081,7 @@
TOKENIZER_MAPPING,
AutoConfig,
AutoFeatureExtractor,
AutoForCausalLM,
AutoImageProcessor,
AutoProcessor,
AutoTokenizer,
Expand Down Expand Up @@ -6469,6 +6467,7 @@
AutoModelForZeroShotObjectDetection,
AutoModelWithLMHead,
)
from .models.auto.modeling_task import AutoForCausalLM
from .models.autoformer import (
AutoformerForPrediction,
AutoformerModel,
Expand Down Expand Up @@ -7336,10 +7335,6 @@
LiltPreTrainedModel,
)
from .models.llama import (
LlamaForCausalLM,
LlamaForQuestionAnswering,
LlamaForSequenceClassification,
LlamaForTokenClassification,
LlamaModel,
LlamaPreTrainedModel,
)
Expand Down
41 changes: 41 additions & 0 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch

from ..modeling_flash_attention_utils import _flash_attention_forward


def flash_attention_forward(
config, query, key, value, attention_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
):
if attention_mask is not None:
seq_len = attention_mask.shape[1]
query = query[:, :, :seq_len]
value = value[:, :, :seq_len]
else:
seq_len = query.shape[1]

# Re-transpose them
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

dropout_rate = config.attention_dropout if training else 0.0

input_dtype = query.dtype
if input_dtype == torch.float32:
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)

attn_output = _flash_attention_forward(
query,
key,
value,
attention_mask,
seq_len,
config=config,
dropout=dropout_rate,
layer_idx=layer_idx,
**kwargs,
)

return attn_output, None
28 changes: 28 additions & 0 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from ..utils import is_torch_greater_or_equal


if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention


def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs):
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

def causal_mod(score, b, h, q_idx, kv_idx):
if causal_mask is not None:
score += causal_mask[b][0][q_idx][kv_idx]
return score

attn_output, attention_weights = flex_attention(
query,
key,
value,
score_mod=causal_mod,
enable_gqa=True,
scale=module.scaling,
return_lse=True,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attention_weights
39 changes: 39 additions & 0 deletions src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)

causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

is_causal = True if causal_mask is None and query.shape[1] > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=causal_mask,
dropout_p=module.config.attention_dropout if module.training else 0.0,
is_causal=is_causal,
scale=module.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, None
2 changes: 1 addition & 1 deletion src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _flash_attention_forward(
if not use_top_left_mask:
causal = is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.mistral.modeling_mistral.MistralFlashAttention2.__init__.
causal = is_causal and query_length != 1

# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
Expand Down
Loading
Loading