From ca19481096cf78c035cfe0305ec49b822976451b Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 2 Nov 2023 15:49:25 +0100 Subject: [PATCH] Remove attn mask patching (#1509) * Remove _prepare_decoder_attention_mask patching * Add specific warning for exports with sequence_length set to 1 * Style * Remove Falcon attention mask patching * lots of cleaning * fix mistral * fix legacy * more fixes * fix make_causal patching * remove unused method --------- Co-authored-by: baskrahmer --- optimum/exporters/onnx/__main__.py | 33 ++-- optimum/exporters/onnx/base.py | 30 ++- optimum/exporters/onnx/config.py | 18 +- optimum/exporters/onnx/model_configs.py | 77 ++++---- optimum/exporters/onnx/model_patcher.py | 242 ++++++++---------------- optimum/exporters/onnx/utils.py | 17 +- optimum/onnxruntime/modeling_decoder.py | 2 +- optimum/utils/modeling_utils.py | 164 ---------------- 8 files changed, 181 insertions(+), 402 deletions(-) diff --git a/optimum/exporters/onnx/__main__.py b/optimum/exporters/onnx/__main__.py index df5c2498eff..654f9a649e1 100644 --- a/optimum/exporters/onnx/__main__.py +++ b/optimum/exporters/onnx/__main__.py @@ -24,6 +24,7 @@ from ...commands.export.onnx import parse_args_onnx from ...utils import DEFAULT_DUMMY_SHAPES, ONNX_WEIGHTS_NAME, logging +from ...utils.modeling_utils import MODEL_TO_PATCH_FOR_PAST from ...utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from ..error_utils import AtolError, OutputMatchError, ShapeError from ..tasks import TasksManager @@ -83,16 +84,12 @@ def _get_submodels_and_onnx_configs( onnx_config_constructor = TasksManager.get_exporter_config_constructor( model=model, exporter="onnx", task=task ) - onnx_config_kwargs = {} - if task.startswith("text-generation") and legacy: - onnx_config_kwargs["no_position_ids"] = legacy - onnx_config = onnx_config_constructor( model.config, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors, - **onnx_config_kwargs, + legacy=legacy, ) onnx_config.variant = _variant @@ -317,13 +314,6 @@ def main_export( model_name_or_path, subfolder=subfolder, library_name=library_name ) - # get the shapes to be used to generate dummy inputs - input_shapes = {} - for input_name in DEFAULT_DUMMY_SHAPES.keys(): - input_shapes[input_name] = ( - kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] - ) - torch_dtype = None if fp16 is False else torch.float16 if task == "auto": @@ -382,6 +372,25 @@ def main_export( is_stable_diffusion = "stable-diffusion" in task model_type = "stable-diffusion" if is_stable_diffusion else model.config.model_type.replace("_", "-") + # For MODEL_TO_PATCH_FOR_PAST architectures, when exporting the model with an input of sequence length of 1, a tracer that does not handle + # controlflows will trace incorrectly the mask generation, resulting in incorrect attention masks for other sequence lengthss. + # Reference: https://github.com/huggingface/transformers/blob/af3de8d87c717c4bb090f037d0d89413c195a42f/src/transformers/modeling_attn_mask_utils.py#L94 + input_shapes = {} + for input_name in DEFAULT_DUMMY_SHAPES.keys(): + input_shapes[input_name] = ( + kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] + ) + + # TODO: this may be moved rather to the OnnxConfig to avoid bloating this script. + if ( + model_type in MODEL_TO_PATCH_FOR_PAST + and input_name == "sequence_length" + and kwargs_shapes.get(input_name) == 1 + ): + raise ValueError( + f"Exporting with a sequence length of 1 a {model_type} model is not supported and can yield unexpected results." + ) + if legacy and model_type in MODEL_TYPES_REQUIRING_POSITION_IDS and task.startswith("text-generation"): logger.warning( f"legacy=True was specified in the ONNX export, although the model {model_name_or_path} (model type {model_type}) requires position_ids for batched inference. Passing `legacy=True` is strongly discouraged, and this option will be removed in a future release. Reference: https://github.com/huggingface/optimum/pull/1381" diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index b623d3bd22f..6765f3310cd 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -200,6 +200,7 @@ def __init__( preprocessors: Optional[List[Any]] = None, int_dtype: str = "int64", float_dtype: str = "fp32", + legacy: bool = False, ): self.task = task self.int_dtype = int_dtype @@ -209,6 +210,7 @@ def __init__( self._preprocessors = preprocessors self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) self.variant = "default" + self.legacy = legacy def _create_dummy_input_generator_classes(self, **kwargs) -> List[DummyInputGenerator]: """ @@ -565,6 +567,7 @@ def __init__( use_past: bool = False, use_past_in_inputs: bool = False, preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): self.use_past = use_past self.use_past_in_inputs = use_past_in_inputs @@ -572,7 +575,12 @@ def __init__( self.is_merged = False self.use_cache_branch = None super().__init__( - config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=legacy, ) @property @@ -628,11 +636,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): and "attention_mask" in dummy_inputs ): # Obtain the past sequence length from the value instead of the key (Bloom). - past_length = dummy_inputs["past_key_values"][0][1].shape[-2] + past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[-2] dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( dummy_inputs["attention_mask"], - desired_length=past_length + 1, + desired_length=past_present_length, dim=1, dtype=dummy_inputs["attention_mask"].dtype, ) @@ -658,11 +666,15 @@ def overwrite_shape_and_generate_input( # models from TextSeq2SeqOnnxConfig use decoder_input_ids as input name # while models from TextDecoderOnnxConfig use input_ids, hence the check for both + + # TODO: The check `self.task != "text-generation" and self.legacy` is added following the use of a single ONNX for both without/with KV cache, without subgraphs. + # This overwrite may be moved to OnnxSeq2SeqConfigWithPast, but I am afraid it would break encoder-decoder models. if ( self.use_past and self.use_past_in_inputs and self.use_cache_branch is not False and input_name in ["decoder_input_ids", "input_ids", "position_ids"] + and ((self.task == "text-generation" and self.legacy) or self.task != "text-generation") ): sequence_length = dummy_input_gen.sequence_length # Use a sequence length of 1 when the KV cache is already populated. @@ -768,6 +780,7 @@ def __init__( use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): super().__init__( config=config, @@ -777,6 +790,7 @@ def __init__( use_past=use_past, use_past_in_inputs=use_past_in_inputs, preprocessors=preprocessors, + legacy=legacy, ) self._behavior = behavior @@ -816,6 +830,7 @@ def with_behavior( use_past_in_inputs=use_past_in_inputs, behavior=behavior, preprocessors=self._preprocessors, + legacy=self.legacy, ) onnx_config.variant = self.variant return onnx_config @@ -1003,7 +1018,7 @@ class OnnxConfigWithLoss(OnnxConfig, ABC): DUMMY_EXTRA_INPUT_GENERATOR_CLASSES = (DummyLabelsGenerator,) - def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: str = "fp32"): + def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: str = "fp32", legacy: bool = False): self._onnx_config = config self.task = self._onnx_config.task self.int_dtype = int_dtype @@ -1011,6 +1026,7 @@ def __init__(self, config: OnnxConfig, int_dtype: str = "int64", float_dtype: st self._normalized_config = self._onnx_config._normalized_config self.PATCHING_SPECS = self._onnx_config.PATCHING_SPECS self.variant = "default" + self.legacy = legacy @classmethod def from_onnx_config(cls, config: OnnxConfig) -> "OnnxConfigWithLoss": @@ -1037,7 +1053,11 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): batch_size = dummy_inputs[input_name].shape[0] # TODO: doesn't this break attention_mask generation? - if isinstance(self._onnx_config, OnnxConfigWithPast) and self._onnx_config.use_past_in_inputs is True: + if ( + isinstance(self._onnx_config, OnnxConfigWithPast) + and self._onnx_config.use_past_in_inputs is True + and self.task != "text-generation" + ): kwargs["sequence_length"] = 1 else: for input_name, dynamic_axes in self._tasks_to_extra_inputs[self.task].items(): diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 7b7d8b19a50..2eaa78d85e4 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -35,11 +35,14 @@ ) from .base import ConfigBehavior, OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast from .constants import ONNX_DECODER_MERGED_NAME, ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME +from .model_patcher import DecoderModelPatcher if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel + from .model_patcher import ModelPatcher + if is_tf_available(): from transformers import TFPreTrainedModel @@ -75,7 +78,7 @@ def __init__( use_past: bool = False, use_past_in_inputs: bool = False, preprocessors: Optional[List[Any]] = None, - no_position_ids: bool = False, + legacy: bool = False, ): super().__init__( config=config, @@ -85,9 +88,8 @@ def __init__( use_past=use_past, use_past_in_inputs=use_past_in_inputs, preprocessors=preprocessors, + legacy=legacy, ) - # TODO: remove no_position_ids once optimum is sufficiently above 1.13 - self.no_position_ids = no_position_ids @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -154,6 +156,12 @@ def post_process_exported_models( return models_and_onnx_configs, onnx_files_subpaths + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + # Refer to DecoderModelPatcher. + return DecoderModelPatcher(self, model, model_kwargs=model_kwargs) + class TextDecoderWithPositionIdsOnnxConfig(TextDecoderOnnxConfig): @property @@ -163,7 +171,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: # Decoders based on GPT2 require a position_ids input to avoid # generating wrong position_ids in the model itself: # https://github.com/huggingface/transformers/blob/v4.33.1/src/transformers/models/gpt2/modeling_gpt2.py#L802 - if not self.no_position_ids and self.task in ["text-generation", "feature-extraction"]: + if not self.legacy and self.task in ["text-generation", "feature-extraction"]: common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} return common_inputs @@ -316,6 +324,7 @@ def __init__( use_past_in_inputs: bool = False, behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): super().__init__( config=config, @@ -326,6 +335,7 @@ def __init__( use_past_in_inputs=use_past_in_inputs, behavior=behavior, preprocessors=preprocessors, + legacy=legacy, ) from ..tasks import TasksManager diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 2da3f5bea6b..b5d67e50409 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -61,12 +61,7 @@ VisionOnnxConfig, ) from .model_patcher import ( - BartModelPatcher, - BloomModelPatcher, FalconModelPatcher, - LlamaModelPatcher, - MistralModelPatcher, - OPTModelPatcher, SAMModelPatcher, SpeechT5ModelPatcher, VisionEncoderDecoderPatcher, @@ -230,11 +225,6 @@ class OPTOnnxConfig(TextDecoderOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return OPTModelPatcher(self, model, model_kwargs=model_kwargs) - class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) @@ -242,13 +232,11 @@ class LlamaOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DEFAULT_ONNX_OPSET = 13 NORMALIZED_CONFIG_CLASS = NormalizedTextConfig - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return LlamaModelPatcher(self, model, model_kwargs=model_kwargs) - class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): + # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35 + MIN_TRANSFORMERS_VERSION = version.parse("4.34.99") + # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 DEFAULT_ONNX_OPSET = 14 DUMMY_INPUT_GENERATOR_CLASSES = ( @@ -257,11 +245,6 @@ class MistralOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return MistralModelPatcher(self, model, model_kwargs=model_kwargs) - class MPTOnnxConfig(TextDecoderOnnxConfig): # MPT does not require position_ids input. @@ -270,11 +253,6 @@ class MPTOnnxConfig(TextDecoderOnnxConfig): num_attention_heads="n_heads", hidden_size="d_model", num_layers="n_layers" ) - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return BloomModelPatcher(self, model, model_kwargs=model_kwargs) - class BloomOnnxConfig(TextDecoderOnnxConfig): # Bloom does not require position_ids input. @@ -305,11 +283,6 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire 1: decoder_sequence_name, } - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return BloomModelPatcher(self, model, model_kwargs=model_kwargs) - class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): DUMMY_INPUT_GENERATOR_CLASSES = ( @@ -341,6 +314,9 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): class FalconOnnxConfig(TextDecoderOnnxConfig): + # This is because of the patching that uses _prepare_4d_causal_attention_mask from transformers>=4.35 + MIN_TRANSFORMERS_VERSION = version.parse("4.34.99") + DUMMY_INPUT_GENERATOR_CLASSES = ( MultiQueryPastKeyValuesGenerator, ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES @@ -357,7 +333,7 @@ def __init__( use_past: bool = False, use_past_in_inputs: bool = False, preprocessors: Optional[List[Any]] = None, - no_position_ids: bool = False, + legacy: bool = False, ): super().__init__( config=config, @@ -367,7 +343,7 @@ def __init__( use_past=use_past, use_past_in_inputs=use_past_in_inputs, preprocessors=preprocessors, - no_position_ids=no_position_ids, + legacy=legacy, ) # For some reason Falcon config.num_kv_heads can not be trusted, see in Transformers: # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L337 @@ -381,11 +357,7 @@ def __init__( def inputs(self) -> Dict[str, Dict[int, str]]: common_inputs = super().inputs - if ( - not self.no_position_ids - and not self._config.alibi - and self.task in ["text-generation", "feature-extraction"] - ): + if not self.legacy and not self._config.alibi and self.task in ["text-generation", "feature-extraction"]: # When alibi is used, position_ids are not used in Falcon. # Reference: https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/falcon/modeling_falcon.py#L1116 common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"} @@ -655,10 +627,7 @@ def flatten_past_key_values(self, flattened_output, name, idx, t): class BartOnnxConfig(M2M100OnnxConfig): - def patch_model_for_export( - self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None - ) -> "ModelPatcher": - return BartModelPatcher(self, model, model_kwargs=model_kwargs) + pass class MBartOnnxConfig(BartOnnxConfig): @@ -1033,9 +1002,15 @@ def __init__( int_dtype: str = "int64", float_dtype: str = "fp32", preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): super().__init__( - config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=legacy, ) if task == "zero-shot-object-detection": logger.warning( @@ -1174,9 +1149,15 @@ def __init__( int_dtype: str = "int64", float_dtype: str = "fp32", preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): super().__init__( - config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=legacy, ) self.is_generating_dummy_inputs = False @@ -1351,6 +1332,7 @@ def __init__( behavior: ConfigBehavior = ConfigBehavior.MONOLITH, preprocessors: Optional[List[Any]] = None, is_postnet_and_vocoder: bool = False, + legacy: bool = False, ): super().__init__( config=config, @@ -1361,6 +1343,7 @@ def __init__( use_past_in_inputs=use_past_in_inputs, behavior=behavior, preprocessors=preprocessors, + legacy=legacy, ) if float_dtype == "fp16": raise ValueError( @@ -1595,9 +1578,15 @@ def __init__( variant: str = "split", vision_encoder: Optional[bool] = None, preprocessors: Optional[List[Any]] = None, + legacy: bool = False, ): super().__init__( - config=config, task=task, int_dtype=int_dtype, float_dtype=float_dtype, preprocessors=preprocessors + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=legacy, ) self.variant = variant self.vision_encoder = vision_encoder diff --git a/optimum/exporters/onnx/model_patcher.py b/optimum/exporters/onnx/model_patcher.py index 4a5f4d1ace4..09cdddc95fe 100644 --- a/optimum/exporters/onnx/model_patcher.py +++ b/optimum/exporters/onnx/model_patcher.py @@ -18,26 +18,26 @@ import types from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union -import transformers +from packaging import version from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.models.falcon.modeling_falcon import FalconModel, build_alibi_tensor +from transformers.models.falcon.modeling_falcon import build_alibi_tensor from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet from transformers.utils import is_torch_available -from ...utils.modeling_utils import ( - _falcon_prepare_attn_mask, - _prepare_attn_mask, - _prepare_decoder_attention_mask, - _prepare_decoder_sliding_window_attention_mask, -) - if is_torch_available(): import torch +from ...configuration_utils import _transformers_version from ...utils import logging +if _transformers_version > version.parse("4.34.99"): + from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask +else: + _prepare_4d_causal_attention_mask = None + AttentionMaskConverter = None + if TYPE_CHECKING: from transformers import PreTrainedModel, TFPreTrainedModel @@ -249,31 +249,6 @@ def __init__( model.decoder.model.decoder.config.use_cache = True -def _make_causal_mask_falcon_patched( - input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int -) -> torch.BoolTensor: - """ - Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it - just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1, - target_length, target_length+past_key_values_length]`. - """ - batch_size, target_length = input_ids_shape - - # NOTE: ONNX Runtime is not able to run ONNX Trilu node with bool input. As a workaround, we pass a float input - # and cast to bool here. Reference: https://github.com/microsoft/onnxruntime/issues/16189 - mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.float, device=device), diagonal=1).to( - torch.bool - ) - - # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op. - # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this - # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later. - past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device) - mask = torch.cat([past_mask, mask], dim=-1) - expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) - return expanded_mask - - def falcon_model_forward_without_kv_reformatting( self, input_ids: Optional[torch.LongTensor] = None, @@ -287,6 +262,8 @@ def falcon_model_forward_without_kv_reformatting( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): + # TODO: We may remove this patch once https://github.com/huggingface/transformers/pull/26933 is merged & released in Transformers. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -346,10 +323,9 @@ def falcon_model_forward_without_kv_reformatting( else: position_ids = position_ids.view(-1, seq_length).long() - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): @@ -359,7 +335,7 @@ def falcon_model_forward_without_kv_reformatting( outputs = block( hidden_states, layer_past=layer_past, - attention_mask=causal_mask, + attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, @@ -393,26 +369,77 @@ def falcon_model_forward_without_kv_reformatting( ) -class FalconModelPatcher(ModelPatcher): +def _make_causal_mask_patched( + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, +): + """ + Make causal mask used for bi-directional self-attention. + """ + # We add self in the signature because `self._make_causal_mask` is used elsewhere in the class definition, despite the method being a staticmethod. + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + # NOTE: adding dtype=torch.int64 here for triu to be supported by ORT: https://github.com/microsoft/onnxruntime/issues/16189 + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int64), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +_make_causal_mask_patched = staticmethod(_make_causal_mask_patched) + + +class DecoderModelPatcher(ModelPatcher): def __enter__(self): - self.patch_ops() + # TODO: Remove this if once transformers if much above 4.35 + if AttentionMaskConverter is not None: + AttentionMaskConverter._make_causal_mask = _make_causal_mask_patched - transformers.models.falcon.modeling_falcon._make_causal_mask = _make_causal_mask_falcon_patched + def __exit__(self, exc_type, exc_value, traceback): + # TODO: Remove this if once transformers if much above 4.35 + if AttentionMaskConverter is not None: + AttentionMaskConverter._make_causal_mask = self.original_make_causal + + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + # TODO: Remove this if once transformers if much above 4.35 + if AttentionMaskConverter is not None: + self.original_make_causal = AttentionMaskConverter._make_causal_mask + + +class FalconModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + self.patch_ops() if self.real_config.task == "text-generation": self._model.transformer.forward = types.MethodType( falcon_model_forward_without_kv_reformatting, self._model.transformer ) - # In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length. - if isinstance(self._model, FalconModel): - self._model._prepare_attn_mask = _falcon_prepare_attn_mask - else: - self._model.transformer._prepare_attn_mask = _falcon_prepare_attn_mask - - setattr(self._model, self.orig_forward_name, self.patched_forward) - def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) self.restore_ops() setattr(self._model, self.orig_forward_name, self.orig_forward) @@ -422,14 +449,6 @@ def __exit__(self, exc_type, exc_value, traceback): self.original_model_transformer_forward, self._model.transformer ) - transformers.models.falcon.modeling_falcon._make_causal_mask = self.original_make_causal - - # In order to use a single decoder, we need to patch the _prepare_attn_mask function to behave independently of the sequence length. - if isinstance(self._model, FalconModel): - self._model._prepare_attn_mask = self.original_falcon_prepare_attn_mask - else: - self._model.transformer._prepare_attn_mask = self.original_falcon_prepare_attn_mask - def __init__( self, config: "OnnxConfig", @@ -441,13 +460,6 @@ def __init__( if config.task == "text-generation": self.original_model_transformer_forward = model.transformer.forward - self.original_make_causal = transformers.models.falcon.modeling_falcon._make_causal_mask - - if isinstance(model, FalconModel): - self.original_falcon_prepare_attn_mask = model._prepare_attn_mask - else: - self.original_falcon_prepare_attn_mask = model.transformer._prepare_attn_mask - self._model = model self.orig_forward_name = "forward" if hasattr(self._model, "forward") else "call" @@ -763,103 +775,3 @@ def patched_forward( return filterd_outputs self.patched_forward = patched_forward - - -class CausalAttentionMaskModelPatcher(ModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - self.patch = self.real_config.task == "text-generation" and self.real_config.use_past - - def __enter__(self): - super().__enter__() - if self.patch: - setattr(self._model_to_patch, self._orig_func_name, self._patch_func.__get__(self._model_to_patch)) - - def __exit__(self, exc_type, exc_value, traceback): - super().__exit__(exc_type, exc_value, traceback) - if self.patch: - setattr(self._model_to_patch, self._orig_func_name, self._orig_func.__get__(self._model_to_patch)) - - -class BloomModelPatcher(CausalAttentionMaskModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - if self.patch: - self._model_to_patch = model.transformer - self._patch_func = _prepare_attn_mask - self._orig_func_name = "_prepare_attn_mask" - self._orig_func = self._model_to_patch._prepare_attn_mask - - -class OPTModelPatcher(CausalAttentionMaskModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - - if self.patch: - self._model_to_patch = model.model.decoder - self._patch_func = _prepare_decoder_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" - self._orig_func = self._model_to_patch._prepare_decoder_attention_mask - - -class LlamaModelPatcher(CausalAttentionMaskModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - - if self.patch: - self._model_to_patch = model.model - self._patch_func = _prepare_decoder_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" - self._orig_func = self._model_to_patch._prepare_decoder_attention_mask - - -class MistralModelPatcher(CausalAttentionMaskModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - - if self.patch: - self._model_to_patch = model.model - self._patch_func = _prepare_decoder_sliding_window_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" - self._orig_func = self._model_to_patch._prepare_decoder_attention_mask - - -class BartModelPatcher(CausalAttentionMaskModelPatcher, Seq2SeqModelPatcher): - def __init__( - self, - config: "OnnxConfig", - model: Union["PreTrainedModel", "TFPreTrainedModel"], - model_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__(config, model, model_kwargs) - - if self.patch: - self._model_to_patch = model.model.decoder - self._patch_func = _prepare_decoder_attention_mask - self._orig_func_name = "_prepare_decoder_attention_mask" - self._orig_func = self._model_to_patch._prepare_decoder_attention_mask diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index ef6206e8d06..c1737fc087c 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -30,7 +30,6 @@ logging, ) from ...utils.import_utils import _diffusers_version -from ...utils.modeling_utils import _prepare_attn_mask, _prepare_decoder_attention_mask # noqa: F401 from ..tasks import TasksManager from .constants import ONNX_DECODER_NAME, ONNX_DECODER_WITH_PAST_NAME, ONNX_ENCODER_NAME @@ -254,9 +253,12 @@ def get_decoder_models_for_export( models_for_export = _get_submodels_for_export_decoder(model, use_past=config.use_past, legacy=legacy) - onnx_kwargs = {"task": config.task, "float_dtype": config.float_dtype, "int_dtype": config.int_dtype} - if model.config.model_type.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: - onnx_kwargs["no_position_ids"] = config.no_position_ids + onnx_kwargs = { + "task": config.task, + "float_dtype": config.float_dtype, + "int_dtype": config.int_dtype, + "legacy": legacy, + } if legacy: onnx_config = config.__class__( @@ -389,14 +391,14 @@ def get_sam_models_for_export(model: Union["PreTrainedModel", "TFPreTrainedModel models_for_export = _get_submodels_for_export_sam(model, config.variant) if config.variant == "monolith": - onnx_config = config.__class__(model.config, task=config.task) + onnx_config = config.__class__(model.config, task=config.task, legacy=config.legacy) models_for_export["model"] = (models_for_export["model"], onnx_config) else: vision_encoder_onnx_config = config.__class__( - model.config, task=config.task, variant=config.variant, vision_encoder=True + model.config, task=config.task, variant=config.variant, vision_encoder=True, legacy=config.legacy ) prompt_encoder_mask_decoder_onnx_config = config.__class__( - model.config, task=config.task, variant=config.variant, vision_encoder=False + model.config, task=config.task, variant=config.variant, vision_encoder=False, legacy=config.legacy ) models_for_export["vision_encoder"] = (models_for_export["vision_encoder"], vision_encoder_onnx_config) models_for_export["prompt_encoder_mask_decoder"] = ( @@ -454,6 +456,7 @@ def get_speecht5_models_for_export( behavior=config._behavior, # Irrelevant here. preprocessors=config._preprocessors, is_postnet_and_vocoder=True, + legacy=config.legacy, ) postnet_and_vocoder_onnx_config.variant = config.variant models_for_export["decoder_postnet_and_vocoder"] = ( diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 13aef3546a5..94418a96afe 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -473,7 +473,7 @@ def _from_pretrained( if file_name == ONNX_DECODER_WITH_PAST_NAME and config.model_type in MODEL_TO_PATCH_FOR_PAST: raise ValueError( - f"{ONNX_DECODER_WITH_PAST_NAME} not supported for the following architecture : {', '.join(MODEL_TO_PATCH_FOR_PAST)}. Please re-export your model or set use_cache=False." + f"ONNX Runtime inference using {ONNX_DECODER_WITH_PAST_NAME} has been deprecated for {config.model_type} architecture. Please re-export your model with optimum>=1.14.0 or set use_cache=False. For details about the deprecation, please refer to https://github.com/huggingface/optimum/releases/tag/v1.14.0." ) regular_file_names = [] diff --git a/optimum/utils/modeling_utils.py b/optimum/utils/modeling_utils.py index 336ad31e5a7..dae5b5d633a 100644 --- a/optimum/utils/modeling_utils.py +++ b/optimum/utils/modeling_utils.py @@ -13,9 +13,6 @@ # limitations under the License. import functools -from typing import Tuple - -import torch MODEL_TO_PATCH_FOR_PAST = { @@ -55,164 +52,3 @@ def recurse_setattr(module, name, value): else: name, rest = name.split(".", 1) recurse_setattr(getattr(module, name), rest, value) - - -# Modified from transformers.models.bloom.modeling_bloom._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, - device: torch.device, - past_key_values_length: int, - dtype: torch.dtype = torch.bool, -) -> torch.BoolTensor: - """ - Make causal mask used for bi-directional self-attention. - """ - batch_size, target_length = input_ids_shape - mask = torch.zeros((target_length, target_length + past_key_values_length), dtype=dtype, device=device) - seq_ids = torch.arange(target_length, device=device) - - mask[:, past_key_values_length:] = ( - (seq_ids[:, None] < seq_ids[None, :]) * torch.finfo(dtype).min - if torch.is_floating_point(mask) - else seq_ids[:, None] < seq_ids[None, :] - ) - - return mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) - - -# NOTE: For MODEL_TO_PATCH_FOR_PAST architectures, when exporting the model with an input of sequence length of 1, the attention masks will be generated incorrectly for other sequence length -# https://github.com/huggingface/transformers/blob/0ee45906845c8d58b9bd2df5acd90e09b00047ff/src/transformers/models/bloom/modeling_bloom.py#L654 -# The method taking care of the decoder mask generation of the models from these architectures must be patched during export for sequence length of 1. - - -# Modified from transformers.models.bloom.modeling_bloom._prepare_attn_mask -def _prepare_attn_mask( - self, - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - past_key_values_length: int, -) -> torch.BoolTensor: - from transformers.models.bloom.modeling_bloom import _expand_mask - - # create causal mask - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - combined_attention_mask = None - device = attention_mask.device - _, src_length = input_shape - - combined_attention_mask = _make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] - expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - - -# Modified from transformers.models.llama.modeling_llama._prepare_decoder_attention_mask -def _prepare_decoder_attention_mask( - self, - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - inputs_embeds: torch.Tensor, - past_key_values_length: int, -): - from transformers.models.llama.modeling_llama import _expand_mask - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - - combined_attention_mask = _make_causal_mask( - input_shape, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - dtype=inputs_embeds.dtype, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - -# Modified from transformers.models.mistral.modeling_mistral._prepare_decoder_sliding_window_attention_mask -def _prepare_decoder_sliding_window_attention_mask( - self, - attention_mask: torch.Tensor, - input_shape: Tuple[int, int], - inputs_embeds: torch.Tensor, - past_key_values_length: int, - sliding_window: int, -): - from transformers.models.mistral.modeling_mistral import _expand_mask, _make_sliding_window_causal_mask - - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - - combined_attention_mask = _make_sliding_window_causal_mask( - input_shape, - device=inputs_embeds.device, - dtype=inputs_embeds.dtype, - past_key_values_length=past_key_values_length, - sliding_window=sliding_window, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - -def _falcon_prepare_attn_mask( - attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int -) -> torch.BoolTensor: - from transformers.models.falcon.modeling_falcon import ( - _expand_mask, - ) - - # NOTE: there is no "copied from" for falcon in transformers which makes no sense to me. - - # Create a causal mask - # The attention mask we receive as input should cover the whole extended sequence, including any past - # cache, so its shape should be [batch_size, seq_length + past_key_values_length] - # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length] - if input_shape[1] + past_key_values_length != attention_mask.shape[1]: - raise ValueError( - "Attention mask shape should be (batch_size, seq_length + past_key_values_length)" - f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length" - f" {past_key_values_length}." - ) - combined_attention_mask = None - device = attention_mask.device - _, seq_length = input_shape - - # if seq_length > 1: - # NOTE: we remove here the `if seq_length > 1` to allow to use a single decoder. - combined_attention_mask = _make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length] - expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask