Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support transformers 4.42 #1929

Merged
merged 9 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test_onnx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install .[tests,onnxruntime,exporters-tf]
pip install .[tests,onnxruntime,exporters]
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
- name: Test with unittest
working-directory: tests
run: |
Expand Down
11 changes: 9 additions & 2 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME
from .model_patcher import (
FalconModelPatcher,
MistralModelPatcher,
MusicgenModelPatcher,
SAMModelPatcher,
SentenceTransformersCLIPPatcher,
Expand Down Expand Up @@ -237,7 +238,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]:


class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head")


Expand All @@ -259,7 +260,7 @@ class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):


class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 13
DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig


Expand Down Expand Up @@ -312,6 +313,12 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig):
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
print("IN patch_model_for_export")
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
return MistralModelPatcher(self, model, model_kwargs=model_kwargs)


class MPTOnnxConfig(TextDecoderOnnxConfig):
# MPT does not require position_ids input.
Expand Down
198 changes: 198 additions & 0 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
_prepare_4d_causal_attention_mask_for_sdpa = None
AttentionMaskConverter = None

if _transformers_version >= version.parse("4.42"):
from transformers.cache_utils import SlidingWindowCache, StaticCache

if TYPE_CHECKING:
from transformers import PreTrainedModel, TFPreTrainedModel

Expand Down Expand Up @@ -746,6 +749,20 @@ def patched_forward(


class SentenceTransformersTransformerPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral":
self._model[0].auto_model._update_causal_mask = types.MethodType(
_update_causal_mask_patched, self._model[0].auto_model
)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral":
self._model[0].auto_model._update_causal_mask = types.MethodType(
self._update_causal_mask_original, self._model[0].auto_model
)

def __init__(
self,
config: "OnnxConfig",
Expand All @@ -754,6 +771,8 @@ def __init__(
):
super().__init__(config, model, model_kwargs)

self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask

def patched_forward(input_ids, attention_mask):
result = self.orig_forward({"input_ids": input_ids, "attention_mask": attention_mask})

Expand Down Expand Up @@ -931,3 +950,182 @@ def patched_forward(
return {"audio_values": audio_values}

self.patched_forward = patched_forward


def _update_causal_mask_patched(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values,
use_cache: bool,
output_attentions: bool,
):
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if self._attn_implementation == "flash_attention_2":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.

# cache_position must be valid here no matter which cache we use
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)

if (
self.config._attn_implementation == "sdpa"
and not (using_static_cache or using_sliding_window_cache)
and not output_attentions
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
sliding_window=self.config.sliding_window,
is_training=self.training,
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
target_length = max(sequence_length, self.config.sliding_window)
# StaticCache
elif using_static_cache:
target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)

if attention_mask is not None and attention_mask.dim() == 4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max() != 0:
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
causal_mask = attention_mask
else:
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
if self.config.sliding_window is not None:
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
# ---------------- NOTE: This part is patched -----------------------------
exclude_mask.bitwise_or_(
torch.arange(target_length, device=device)
<= (cache_position.reshape(-1, 1) - self.config.sliding_window)
)
# ---------------- NOTE: patch end ----------------------------------------

causal_mask *= exclude_mask
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)

# if (
# self.config._attn_implementation == "sdpa"
# and attention_mask is not None
# and attention_mask.device.type == "cuda"
# and not output_attentions
# ):
# # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# # Details: https://github.com/pytorch/pytorch/issues/110213
# causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask


class MistralModelPatcher(ModelPatcher):
def __enter__(self):
super().__enter__()
if AttentionMaskConverter is not None:
# TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35
AttentionMaskConverter._make_causal_mask = _make_causal_mask_patched_staticmethod

if _transformers_version >= version.parse("4.36"):
AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod

if _transformers_version >= version.parse("4.36"):
patch_everywhere(
"_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched
)

print("self._model in enter", self._model)
fxmarty marked this conversation as resolved.
Show resolved Hide resolved
if _transformers_version >= version.parse("4.42"):
if hasattr(self._model, "model"):
self._model.model._update_causal_mask = types.MethodType(
_update_causal_mask_patched, self._model.model
)
else:
self._model._update_causal_mask = types.MethodType(_update_causal_mask_patched, self._model)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if AttentionMaskConverter is not None:
# TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35
AttentionMaskConverter._make_causal_mask = staticmethod(self.original_make_causal)

if _transformers_version >= version.parse("4.36"):
AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended)

if _transformers_version >= version.parse("4.36"):
patch_everywhere(
"_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa
)

if _transformers_version >= version.parse("4.42"):
if hasattr(self._model, "model"):
self._model.model._update_causal_mask = types.MethodType(
self._update_causal_mask_original, self._model.model
)
else:
self._model._update_causal_mask = types.MethodType(self._update_causal_mask_original, self._model)

def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

if _transformers_version >= version.parse("4.36"):
self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa
self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended

# TODO: Remove this if once transformers if much above 4.35
if AttentionMaskConverter is not None:
self.original_make_causal = AttentionMaskConverter._make_causal_mask

if _transformers_version >= version.parse("4.42"):
if hasattr(self._model, "model"):
self._update_causal_mask_original = self._model.model._update_causal_mask
else:
self._update_causal_mask_original = self._model._update_causal_mask
1 change: 1 addition & 0 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class ORTModelForCausalLM(ORTModel, GenerationMixin):

auto_model_class = AutoModelForCausalLM
main_input_name = "input_ids"
_supports_cache_class = False

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ class ORTModelForConditionalGeneration(ORTModel, ABC):

# Used in from_transformers to export model to onnxORTEncoder
base_model_prefix = "onnx_model"
_supports_cache_class = False

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
REQUIRED_PKGS = [
"coloredlogs",
"sympy",
"transformers[sentencepiece]>=4.26.0,<4.42.0",
"transformers[sentencepiece]>=4.26.0,<4.43.0",
"torch>=1.11",
"packaging",
"numpy<2.0", # transformers requires numpy<2.0 https://github.com/huggingface/transformers/pull/31569
Expand Down
Loading