diff --git a/docs/source/onnxruntime/usage_guides/optimization.mdx b/docs/source/onnxruntime/usage_guides/optimization.mdx index f3868b6a72d..5fff6fcf8b6 100644 --- a/docs/source/onnxruntime/usage_guides/optimization.mdx +++ b/docs/source/onnxruntime/usage_guides/optimization.mdx @@ -79,6 +79,12 @@ Here is a list of the possible optimizations you can enable: - Add Bias and Gelu / FastGelu fusion with `disable_bias_gelu_fusion=False`, - Gelu approximation with `enable_gelu_approximation=True`. + + +Attention fusion is designed for right-side padding for BERT-like architectures (eg. BERT, RoBERTa, VIT, etc.) and for left-side padding for generative models (GPT-like). If you are not following the convention, please set `use_raw_attention_mask=True` to avoid potential accuracy issues but sacrifice the performance. + + + While [`~onnxruntime.configuration.OptimizationConfig`] gives you full control on how to do optimization, it can be hard to know what to enable / disable. Instead, you can use [`~onnxruntime.configuration.AutoOptimizationConfig`] which provides four common optimization levels: - O1: basic general optimizations. - O2: basic and extended general optimizations, transformers-specific fusions. diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 2fa98eed1b6..1a9024db7a1 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -180,6 +180,59 @@ def random_int_tensor( else: return np.random.randint(min_value, high=max_value, size=shape, dtype=DTYPE_MAPPER.np(dtype)) + @staticmethod + @check_framework_is_available + def random_mask_tensor(shape: List[int], padding_side: str = "right", framework: str = "pt", dtype: str = "int64"): + """ + Generates a mask tensor either right or left padded. + + Args: + shape (`List[int]`): + The shape of the random tensor. + padding_side (`str`, defaults to "right"): + The side on which the padding is applied. + framework (`str`, defaults to `"pt"`): + The requested framework. + dtype (`str`, defaults to `"int64"`): + The dtype of the generated integer tensor. Could be "int64", "int32", "int8". + + Returns: + A random mask tensor either left padded or right padded in the requested framework. + """ + shape = tuple(shape) + mask_length = random.randint(1, shape[-1] - 1) + if framework == "pt": + mask_tensor = torch.cat( + [ + torch.ones(*shape[:-1], shape[-1] - mask_length, dtype=DTYPE_MAPPER.pt(dtype)), + torch.zeros(*shape[:-1], mask_length, dtype=DTYPE_MAPPER.pt(dtype)), + ], + dim=-1, + ) + if padding_side == "left": + mask_tensor = torch.flip(mask_tensor, [-1]) + elif framework == "tf": + mask_tensor = tf.concat( + [ + tf.ones((*shape[:-1], shape[-1] - mask_length), dtype=DTYPE_MAPPER.tf(dtype)), + tf.zeros((*shape[:-1], mask_length), dtype=DTYPE_MAPPER.tf(dtype)), + ], + axis=-1, + ) + if padding_side == "left": + mask_tensor = tf.reverse(mask_tensor, [-1]) + else: + mask_tensor = np.concatenate( + [ + np.ones((*shape[:-1], shape[-1] - mask_length), dtype=DTYPE_MAPPER.np(dtype)), + np.zeros((*shape[:-1], mask_length), dtype=DTYPE_MAPPER.np(dtype)), + ], + axis=-1, + ) + if padding_side == "left": + mask_tensor = np.flip(mask_tensor, [-1]) + return mask_tensor + @staticmethod @check_framework_is_available def random_float_tensor( @@ -344,6 +397,7 @@ def __init__( random_batch_size_range: Optional[Tuple[int, int]] = None, random_sequence_length_range: Optional[Tuple[int, int]] = None, random_num_choices_range: Optional[Tuple[int, int]] = None, + padding_side: str = "right", **kwargs, ): self.task = task @@ -363,14 +417,24 @@ def __init__( self.num_choices = random.randint(low, high) else: self.num_choices = num_choices + self.padding_side = padding_side - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ): min_value = 0 max_value = 2 if input_name != "input_ids" else self.vocab_size shape = [self.batch_size, self.sequence_length] if self.task == "multiple-choice": shape = [self.batch_size, self.num_choices, self.sequence_length] - return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype) + if "mask" in input_name: + return self.random_mask_tensor(shape, padding_side=self.padding_side, framework=framework, dtype=int_dtype) + else: + return self.random_int_tensor(shape, max_value, min_value=min_value, framework=framework, dtype=int_dtype) class DummyDecoderTextInputGenerator(DummyTextInputGenerator):