From 8b02bb6e748842cfb74ec9cba808201b893fa0af Mon Sep 17 00:00:00 2001 From: Jacky Lee <39754370+jla524@users.noreply.github.com> Date: Mon, 22 Apr 2024 02:57:27 -0700 Subject: [PATCH] Enable multi-device for more models (#30379) * feat: support for vitmatte * feat: support for vivit * feat: support for beit * feat: support for blip :D * feat: support for data2vec --- src/transformers/models/beit/modeling_beit.py | 1 + src/transformers/models/blip/modeling_blip_text.py | 1 + src/transformers/models/data2vec/modeling_data2vec_vision.py | 1 + src/transformers/models/vitmatte/modeling_vitmatte.py | 1 + src/transformers/models/vivit/modeling_vivit.py | 1 + 5 files changed, 5 insertions(+) diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index d04717039ec909..c23d4f4ea4cdee 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -563,6 +563,7 @@ class BeitPreTrainedModel(PreTrainedModel): base_model_prefix = "beit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["BeitLayer"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 808c33f8104fc1..3eb6ad45791030 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -549,6 +549,7 @@ class BlipTextPreTrainedModel(PreTrainedModel): config_class = BlipTextConfig base_model_prefix = "bert" + _no_split_modules = [] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/data2vec/modeling_data2vec_vision.py b/src/transformers/models/data2vec/modeling_data2vec_vision.py index 44088d498f6035..c7f4f6390aad64 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_vision.py +++ b/src/transformers/models/data2vec/modeling_data2vec_vision.py @@ -574,6 +574,7 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel): base_model_prefix = "data2vec_vision" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = ["Data2VecVisionLayer"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/vitmatte/modeling_vitmatte.py b/src/transformers/models/vitmatte/modeling_vitmatte.py index f371c608607a5f..4d204a8e563a8d 100644 --- a/src/transformers/models/vitmatte/modeling_vitmatte.py +++ b/src/transformers/models/vitmatte/modeling_vitmatte.py @@ -73,6 +73,7 @@ class VitMattePreTrainedModel(PreTrainedModel): config_class = VitMatteConfig main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = [] def _init_weights(self, module): if isinstance(module, nn.Conv2d): diff --git a/src/transformers/models/vivit/modeling_vivit.py b/src/transformers/models/vivit/modeling_vivit.py index 08efb85e1f0254..aa962373568aee 100755 --- a/src/transformers/models/vivit/modeling_vivit.py +++ b/src/transformers/models/vivit/modeling_vivit.py @@ -387,6 +387,7 @@ class VivitPreTrainedModel(PreTrainedModel): base_model_prefix = "vivit" main_input_name = "pixel_values" supports_gradient_checkpointing = True + _no_split_modules = [] def _init_weights(self, module): """Initialize the weights"""