Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Feb 19, 2024
1 parent 0a38d16 commit 4d524f2
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ def __init__(self, config):
super().__init__(config)

self.backbone = None
if config.is_hybrid:
if config.is_hybrid or config.backbone_config is None:
self.dpt = DPTModel(config, add_pooling_layer=False)
else:
self.backbone = load_backbone(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,9 @@ def test_different_timm_backbone(self):

# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}

for model_class in self.all_model_classes:
model = model_class(config)
Expand Down
1 change: 0 additions & 1 deletion tests/models/detr/test_modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,6 @@ def test_greyscale_images(self):
)

# let's set num_channels to 1
config.num_channels = 1
config.backbone_config.num_channels = 1

for model_class in self.all_model_classes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ def test_different_timm_backbone(self):

# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}

for model_class in self.all_model_classes:
model = model_class(config)
Expand Down Expand Up @@ -482,7 +485,6 @@ def test_greyscale_images(self):
)

# let's set num_channels to 1
config.num_channels = 1
config.backbone_config.num_channels = 1

for model_class in self.all_model_classes:
Expand Down

0 comments on commit 4d524f2

Please sign in to comment.