Skip to content

Commit

Permalink
Don't mutate; correct model attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Apr 8, 2024
1 parent f745417 commit 51eb6d3
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def __init__(self, config):
if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def __init__(self, config):
if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def __init__(self, config):
if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def __init__(self, config):
if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
Expand Down
11 changes: 8 additions & 3 deletions tests/models/conditional_detr/test_modeling_conditional_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,17 @@ def test_different_timm_backbone(self):
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
elif model_class.__name__ == "ConditionalDetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)

self.assertTrue(outputs)

# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3)

def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
11 changes: 8 additions & 3 deletions tests/models/deformable_detr/test_modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,12 +539,17 @@ def test_different_timm_backbone(self):
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 4)
elif model_class.__name__ == "ConditionalDetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.deformable_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 4)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 4)

self.assertTrue(outputs)

# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.intermediate_channel_sizes), 4)

def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
11 changes: 8 additions & 3 deletions tests/models/detr/test_modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,17 @@ def test_different_timm_backbone(self):
self.model_tester.num_labels + 1,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
elif model_class.__name__ == "DetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)

self.assertTrue(outputs)

# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3)

def test_greyscale_images(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,14 @@ def test_different_timm_backbone(self):
self.model_tester.num_labels + 1,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)

self.assertTrue(outputs)

# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3)

def test_greyscale_images(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

Expand Down

0 comments on commit 51eb6d3

Please sign in to comment.