From 4d524f28a91d5a90c9dac6c633231a15053ca66d Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Mon, 19 Feb 2024 15:54:44 +0000 Subject: [PATCH] Fix tests --- src/transformers/models/dpt/modeling_dpt.py | 2 +- .../models/conditional_detr/test_modeling_conditional_detr.py | 2 ++ tests/models/detr/test_modeling_detr.py | 1 - .../table_transformer/test_modeling_table_transformer.py | 4 +++- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 595712bb721a5e..73b32833249750 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -1079,7 +1079,7 @@ def __init__(self, config): super().__init__(config) self.backbone = None - if config.is_hybrid: + if config.is_hybrid or config.backbone_config is None: self.dpt = DPTModel(config, add_pooling_layer=False) else: self.backbone = load_backbone(config) diff --git a/tests/models/conditional_detr/test_modeling_conditional_detr.py b/tests/models/conditional_detr/test_modeling_conditional_detr.py index f297634a2e7553..cfa2b444861708 100644 --- a/tests/models/conditional_detr/test_modeling_conditional_detr.py +++ b/tests/models/conditional_detr/test_modeling_conditional_detr.py @@ -443,7 +443,9 @@ def test_different_timm_backbone(self): # let's pick a random timm backbone config.backbone = "tf_mobilenetv3_small_075" + config.backbone_config = None config.use_timm_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} for model_class in self.all_model_classes: model = model_class(config) diff --git a/tests/models/detr/test_modeling_detr.py b/tests/models/detr/test_modeling_detr.py index b90b34a24f343f..34e8c28b1d79f8 100644 --- a/tests/models/detr/test_modeling_detr.py +++ b/tests/models/detr/test_modeling_detr.py @@ -473,7 +473,6 @@ def test_greyscale_images(self): ) # let's set num_channels to 1 - config.num_channels = 1 config.backbone_config.num_channels = 1 for model_class in self.all_model_classes: diff --git a/tests/models/table_transformer/test_modeling_table_transformer.py b/tests/models/table_transformer/test_modeling_table_transformer.py index eb5e80c93886b9..a214f831aa0bd7 100644 --- a/tests/models/table_transformer/test_modeling_table_transformer.py +++ b/tests/models/table_transformer/test_modeling_table_transformer.py @@ -455,6 +455,9 @@ def test_different_timm_backbone(self): # let's pick a random timm backbone config.backbone = "tf_mobilenetv3_small_075" + config.backbone_config = None + config.use_timm_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} for model_class in self.all_model_classes: model = model_class(config) @@ -482,7 +485,6 @@ def test_greyscale_images(self): ) # let's set num_channels to 1 - config.num_channels = 1 config.backbone_config.num_channels = 1 for model_class in self.all_model_classes: