Skip to content

Commit

Permalink
Enable Gradient Checkpointing in Deformable DETR (#28686)
Browse files Browse the repository at this point in the history
* Enabled gradient checkpointing in Deformable DETR

* Enabled gradient checkpointing in Deformable DETR encoder

* Removed # Copied from headers in modeling_deta.py to break dependence on Deformable DETR code
  • Loading branch information
FoamoftheSea authored Jan 29, 2024
1 parent f72c7c2 commit 0548af5
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
38 changes: 28 additions & 10 deletions src/transformers/models/deformable_detr/modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,7 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
config_class = DeformableDetrConfig
base_model_prefix = "model"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = [r"DeformableDetrConvEncoder", r"DeformableDetrEncoderLayer", r"DeformableDetrDecoderLayer"]

def _init_weights(self, module):
Expand Down Expand Up @@ -1143,6 +1144,7 @@ class DeformableDetrEncoder(DeformableDetrPreTrainedModel):

def __init__(self, config: DeformableDetrConfig):
super().__init__(config)
self.gradient_checkpointing = False

self.dropout = config.dropout
self.layers = nn.ModuleList([DeformableDetrEncoderLayer(config) for _ in range(config.encoder_layers)])
Expand Down Expand Up @@ -1235,15 +1237,27 @@ def forward(
for i, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
position_embeddings,
reference_points,
spatial_shapes,
level_start_index,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]

Expand Down Expand Up @@ -1368,9 +1382,13 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
position_embeddings,
reference_points_input,
spatial_shapes,
level_start_index,
encoder_hidden_states,
encoder_attention_mask,
None,
output_attentions,
)
else:
layer_outputs = decoder_layer(
Expand Down
3 changes: 0 additions & 3 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,6 @@ def forward(self, hidden_states: torch.Tensor):
return hidden_states


# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrPreTrainedModel with DeformableDetrConvEncoder->DetaBackboneWithPositionalEncodings,DeformableDetr->Deta
class DetaPreTrainedModel(PreTrainedModel):
config_class = DetaConfig
base_model_prefix = "model"
Expand Down Expand Up @@ -1028,7 +1027,6 @@ def _init_weights(self, module):
"""


# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrEncoder with DeformableDetr->Deta
class DetaEncoder(DetaPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* deformable attention layers. Each layer is a
Expand Down Expand Up @@ -1159,7 +1157,6 @@ def forward(
)


# Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrDecoder with DeformableDetr->Deta,Deformable DETR->DETA
class DetaDecoder(DetaPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetaDecoderLayer`].
Expand Down

0 comments on commit 0548af5

Please sign in to comment.