Skip to content

Commit

Permalink
Add Transformers v4.45 support (#2023)
Browse files Browse the repository at this point in the history
* transformers v4.45 support

* fix transformers v4.45 compatibility

* update opset

* update model

* Add generation config saving

* fix codegen

* bump default opset m2m100

* fix codegen

* fix bettertransformers

* add warnign deprecation bettertransformer

* bettertransformers fixes

* disable transformers 4.45 for onnx export

* update model ID
  • Loading branch information
echarlaix authored Sep 30, 2024
1 parent c6b4678 commit 049b00f
Show file tree
Hide file tree
Showing 14 changed files with 223 additions and 92 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ REAL_CLONE_URL = $(if $(CLONE_URL),$(CLONE_URL),$(DEFAULT_CLONE_URL))
# Run code quality checks
style_check:
black --check .
ruff .
ruff check .

style:
black .
ruff . --fix
ruff check . --fix

# Run tests for the library
test:
Expand Down
84 changes: 77 additions & 7 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,71 @@ def gpt2_wrapped_scaled_dot_product(
return sdpa_result, None


# Adapted from transformers.models.gptj.modeling_gptj.GPTJAttention._attn
def gptj_wrapped_scaled_dot_product(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
raise_on_head_mask(head_mask)
batch_size = query.shape[0]

mask_value = torch.finfo(value.dtype).min
mask_value = torch.full([], mask_value, dtype=value.dtype)

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
query = query.to(value.dtype)
key = key.to(value.dtype)

if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1:
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")

dropout_p = self.dropout_prob_attn if self.training else 0.0
if batch_size == 1 or self.training:
if query.shape[2] > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
query_length, key_length = query.size(-2), key.size(-2)

# causal_mask is always [True, ..., True] otherwise, so executing this
# is unnecessary
if query_length > 1:
if not check_if_transformers_greater("4.44.99"):
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)

causal_mask = torch.where(causal_mask, 0, mask_value)

# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
if attention_mask is not None:
attention_mask = causal_mask + attention_mask

else:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

# in gpt-neo-x and gpt-j the query and keys are always in fp32
# thus we need to cast them to the value dtype
if self.downcast_qk:
sdpa_result = sdpa_result.to(value.dtype)

return sdpa_result, None


# Adapted from transformers.models.bark.modeling_bark.BarkSelfAttention._attn
def bark_wrapped_scaled_dot_product(
self,
Expand Down Expand Up @@ -195,7 +260,7 @@ def codegen_wrapped_scaled_dot_product(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
# in this case, which is the later decoding steps, the `causal_mask`` in
# in this case, which is the later decoding steps, the `causal_mask` in
# https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
# is [True, ..., True] so actually not causal
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
Expand All @@ -207,15 +272,20 @@ def codegen_wrapped_scaled_dot_product(
# causal_mask is always [True, ..., True] otherwise, so executing this
# is unnecessary
if query_length > 1:
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
if not check_if_transformers_greater("4.44.99"):
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(
torch.bool
)

causal_mask = torch.where(causal_mask, 0, mask_value)
causal_mask = torch.where(causal_mask, 0, mask_value)

# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
# torch.Tensor.expand does no memory copy
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)

# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)
# we use torch.min to avoid having tensor(-inf)
attention_mask = torch.min(causal_mask, attention_mask)
else:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
Expand Down
35 changes: 31 additions & 4 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
codegen_wrapped_scaled_dot_product,
gpt2_wrapped_scaled_dot_product,
gpt_neo_wrapped_scaled_dot_product,
gptj_wrapped_scaled_dot_product,
opt_forward,
t5_forward,
)
Expand Down Expand Up @@ -82,7 +83,7 @@ def forward(self, *args, **kwargs):


class GPTJAttentionLayerBetterTransformer(BetterTransformerBaseLayer, GPTJAttention, nn.Module):
_attn = gpt2_wrapped_scaled_dot_product
_attn = gptj_wrapped_scaled_dot_product

def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
super().__init__(config)
Expand All @@ -96,14 +97,22 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
"out_proj",
"attn_dropout",
"resid_dropout",
"bias",
"scale_attn",
"masked_bias",
]
# Attribute only for transformers>=4.28
if hasattr(layer, "embed_positions"):
submodules.append("embed_positions")

# Attribute only for transformers<4.45
if hasattr(layer, "bias"):
submodules.append("bias")
if hasattr(layer, "masked_bias"):
submodules.append("masked_bias")

# Attribute only for transformers>=4.45
if hasattr(layer, "layer_idx"):
submodules.append("layer_idx")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

Expand All @@ -127,6 +136,11 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.module_mapping = None
submodules = ["rotary_emb", "query_key_value", "dense", "bias", "masked_bias", "norm_factor"]

# Attribute only for transformers>=4.45
if hasattr(layer, "layer_idx"):
submodules.append("layer_idx")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

Expand Down Expand Up @@ -155,6 +169,11 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

self.module_mapping = None
submodules = ["attn_dropout", "resid_dropout", "k_proj", "v_proj", "q_proj", "out_proj", "bias", "masked_bias"]

# Attribute only for transformers>=4.45
if hasattr(layer, "layer_id"):
submodules.append("layer_id")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

Expand Down Expand Up @@ -238,12 +257,20 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
super(BetterTransformerBaseLayer, self).__init__(config)

self.module_mapping = None
submodules = ["attn_dropout", "resid_dropout", "qkv_proj", "out_proj", "causal_mask", "scale_attn"]
submodules = ["attn_dropout", "resid_dropout", "qkv_proj", "out_proj", "scale_attn"]

# Attribute only for transformers>=4.28
if hasattr(layer, "embed_positions"):
submodules.append("embed_positions")

# Attribute only for transformers<4.45
if hasattr(layer, "causal_mask"):
submodules.append("causal_mask")

# Attribute only for transformers>=4.45
if hasattr(layer, "layer_idx"):
submodules.append("layer_idx")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

Expand Down
4 changes: 4 additions & 0 deletions optimum/bettertransformer/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ def transform(
The converted model if the conversion has been successful.
"""

logger.warning(
"The class `optimum.bettertransformers.transformation.BetterTransformer` is deprecated and will be removed in a future release."
)

hf_config = model.config
if hf_config.model_type in ["falcon", "gpt_bigcode", "llama", "whisper"]:
raise ValueError(
Expand Down
18 changes: 18 additions & 0 deletions optimum/exporters/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import numpy as np
import onnx
import transformers
from transformers.modeling_utils import get_parameter_dtype
from transformers.utils import is_tf_available, is_torch_available

Expand All @@ -34,6 +35,7 @@
DEFAULT_DUMMY_SHAPES,
ONNX_WEIGHTS_NAME,
TORCH_MINIMUM_VERSION,
check_if_transformers_greater,
is_diffusers_available,
is_torch_onnx_support_available,
logging,
Expand Down Expand Up @@ -999,6 +1001,10 @@ def onnx_export_from_model(
>>> onnx_export_from_model(model, output="gpt2_onnx/")
```
"""
if check_if_transformers_greater("4.44.99"):
raise ImportError(
f"ONNX conversion disabled for now for transformers version greater than v4.45, found {transformers.__version__}"
)

TasksManager.standardize_model_attributes(model)

Expand Down Expand Up @@ -1120,6 +1126,18 @@ def onnx_export_from_model(
if isinstance(atol, dict):
atol = atol[task.replace("-with-past", "")]

if check_if_transformers_greater("4.44.99"):
misplaced_generation_parameters = model.config._get_non_default_generation_parameters()
if model.can_generate() and len(misplaced_generation_parameters) > 0:
logger.warning(
"Moving the following attributes in the config to the generation config: "
f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
"generation parameters in the model config, as opposed to in the generation config.",
)
for param_name, param_value in misplaced_generation_parameters.items():
setattr(model.generation_config, param_name, param_value)
setattr(model.config, param_name, None)

# Saving the model config and preprocessor as this is needed sometimes.
model.config.save_pretrained(output)
generation_config = getattr(model, "generation_config", None)
Expand Down
11 changes: 6 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:


class AlbertOnnxConfig(BertOnnxConfig):
DEFAULT_ONNX_OPSET = 11
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class ConvBertOnnxConfig(BertOnnxConfig):
Expand Down Expand Up @@ -171,11 +171,11 @@ class MPNetOnnxConfig(DistilBertOnnxConfig):


class RobertaOnnxConfig(DistilBertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class CamembertOnnxConfig(DistilBertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class FlaubertOnnxConfig(BertOnnxConfig):
Expand All @@ -187,7 +187,7 @@ class IBertOnnxConfig(DistilBertOnnxConfig):


class XLMRobertaOnnxConfig(DistilBertOnnxConfig):
pass
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.


class DebertaOnnxConfig(BertOnnxConfig):
Expand Down Expand Up @@ -257,7 +257,7 @@ class ImageGPTOnnxConfig(GPT2OnnxConfig):


class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_attention_heads="num_heads")


Expand Down Expand Up @@ -564,6 +564,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int


class M2M100OnnxConfig(TextSeq2SeqOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig.with_args(
encoder_num_layers="encoder_layers",
decoder_num_layers="decoder_layers",
Expand Down
3 changes: 3 additions & 0 deletions optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ def from_pretrained(
export = from_transformers

if len(model_id.split("@")) == 2:
logger.warning(
f"Specifying the `revision` as @{model_id.split('@')[1]} is deprecated and will be removed in v1.23, please use the `revision` argument instead."
)
if revision is not None:
logger.warning(
f"The argument `revision` was set to {revision} but will be ignored for {model_id.split('@')[1]}"
Expand Down
Loading

0 comments on commit 049b00f

Please sign in to comment.