diff --git a/docs/source/en/model_doc/beit.md b/docs/source/en/model_doc/beit.md index f7605ebcdf90d4..25b0eafb26a039 100644 --- a/docs/source/en/model_doc/beit.md +++ b/docs/source/en/model_doc/beit.md @@ -71,6 +71,43 @@ alt="drawing" width="600"/> BEiT pre-training. Taken from the original paper. +### Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +``` +from transformers import BeitForImageClassification +model = BeitForImageClassification.from_pretrained("microsoft/beit-base-patch16-224", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +On a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.5.1, OS Ubuntu 20.04) with `float16` and +`microsoft/beit-base-patch16-224` model, we saw the following improvements during training and inference: + +#### Training + +| num_training_steps | batch_size | image_size | is_cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) | +|--------------------|------------|--------------|---------|----------------------------|---------------------------|-------------|----------------------|--------------------|----------------| +| 50 | 2 | (1048, 640) | True | 0.984 | 0.746 | 31.975 | 6738.915 | 4319.886 | 55.998 | + +#### Inference + +| Image batch size | Eager (s/iter) | Eager CI, % | Eager memory (MB) | SDPA (s/iter) | SDPA CI, % | SDPA memory (MB) | SDPA speedup | SDPA memory saved (%) | +|-------------------:|-----------------:|:--------------|--------------------:|----------------:|:-------------|-------------------:|---------------:|----------------------:| +| 1 | 0.012 | ±0.3% | 3.76657e+08 | 0.011 | ±0.5% | 3.75739e+08 | 1.05 | 0.244 | +| 4 | 0.013 | ±0.1% | 4.03147e+08 | 0.011 | ±0.2% | 3.90554e+08 | 1.178 | 3.225 | +| 16 | 0.045 | ±0.1% | 4.96697e+08 | 0.035 | ±0.1% | 4.51232e+08 | 1.304 | 10.076 | +| 32 | 0.088 | ±0.1% | 6.24417e+08 | 0.066 | ±0.1% | 5.33488e+08 | 1.325 | 17.044 | + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with BEiT. diff --git a/docs/source/en/model_doc/data2vec.md b/docs/source/en/model_doc/data2vec.md index 517a51ce46a3a4..cb1dc675caa55e 100644 --- a/docs/source/en/model_doc/data2vec.md +++ b/docs/source/en/model_doc/data2vec.md @@ -48,6 +48,46 @@ The original code for vision can be found [here](https://github.com/facebookrese - For Data2VecText, preprocessing is identical to [`RobertaModel`], including tokenization. - For Data2VecVision, preprocessing is identical to [`BeitModel`], including feature extraction. +### Using Scaled Dot Product Attention (SDPA) + +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) +page for more information. + +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. + +The SDPA implementation is currently available for the Data2VecAudio and Data2VecVision models. + +``` +from transformers import Data2VecVisionForImageClassification +model = Data2VecVisionForImageClassification.from_pretrained("facebook/data2vec-vision-base", attn_implementation="sdpa", torch_dtype=torch.float16) +... +``` + +For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`). + +For the Data2VecVision model, on a local benchmark (NVIDIA GeForce RTX 2060-8GB, PyTorch 2.5.1, OS Ubuntu 20.04) +with `float16` and `facebook/data2vec-vision-base` model, we saw the following improvements during training and +inference: + +#### Training + +| num_training_steps | batch_size | image_size | is_cuda | Time per batch (eager - s) | Time per batch (sdpa - s) | Speedup (%) | Eager peak mem (MB) | SDPA peak mem (MB) | Mem saving (%) | +|--------------------|------------|--------------|---------|----------------------------|---------------------------|-------------|----------------------|--------------------|----------------| +| 50 | 2 | (1048, 640) | True | 0.996 | 0.754 | 32.147 | 6722.198 | 4264.653 | 57.626 | + +#### Inference + +| Image batch size | Eager (s/iter) | Eager CI, % | Eager memory (MB) | SDPA (s/iter) | SDPA CI, % | SDPA memory (MB) | SDPA speedup | SDPA memory saved | +|-------------------:|-----------------:|:--------------|--------------------:|----------------:|:-------------|-------------------:|---------------:|--------------------:| +| 1 | 0.011 | ±0.3% | 3.76143e+08 | 0.01 | ±0.3% | 3.74397e+08 | 1.101 | 0.466 | +| 4 | 0.014 | ±0.1% | 4.02756e+08 | 0.012 | ±0.2% | 3.91373e+08 | 1.219 | 2.909 | +| 16 | 0.046 | ±0.3% | 4.96482e+08 | 0.035 | ±0.2% | 4.51017e+08 | 1.314 | 10.081 | +| 32 | 0.088 | ±0.1% | 6.23903e+08 | 0.067 | ±0.1% | 5.32974e+08 | 1.33 | 17.061 | + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Data2Vec. diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 12f492ff29a5ee..493e73e0d4430e 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -218,6 +218,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Albert](https://huggingface.co/docs/transformers/model_doc/albert#transformers.AlbertModel) * [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel) * [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) +* [Beit](https://huggingface.co/docs/transformers/model_doc/beit#transformers.BeitModel) * [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel) * [BioGpt](https://huggingface.co/docs/transformers/model_doc/biogpt#transformers.BioGptModel) * [CamemBERT](https://huggingface.co/docs/transformers/model_doc/camembert#transformers.CamembertModel) @@ -226,6 +227,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [GLM](https://huggingface.co/docs/transformers/model_doc/glm#transformers.GLMModel) * [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) * [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) +* [data2vec_vision](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecVisionModel) * [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) * [DeiT](https://huggingface.co/docs/transformers/model_doc/deit#transformers.DeiTModel) * [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index f972e021f3e2b3..46478f272c7513 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -361,6 +361,68 @@ def forward( return outputs +class BeitSdpaSelfAttention(BeitSelfAttention): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["BeitRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[Tuple[int]] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`BeitSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not " + "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) + + mixed_query_layer = self.query(hidden_states) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + attn_bias = None + if self.relative_position_bias is not None: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) + attn_bias = self.relative_position_bias( + window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + if attn_bias is None: + attn_bias = relative_position_bias + else: + attn_bias += relative_position_bias + + scaling = 1 / math.sqrt(self.attention_head_size) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attn_bias, + dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=scaling, + ) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer, None + + class BeitSelfOutput(nn.Module): """ The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the @@ -379,10 +441,16 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma return hidden_states +BEIT_SELF_ATTENTION_CLASSES = { + "eager": BeitSelfAttention, + "sdpa": BeitSdpaSelfAttention, +} + + class BeitAttention(nn.Module): def __init__(self, config: BeitConfig, window_size: Optional[tuple] = None) -> None: super().__init__() - self.attention = BeitSelfAttention(config, window_size=window_size) + self.attention = BEIT_SELF_ATTENTION_CLASSES[config._attn_implementation](config, window_size=window_size) self.output = BeitSelfOutput(config) self.pruned_heads = set() @@ -700,6 +768,7 @@ class BeitPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["BeitLayer"] _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"] + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 4d252ce1f19db7..770162285bf33b 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -362,6 +362,69 @@ def forward( return outputs +# Copied from transformers.models.beit.modeling_beit.BeitSdpaSelfAttention with Beit->Data2VecVision +class Data2VecVisionSdpaSelfAttention(Data2VecVisionSelfAttention): + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None, + interpolate_pos_encoding: bool = False, + resolution: Optional[Tuple[int]] = None, + ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`Data2VecVisionSdpaSelfAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not " + "support `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, " + "but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + relative_position_bias=relative_position_bias, + interpolate_pos_encoding=interpolate_pos_encoding, + resolution=resolution, + ) + + mixed_query_layer = self.query(hidden_states) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + attn_bias = None + if self.relative_position_bias is not None: + height, width = resolution + window_size = (height // self.config.patch_size, width // self.config.patch_size) + attn_bias = self.relative_position_bias( + window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1] + ) + + # Add shared relative position bias if provided. + if relative_position_bias is not None: + if attn_bias is None: + attn_bias = relative_position_bias + else: + attn_bias += relative_position_bias + + scaling = 1 / math.sqrt(self.attention_head_size) + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attn_bias, + dropout_p=self.config.attention_probs_dropout_prob if self.training else 0.0, + is_causal=False, + scale=scaling, + ) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer, None + + # Copied from transformers.models.beit.modeling_beit.BeitSelfOutput with Beit->Data2VecVision class Data2VecVisionSelfOutput(nn.Module): """ @@ -381,11 +444,19 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma return hidden_states -# Copied from transformers.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision +DATA2VEC_VISION_SELF_ATTENTION_CLASSES = { + "eager": Data2VecVisionSelfAttention, + "sdpa": Data2VecVisionSdpaSelfAttention, +} + + +# Copied from tests.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision, BEIT->DATA2VEC_VISION class Data2VecVisionAttention(nn.Module): def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None: super().__init__() - self.attention = Data2VecVisionSelfAttention(config, window_size=window_size) + self.attention = DATA2VEC_VISION_SELF_ATTENTION_CLASSES[config._attn_implementation]( + config, window_size=window_size + ) self.output = Data2VecVisionSelfOutput(config) self.pruned_heads = set() @@ -711,6 +782,7 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["Data2VecVisionLayer"] _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"] + _supports_sdpa = True def _init_weights(self, module): """Initialize the weights""" diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index ac64f0fd3b0b11..e54273f7839965 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -14,18 +14,35 @@ # limitations under the License. """Testing suite for the PyTorch BEiT model.""" +import inspect +import tempfile import unittest +import numpy as np from datasets import load_dataset from packaging import version +from parameterized import parameterized from transformers import BeitConfig -from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device -from transformers.utils import cached_property, is_torch_available, is_vision_available +from transformers.testing_utils import ( + require_torch, + require_torch_multi_gpu, + require_torch_sdpa, + require_vision, + slow, + torch_device, +) +from transformers.utils import ( + cached_property, + is_torch_available, + is_torch_bf16_available_on_device, + is_torch_fp16_available_on_device, + is_vision_available, +) from ...test_backbone_common import BackboneTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel from ...test_pipeline_mixin import PipelineTesterMixin @@ -74,6 +91,8 @@ def __init__( scope=None, out_indices=[1, 2, 3, 4], out_features=["stage1", "stage2", "stage3", "stage4"], + attn_implementation="eager", + mask_ratio=0.5, ): self.parent = parent self.vocab_size = vocab_size @@ -100,6 +119,8 @@ def __init__( # in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 self.seq_length = num_patches + 1 + self.num_masks = int(mask_ratio * self.seq_length) + self.attn_implementation = attn_implementation def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -131,6 +152,7 @@ def get_config(self): initializer_range=self.initializer_range, out_indices=self.out_indices, out_features=self.out_features, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels, pixel_labels): @@ -387,6 +409,193 @@ def test_model_from_pretrained(self): model = BeitModel.from_pretrained(model_name) self.assertIsNotNone(model) + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + # The common test modifies the num_hidden_layers to be 1. However, for Beit we want to + # avoid that because the num_hidden_layers is generally assumed to be 4. Also, the code + # related to attention masks in the original common tests is not required as the Beit + # model does not handle attention masks. Furthermore, some extra code like modifying + # the norm layers eps values for specialized configs and checking for the 'noise' + # has been omitted to simply the test. + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 + + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + config.rms_norm_eps = 1.0 + config.layer_norm_eps = 1.0 + config.norm_eps = 1.0 + config.norm_epsilon = 1.0 + config.layer_norm_epsilon = 1.0 + + model = model_class(config) + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype, use_mask_token=True) + model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + use_mask_token=True, + ) + model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype) + + # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.) + for x in model_eager.modules(): + if isinstance(x, (nn.LayerNorm, nn.GroupNorm)): + x.eps = 1.0 + for x in model_sdpa.modules(): + if isinstance(x, (nn.LayerNorm, nn.GroupNorm)): + x.eps = 1.0 + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + for output_attentions in [True, False]: + can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters + if not (self.has_attentions and can_output_attn) and output_attentions: + continue + # TODO: if we can also check with `batch_size=1` without being flaky? + for batch_size in [7]: + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + dummy_input = dummy_input[:batch_size] + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}" + processed_inputs = { + model.main_input_name: dummy_input, + "output_hidden_states": True, + } + + if ( + self.has_attentions + and "output_attentions" in inspect.signature(model_sdpa.forward).parameters + ): + processed_inputs["output_attentions"] = output_attentions + + if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters: + dummy_mask = torch.ones((self.model_tester.num_masks,)) + mask_length = self.model_tester.seq_length - 1 - dummy_mask.size(0) + dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) + dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() + processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) + + with torch.no_grad(): + with sdpa_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + + logits_eager = outputs_eager.hidden_states[-1] + logits_sdpa = outputs_sdpa.hidden_states[-1] + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + _logits_sdpa = torch.zeros_like(input=logits_sdpa) + _logits_eager = torch.zeros_like(input=logits_eager) + + _logits_sdpa[:-1] = logits_sdpa[:-1] + _logits_eager[:-1] = logits_eager[:-1] + + if padding_side == "left": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] + _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] + + elif padding_side == "right": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] + _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] + + logits_sdpa = _logits_sdpa + logits_eager = _logits_eager + + results = [ + torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) + for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) + ] + # If 80% batch elements have matched results, it's fine + if np.mean(results) < 0.8: + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/data2vec/test_modeling_data2vec_vision.py b/tests/models/data2vec/test_modeling_data2vec_vision.py index c729d88d614fbc..02276d905fa402 100644 --- a/tests/models/data2vec/test_modeling_data2vec_vision.py +++ b/tests/models/data2vec/test_modeling_data2vec_vision.py @@ -14,14 +14,32 @@ # limitations under the License. """Testing suite for the PyTorch Data2VecVision model.""" +import inspect +import tempfile import unittest +import numpy as np +from parameterized import parameterized + from transformers import Data2VecVisionConfig -from transformers.testing_utils import require_torch, require_torch_multi_gpu, require_vision, slow, torch_device -from transformers.utils import cached_property, is_torch_available, is_vision_available +from transformers.testing_utils import ( + require_torch, + require_torch_multi_gpu, + require_torch_sdpa, + require_vision, + slow, + torch_device, +) +from transformers.utils import ( + cached_property, + is_torch_available, + is_torch_bf16_available_on_device, + is_torch_fp16_available_on_device, + is_vision_available, +) from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor +from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel from ...test_pipeline_mixin import PipelineTesterMixin @@ -66,6 +84,8 @@ def __init__( num_labels=3, scope=None, out_indices=[0, 1, 2, 3], + attn_implementation="eager", + mask_ratio=0.5, ): self.parent = parent self.vocab_size = 100 @@ -91,6 +111,8 @@ def __init__( # in BeiT, the seq length equals the number of patches + 1 (we add 1 for the [CLS] token) num_patches = (image_size // patch_size) ** 2 self.seq_length = num_patches + 1 + self.num_masks = int(mask_ratio * self.seq_length) + self.attn_implementation = attn_implementation def prepare_config_and_inputs(self): pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size]) @@ -121,6 +143,7 @@ def get_config(self): is_decoder=False, initializer_range=self.initializer_range, out_indices=self.out_indices, + attn_implementation=self.attn_implementation, ) def create_and_check_model(self, config, pixel_values, labels, pixel_labels): @@ -300,6 +323,194 @@ def test_model_from_pretrained(self): model = Data2VecVisionModel.from_pretrained(model_name) self.assertIsNotNone(model) + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) + @require_torch_sdpa + # Copied from tests.models.beit.test_modeling_beit.BeitModelTest.test_eager_matches_sdpa_inference with Beit->Data2VecVision + def test_eager_matches_sdpa_inference(self, torch_dtype: str): + # The common test modifies the num_hidden_layers to be 1. However, for Data2VecVision we want to + # avoid that because the num_hidden_layers is generally assumed to be 4. Also, the code + # related to attention masks in the original common tests is not required as the Data2VecVision + # model does not handle attention masks. Furthermore, some extra code like modifying + # the norm layers eps values for specialized configs and checking for the 'noise' + # has been omitted to simply the test. + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + if torch_dtype == "float16" and not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + if torch_dtype == "bfloat16" and not is_torch_bf16_available_on_device(torch_device): + self.skipTest( + f"bfloat16 not supported on {torch_device} (on the specific device currently used, e.g. Nvidia T4 GPU)" + ) + + # Not sure whether it's fine to put torch.XXX in a decorator if torch is not available so hacking it here instead. + if torch_dtype == "float16": + torch_dtype = torch.float16 + elif torch_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + elif torch_dtype == "float32": + torch_dtype = torch.float32 + + atols = { + ("cpu", False, torch.float32): 1e-6, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-6, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-6, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-6, + ("cuda", True, torch.bfloat16): 1e-2, + ("cuda", True, torch.float16): 5e-3, + } + rtols = { + ("cpu", False, torch.float32): 1e-4, + ("cpu", False, torch.float16): 5e-3, + ("cpu", False, torch.bfloat16): 1e-2, + ("cpu", True, torch.float32): 1e-4, + ("cpu", True, torch.float16): 5e-3, + ("cpu", True, torch.bfloat16): 1e-2, + ("cuda", False, torch.float32): 1e-4, + ("cuda", False, torch.bfloat16): 1e-2, + ("cuda", False, torch.float16): 5e-3, + ("cuda", True, torch.float32): 1e-4, + ("cuda", True, torch.bfloat16): 3e-2, + ("cuda", True, torch.float16): 5e-3, + } + + def get_mean_reldiff(failcase, x, ref, atol, rtol): + return f"{failcase}: mean relative difference: {((x - ref).abs() / (ref.abs() + 1e-12)).mean():.3e}, torch atol = {atol}, torch rtol = {rtol}" + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + config.rms_norm_eps = 1.0 + config.layer_norm_eps = 1.0 + config.norm_eps = 1.0 + config.norm_epsilon = 1.0 + config.layer_norm_epsilon = 1.0 + + model = model_class(config) + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype, use_mask_token=True) + model_sdpa = model_sdpa.eval().to(torch_device, dtype=torch_dtype) + + model_eager = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch_dtype, + attn_implementation="eager", + use_mask_token=True, + ) + model_eager = model_eager.eval().to(torch_device, dtype=torch_dtype) + + # Another way to make sure norm layers have desired epsilon. (Some models don't set it from its config.) + for x in model_eager.modules(): + if isinstance(x, (nn.LayerNorm, nn.GroupNorm)): + x.eps = 1.0 + for x in model_sdpa.modules(): + if isinstance(x, (nn.LayerNorm, nn.GroupNorm)): + x.eps = 1.0 + + # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model, + # but it would be nicer to have an efficient way to use parameterized.expand + fail_cases = [] + for padding_side in ["left", "right"]: + for use_mask in [False, True]: + for output_attentions in [True, False]: + can_output_attn = "output_attentions" in inspect.signature(model_sdpa.forward).parameters + if not (self.has_attentions and can_output_attn) and output_attentions: + continue + # TODO: if we can also check with `batch_size=1` without being flaky? + for batch_size in [7]: + dummy_input = inputs_dict[model.main_input_name] + + if dummy_input.dtype in [torch.float32, torch.bfloat16, torch.float16]: + dummy_input = dummy_input.to(torch_dtype) + + dummy_input = dummy_input[:batch_size] + for enable_kernels in [False, True]: + failcase = f"padding_side={padding_side}, use_mask={use_mask}, enable_kernels={enable_kernels}" + processed_inputs = { + model.main_input_name: dummy_input, + "output_hidden_states": True, + } + + if ( + self.has_attentions + and "output_attentions" in inspect.signature(model_sdpa.forward).parameters + ): + processed_inputs["output_attentions"] = output_attentions + + if "bool_masked_pos" in inspect.signature(model_eager.forward).parameters: + dummy_mask = torch.ones((self.model_tester.num_masks,)) + mask_length = self.model_tester.seq_length - 1 - dummy_mask.size(0) + dummy_mask = torch.cat([dummy_mask, torch.zeros(mask_length)]) + dummy_bool_masked_pos = dummy_mask.expand(batch_size, -1).bool() + processed_inputs["bool_masked_pos"] = dummy_bool_masked_pos.to(torch_device) + + with torch.no_grad(): + with sdpa_kernel( + enable_flash=enable_kernels, + enable_math=True, + enable_mem_efficient=enable_kernels, + ): + prepared_inputs = self._prepare_for_class(processed_inputs, model_class) + outputs_eager = model_eager(**prepared_inputs) + outputs_sdpa = model_sdpa(**prepared_inputs) + + logits_eager = outputs_eager.hidden_states[-1] + logits_sdpa = outputs_sdpa.hidden_states[-1] + if torch_device in ["cpu", "cuda"]: + atol = atols[torch_device, enable_kernels, torch_dtype] + rtol = rtols[torch_device, enable_kernels, torch_dtype] + elif torch_device == "xpu": + # As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH + # which is implemented on PyTorch level using aten operators and is + # device agnostic with respect to implementation of each aten operator. + atol = atols["cuda", False, torch_dtype] + rtol = rtols["cuda", False, torch_dtype] + else: + atol = 1e-7 + rtol = 1e-4 + + # Masked tokens output slightly deviates - we don't mind that. + if use_mask: + _logits_sdpa = torch.zeros_like(input=logits_sdpa) + _logits_eager = torch.zeros_like(input=logits_eager) + + _logits_sdpa[:-1] = logits_sdpa[:-1] + _logits_eager[:-1] = logits_eager[:-1] + + if padding_side == "left": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, 2:] + _logits_eager[-1:, 2:] = logits_eager[-1:, 2:] + + elif padding_side == "right": + _logits_sdpa[-1:, 2:] = logits_sdpa[-1:, :-2] + _logits_eager[-1:, 2:] = logits_eager[-1:, :-2] + + logits_sdpa = _logits_sdpa + logits_eager = _logits_eager + + results = [ + torch.allclose(_logits_sdpa, _logits_eager, atol=atol, rtol=rtol) + for (_logits_sdpa, _logits_eager) in zip(logits_sdpa, logits_eager) + ] + # If 80% batch elements have matched results, it's fine + if np.mean(results) < 0.8: + fail_cases.append( + get_mean_reldiff(failcase, logits_sdpa, logits_eager, atol, rtol) + ) + + self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) + # We will verify our results on an image of cute cats def prepare_img():