Skip to content

Commit

Permalink
mostly gemma2
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed Dec 16, 2024
1 parent 0325dc4 commit ecd814b
Show file tree
Hide file tree
Showing 21 changed files with 272 additions and 260 deletions.
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
Expand Down
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_multimodal1.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
Expand Down
92 changes: 46 additions & 46 deletions examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
Expand Down Expand Up @@ -245,6 +245,51 @@ def forward(
return outputs


MY_NEW_MODEL2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MyNewModel2Config`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
"The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top.",
MY_NEW_MODEL2_START_DOCSTRING,
)
class MyNewModel2PreTrainedModel(PreTrainedModel):
config_class = MyNewModel2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MyNewModel2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True

def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()


class MyNewModel2RotaryEmbedding(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -310,51 +355,6 @@ def forward(self, x, position_ids):
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


MY_NEW_MODEL2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`MyNewModel2Config`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


@add_start_docstrings(
"The bare MyNewModel2 Model outputting raw hidden-states without any specific head on top.",
MY_NEW_MODEL2_START_DOCSTRING,
)
class MyNewModel2PreTrainedModel(PreTrainedModel):
config_class = MyNewModel2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["MyNewModel2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True

def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()


MY_NEW_MODEL2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Expand Down
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
Expand Down
31 changes: 22 additions & 9 deletions src/transformers/integrations/flash_attention.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
from typing import Optional

import torch

from ..modeling_flash_attention_utils import _flash_attention_forward


def flash_attention_forward(
config, query, key, value, attentions_mask, target_dtype=torch.float16, training=False, layer_idx=0, **kwargs
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
target_dtype: torch.dtype = torch.float16,
**kwargs,
):
if attentions_mask is not None:
seq_len = attentions_mask.shape[1]
if attention_mask is not None:
seq_len = attention_mask.shape[1]
query = query[:, :, :seq_len]
value = value[:, :, :seq_len]
else:
Expand All @@ -18,8 +30,6 @@ def flash_attention_forward(
key = key.transpose(1, 2)
value = value.transpose(1, 2)

dropout_rate = config.attention_dropout if training else 0.0

input_dtype = query.dtype
if input_dtype == torch.float32:
query = query.to(target_dtype)
Expand All @@ -30,11 +40,14 @@ def flash_attention_forward(
query,
key,
value,
attentions_mask,
attention_mask,
seq_len,
config=config,
dropout=dropout_rate,
layer_idx=layer_idx,
module.is_causal,
dropout=dropout,
softmax_scale=scaling,
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=module._flash_attn_uses_top_left_mask,
**kwargs,
)

Expand Down
20 changes: 18 additions & 2 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
from typing import Optional

import torch

from ..utils import is_torch_greater_or_equal


if is_torch_greater_or_equal("2.5"):
from torch.nn.attention.flex_attention import flex_attention


def flex_attention_forward(module, query, key, value, attention_mask, output_attentions=False, **_kwargs):
def flex_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
softcap: Optional[float] = None,
**kwargs,
):
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, :, :, : key.shape[-2]]

def causal_mod(score, b, h, q_idx, kv_idx):
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
if causal_mask is not None:
score += causal_mask[b][0][q_idx][kv_idx]
return score
Expand All @@ -21,8 +36,9 @@ def causal_mod(score, b, h, q_idx, kv_idx):
value,
score_mod=causal_mod,
enable_gqa=True,
scale=module.scaling,
scale=scaling,
return_lse=True,
)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attention_weights
18 changes: 15 additions & 3 deletions src/transformers/integrations/sdpa_attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch


Expand All @@ -13,7 +15,16 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kwargs):
def sdpa_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
dropout: float = 0.0,
scaling: Optional[float] = None,
**kwargs,
):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)

Expand All @@ -31,9 +42,10 @@ def sdpa_attention_forward(module, query, key, value, attention_mask=None, **_kw
key,
value,
attn_mask=causal_mask,
dropout_p=module.config.attention_dropout if module.training else 0.0,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
scale=module.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, None
2 changes: 1 addition & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from dataclasses import dataclass
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from zipfile import is_zipfile

import torch
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def forward(
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,10 +659,6 @@ def __init__(self, config: Cohere2Config):
[Cohere2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)

self.gradient_checkpointing = False
if getattr(config, "pretraining_tp", 1) != 1:
logger.warn("`pretraining_tp` is deprecated, please use `model.tensor_parallel` instead.")
self.rotary_emb = Cohere2RotaryEmbedding(config=config)

# Initialize weights and apply final processing
Expand Down
Loading

0 comments on commit ecd814b

Please sign in to comment.