From 0548af54ccc81d31e6264bb394e74c89c477518d Mon Sep 17 00:00:00 2001 From: Nate Cibik <50897218+FoamoftheSea@users.noreply.github.com> Date: Mon, 29 Jan 2024 02:10:40 -0800 Subject: [PATCH] Enable Gradient Checkpointing in Deformable DETR (#28686) * Enabled gradient checkpointing in Deformable DETR * Enabled gradient checkpointing in Deformable DETR encoder * Removed # Copied from headers in modeling_deta.py to break dependence on Deformable DETR code --- .../modeling_deformable_detr.py | 38 ++++++++++++++----- src/transformers/models/deta/modeling_deta.py | 3 -- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 3767eef0392f6a..abd431dcd81460 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -1048,6 +1048,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): config_class = DeformableDetrConfig base_model_prefix = "model" main_input_name = "pixel_values" + supports_gradient_checkpointing = True _no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"] def _init_weights(self, module): @@ -1143,6 +1144,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel): def __init__(self, config: DeformableDetrConfig): super().__init__(config) + self.gradient_checkpointing = False self.dropout = config.dropout self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)]) @@ -1235,15 +1237,27 @@ def forward( for i, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - position_embeddings=position_embeddings, - reference_points=reference_points, - spatial_shapes=spatial_shapes, - level_start_index=level_start_index, - output_attentions=output_attentions, - ) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + position_embeddings, + reference_points, + spatial_shapes, + level_start_index, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] @@ -1368,9 +1382,13 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, + position_embeddings, + reference_points_input, + spatial_shapes, + level_start_index, encoder_hidden_states, encoder_attention_mask, - None, + output_attentions, ) else: layer_outputs = decoder_layer( diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index 330ccfe3f0c389..7ffe4485110ad5 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -942,7 +942,6 @@ def forward(self, hidden_states: torch.Tensor): return hidden_states -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetrConvEncoder->DetaBackboneWithPositionalEncodings,DeformableDetr->Deta class DetaPreTrainedModel(PreTrainedModel): config_class = DetaConfig base_model_prefix = "model" @@ -1028,7 +1027,6 @@ def _init_weights(self, module): """ -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetr->Deta class DetaEncoder(DetaPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a @@ -1159,7 +1157,6 @@ def forward( ) -# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrDecoder with DeformableDetr->Deta,Deformable DETR->DETA class DetaDecoder(DetaPreTrainedModel): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetaDecoderLayer`].