diff --git a/.github/workflows/test_onnx.yml b/.github/workflows/test_onnx.yml index 5a21f12d015..9aa8b307235 100644 --- a/.github/workflows/test_onnx.yml +++ b/.github/workflows/test_onnx.yml @@ -27,7 +27,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - pip install .[tests,onnxruntime,exporters-tf] + pip install .[tests,exporters] - name: Test with unittest working-directory: tests run: | diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 421b7c9010a..c66e54b323c 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -72,6 +72,7 @@ from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME from .model_patcher import ( FalconModelPatcher, + MistralModelPatcher, MusicgenModelPatcher, SAMModelPatcher, SentenceTransformersCLIPPatcher, @@ -237,7 +238,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: class GPT2OnnxConfig(TextDecoderWithPositionIdsOnnxConfig): - DEFAULT_ONNX_OPSET = 13 + DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") @@ -259,7 +260,7 @@ class GPTNeoOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): class GPTNeoXOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): - DEFAULT_ONNX_OPSET = 13 + DEFAULT_ONNX_OPSET = 14 # uses SDPA in Transformers, hence opset>=14. NORMALIZED_CONFIG_CLASS = NormalizedTextConfig @@ -312,6 +313,11 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MistralModelPatcher(self, model, model_kwargs=model_kwargs) + class MPTOnnxConfig(TextDecoderOnnxConfig): # MPT does not require position_ids input. diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 0a105343546..215d65549f8 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -42,6 +42,9 @@ _prepare_4d_causal_attention_mask_for_sdpa = None AttentionMaskConverter = None +if _transformers_version >= version.parse("4.42"): + from transformers.cache_utils import SlidingWindowCache, StaticCache + if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel @@ -746,6 +749,20 @@ def patched_forward( class SentenceTransformersTransformerPatcher(ModelPatcher): + def __enter__(self): + super().__enter__() + if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral": + self._model[0].auto_model._update_causal_mask = types.MethodType( + _update_causal_mask_patched, self._model[0].auto_model + ) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if _transformers_version >= version.parse("4.42") and self.real_config._config.model_type == "mistral": + self._model[0].auto_model._update_causal_mask = types.MethodType( + self._update_causal_mask_original, self._model[0].auto_model + ) + def __init__( self, config: "OnnxConfig", @@ -754,6 +771,8 @@ def __init__( ): super().__init__(config, model, model_kwargs) + self._update_causal_mask_original = self._model[0].auto_model._update_causal_mask + def patched_forward(input_ids, attention_mask): result = self.orig_forward({"input_ids": input_ids, "attention_mask": attention_mask}) @@ -931,3 +950,181 @@ def patched_forward( return {"audio_values": audio_values} self.patched_forward = patched_forward + + +def _update_causal_mask_patched( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values, + use_cache: bool, + output_attentions: bool, +): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self._attn_implementation == "flash_attention_2": + if attention_mask is not None and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + + # cache_position must be valid here no matter which cache we use + past_seen_tokens = cache_position[0] if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache + if using_sliding_window_cache: + target_length = max(sequence_length, self.config.sliding_window) + # StaticCache + elif using_static_cache: + target_length = past_key_values.get_max_length() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if self.config.sliding_window is not None: + if not using_sliding_window_cache or sequence_length > self.config.sliding_window: + # ---------------- NOTE: This part is patched ----------------------------- + exclude_mask.bitwise_or_( + torch.arange(target_length, device=device) + <= (cache_position.reshape(-1, 1) - self.config.sliding_window) + ) + # ---------------- NOTE: patch end ---------------------------------------- + + causal_mask *= exclude_mask + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + # if ( + # self.config._attn_implementation == "sdpa" + # and attention_mask is not None + # and attention_mask.device.type == "cuda" + # and not output_attentions + # ): + # # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # # Details: https://github.com/pytorch/pytorch/issues/110213 + # causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class MistralModelPatcher(ModelPatcher): + def __enter__(self): + super().__enter__() + if AttentionMaskConverter is not None: + # TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35 + AttentionMaskConverter._make_causal_mask = _make_causal_mask_patched_staticmethod + + if _transformers_version >= version.parse("4.36"): + AttentionMaskConverter._unmask_unattended = _unmask_unattended_patched_staticmethod + + if _transformers_version >= version.parse("4.36"): + patch_everywhere( + "_prepare_4d_causal_attention_mask_for_sdpa", _prepare_4d_causal_attention_mask_for_sdpa_patched + ) + + if _transformers_version >= version.parse("4.42"): + if hasattr(self._model, "model"): + self._model.model._update_causal_mask = types.MethodType( + _update_causal_mask_patched, self._model.model + ) + else: + self._model._update_causal_mask = types.MethodType(_update_causal_mask_patched, self._model) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + if AttentionMaskConverter is not None: + # TODO: Remove this _make_causal_mask patch if once transformers if much above 4.35 + AttentionMaskConverter._make_causal_mask = staticmethod(self.original_make_causal) + + if _transformers_version >= version.parse("4.36"): + AttentionMaskConverter._unmask_unattended = staticmethod(self.original_unmask_unattended) + + if _transformers_version >= version.parse("4.36"): + patch_everywhere( + "_prepare_4d_causal_attention_mask_for_sdpa", self.original_prepare_4d_causal_attention_mask_for_sdpa + ) + + if _transformers_version >= version.parse("4.42"): + if hasattr(self._model, "model"): + self._model.model._update_causal_mask = types.MethodType( + self._update_causal_mask_original, self._model.model + ) + else: + self._model._update_causal_mask = types.MethodType(self._update_causal_mask_original, self._model) + + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + if _transformers_version >= version.parse("4.36"): + self.original_prepare_4d_causal_attention_mask_for_sdpa = _prepare_4d_causal_attention_mask_for_sdpa + self.original_unmask_unattended = AttentionMaskConverter._unmask_unattended + + # TODO: Remove this if once transformers if much above 4.35 + if AttentionMaskConverter is not None: + self.original_make_causal = AttentionMaskConverter._make_causal_mask + + if _transformers_version >= version.parse("4.42"): + if hasattr(self._model, "model"): + self._update_causal_mask_original = self._model.model._update_causal_mask + else: + self._update_causal_mask_original = self._model._update_causal_mask diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index fd7e741d7c0..6a0dcbba2f0 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -121,6 +121,7 @@ class ORTModelForCausalLM(ORTModel, GenerationMixin): auto_model_class = AutoModelForCausalLM main_input_name = "input_ids" + _supports_cache_class = False def __init__( self, diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index bfdfbff1b11..14bcad682c7 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -1092,9 +1092,10 @@ def forward( model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) if "last_hidden_state" in self.output_names: - last_hidden_state = model_outputs[self.output_names["last_hidden_state"]] + last_hidden_state = model_outputs["last_hidden_state"] else: - last_hidden_state = model_outputs[0] + # TODO: This allows to support sentence-transformers models (sentence embedding), but is not validated. + last_hidden_state = next(iter(model_outputs.values())) # converts output to namedtuple for pipelines post-processing return BaseModelOutput(last_hidden_state=last_hidden_state) diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 89a0ae44d58..3b1af05d0f5 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -570,6 +570,7 @@ class ORTModelForConditionalGeneration(ORTModel, ABC): # Used in from_transformers to export model to onnxORTEncoder base_model_prefix = "onnx_model" + _supports_cache_class = False def __init__( self, diff --git a/setup.py b/setup.py index b6a5b07f932..cc88760d614 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ REQUIRED_PKGS = [ "coloredlogs", "sympy", - "transformers[sentencepiece]>=4.26.0,<4.42.0", + "transformers[sentencepiece]>=4.26.0,<4.43.0", "torch>=1.11", "packaging", "numpy<2.0", # transformers requires numpy<2.0 https://github.com/huggingface/transformers/pull/31569