diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index ca0f284f4c7449..fe5a880027fc60 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -461,7 +461,7 @@ class ViTMSNPreTrainedModel(PreTrainedModel): base_model_prefix = "vit" main_input_name = "pixel_values" supports_gradient_checkpointing = True - _no_split_modules = ["ViTMSNAttention"] + _no_split_modules = ["ViTMSNAttention", "ViTMSNSdpaAttention"] _supports_sdpa = True # todo: Resort to https://github.com/facebookresearch/msn/blob/main/src/deit.py#L200-#L211