diff --git a/src/transformers/models/conditional_detr/configuration_conditional_detr.py b/src/transformers/models/conditional_detr/configuration_conditional_detr.py index 945e5edb32ad30..4f95de3582f082 100644 --- a/src/transformers/models/conditional_detr/configuration_conditional_detr.py +++ b/src/transformers/models/conditional_detr/configuration_conditional_detr.py @@ -192,10 +192,16 @@ def __init__( if backbone_config is not None and use_timm_backbone: raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.") - if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: - raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") - - if not use_timm_backbone: + # We default to values which were previously hard-coded in the model. This enables configurability of the config + # while keeping the default behavior the same. + if use_timm_backbone and backbone_kwargs is None: + backbone_kwargs = {} + if dilation: + backbone_kwargs["output_stride"] = 16 + backbone_kwargs["out_indices"] = [1, 2, 3, 4] + backbone_kwargs["in_chans"] = num_channels + # Backwards compatibility + elif not use_timm_backbone and backbone in (None, "resnet50"): if backbone_config is None: logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index d8ff371fad77d1..d723d3866ea416 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -338,12 +338,12 @@ def replace_batch_norm(model): replace_batch_norm(module) -# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder +# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->ConditionalDetr class ConditionalDetrConvEncoder(nn.Module): """ Convolutional backbone, using either the AutoBackbone API or one from the timm library. - nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above. + nn.BatchNorm2d layers are replaced by ConditionalDetrFrozenBatchNorm2d as defined above. """ @@ -352,17 +352,23 @@ def __init__(self, config): self.config = config + # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API if config.use_timm_backbone: + # We default to values which were previously hard-coded. This enables configurability from the config + # using backbone arguments, while keeping the default behavior the same. requires_backends(self, ["timm"]) - kwargs = {} + kwargs = getattr(config, "backbone_kwargs", {}) + kwargs = {} if kwargs is None else kwargs.copy() + out_indices = kwargs.pop("out_indices", (1, 2, 3, 4)) + num_channels = kwargs.pop("in_chans", config.num_channels) if config.dilation: - kwargs["output_stride"] = 16 + kwargs["output_stride"] = kwargs.get("output_stride", 16) backbone = create_model( config.backbone, pretrained=config.use_pretrained_backbone, features_only=True, - out_indices=(1, 2, 3, 4), - in_chans=config.num_channels, + out_indices=out_indices, + in_chans=num_channels, **kwargs, ) else: diff --git a/src/transformers/models/deformable_detr/configuration_deformable_detr.py b/src/transformers/models/deformable_detr/configuration_deformable_detr.py index 6d32f6220df586..3f3ffff69ff2e9 100644 --- a/src/transformers/models/deformable_detr/configuration_deformable_detr.py +++ b/src/transformers/models/deformable_detr/configuration_deformable_detr.py @@ -212,7 +212,16 @@ def __init__( if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") - if not use_timm_backbone: + # We default to values which were previously hard-coded in the model. This enables configurability of the config + # while keeping the default behavior the same. + if use_timm_backbone and backbone_kwargs is None: + backbone_kwargs = {} + if dilation: + backbone_kwargs["output_stride"] = 16 + backbone_kwargs["out_indices"] = [2, 3, 4] if num_feature_levels > 1 else [4] + backbone_kwargs["in_chans"] = num_channels + # Backwards compatibility + elif not use_timm_backbone and backbone in (None, "resnet50"): if backbone_config is None: logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) @@ -220,6 +229,7 @@ def __init__( backbone_model_type = backbone_config.get("model_type") config_class = CONFIG_MAPPING[backbone_model_type] backbone_config = config_class.from_dict(backbone_config) + self.use_timm_backbone = use_timm_backbone self.backbone_config = backbone_config self.num_channels = num_channels diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index c0ac7cffc7ab44..7b2bbb9b1242c9 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -88,11 +88,31 @@ def load_cuda_kernels(): if is_vision_available(): from transformers.image_transforms import center_to_corners_format + if is_accelerate_available(): from accelerate import PartialState from accelerate.utils import reduce +if is_timm_available(): + from timm import create_model + + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeformableDetrConfig" +_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr" + +DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "sensetime/deformable-detr", + # See all Deformable DETR models at https://huggingface.co/models?filter=deformable-detr +] + + class MultiScaleDeformableAttentionFunction(Function): @staticmethod def forward( @@ -141,21 +161,6 @@ def backward(context, grad_output): return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None -if is_scipy_available(): - from scipy.optimize import linear_sum_assignment - -if is_timm_available(): - from timm import create_model - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "DeformableDetrConfig" -_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr" - - -from ..deprecated._archive_maps import DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402 - - @dataclass class DeformableDetrDecoderOutput(ModelOutput): """ @@ -420,17 +425,23 @@ def __init__(self, config): self.config = config + # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API if config.use_timm_backbone: + # We default to values which were previously hard-coded. This enables configurability from the config + # using backbone arguments, while keeping the default behavior the same. requires_backends(self, ["timm"]) - kwargs = {} + kwargs = getattr(config, "backbone_kwargs", {}) + kwargs = {} if kwargs is None else kwargs.copy() + out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,)) + num_channels = kwargs.pop("in_chans", config.num_channels) if config.dilation: - kwargs["output_stride"] = 16 + kwargs["output_stride"] = kwargs.get("output_stride", 16) backbone = create_model( config.backbone, pretrained=config.use_pretrained_backbone, features_only=True, - out_indices=(2, 3, 4) if config.num_feature_levels > 1 else (4,), - in_chans=config.num_channels, + out_indices=out_indices, + in_chans=num_channels, **kwargs, ) else: diff --git a/src/transformers/models/detr/configuration_detr.py b/src/transformers/models/detr/configuration_detr.py index 9b9b5afacd0b7f..db180ef1d41fed 100644 --- a/src/transformers/models/detr/configuration_detr.py +++ b/src/transformers/models/detr/configuration_detr.py @@ -193,7 +193,16 @@ def __init__( if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") - if not use_timm_backbone: + # We default to values which were previously hard-coded in the model. This enables configurability of the config + # while keeping the default behavior the same. + if use_timm_backbone and backbone_kwargs is None: + backbone_kwargs = {} + if dilation: + backbone_kwargs["output_stride"] = 16 + backbone_kwargs["out_indices"] = [1, 2, 3, 4] + backbone_kwargs["in_chans"] = num_channels + # Backwards compatibility + elif not use_timm_backbone and backbone in (None, "resnet50"): if backbone_config is None: logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) @@ -201,8 +210,9 @@ def __init__( backbone_model_type = backbone_config.get("model_type") config_class = CONFIG_MAPPING[backbone_model_type] backbone_config = config_class.from_dict(backbone_config) + backbone = None # set timm attributes to None - dilation, backbone, use_pretrained_backbone = None, None, None + dilation = None self.use_timm_backbone = use_timm_backbone self.backbone_config = backbone_config diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index d7fcdfc5bc7e83..0da702db8b67e2 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -49,9 +49,11 @@ if is_scipy_available(): from scipy.optimize import linear_sum_assignment + if is_timm_available(): from timm import create_model + if is_vision_available(): from transformers.image_transforms import center_to_corners_format @@ -345,17 +347,23 @@ def __init__(self, config): self.config = config + # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API if config.use_timm_backbone: + # We default to values which were previously hard-coded. This enables configurability from the config + # using backbone arguments, while keeping the default behavior the same. requires_backends(self, ["timm"]) - kwargs = {} + kwargs = getattr(config, "backbone_kwargs", {}) + kwargs = {} if kwargs is None else kwargs.copy() + out_indices = kwargs.pop("out_indices", (1, 2, 3, 4)) + num_channels = kwargs.pop("in_chans", config.num_channels) if config.dilation: - kwargs["output_stride"] = 16 + kwargs["output_stride"] = kwargs.get("output_stride", 16) backbone = create_model( config.backbone, pretrained=config.use_pretrained_backbone, features_only=True, - out_indices=(1, 2, 3, 4), - in_chans=config.num_channels, + out_indices=out_indices, + in_chans=num_channels, **kwargs, ) else: diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index aad3330279f051..ef6c8bb853abda 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -1075,10 +1075,10 @@ def __init__(self, config): super().__init__(config) self.backbone = None - if config.backbone_config is not None and config.is_hybrid is False: - self.backbone = load_backbone(config) - else: + if config.is_hybrid or config.backbone_config is None: self.dpt = DPTModel(config, add_pooling_layer=False) + else: + self.backbone = load_backbone(config) # Neck self.neck = DPTNeck(config) diff --git a/src/transformers/models/table_transformer/configuration_table_transformer.py b/src/transformers/models/table_transformer/configuration_table_transformer.py index 9a2ff6bbab3b24..4963396024a57e 100644 --- a/src/transformers/models/table_transformer/configuration_table_transformer.py +++ b/src/transformers/models/table_transformer/configuration_table_transformer.py @@ -193,7 +193,16 @@ def __init__( if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None: raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") - if not use_timm_backbone: + # We default to values which were previously hard-coded in the model. This enables configurability of the config + # while keeping the default behavior the same. + if use_timm_backbone and backbone_kwargs is None: + backbone_kwargs = {} + if dilation: + backbone_kwargs["output_stride"] = 16 + backbone_kwargs["out_indices"] = [1, 2, 3, 4] + backbone_kwargs["in_chans"] = num_channels + # Backwards compatibility + elif not use_timm_backbone and backbone in (None, "resnet50"): if backbone_config is None: logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.") backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"]) @@ -201,8 +210,9 @@ def __init__( backbone_model_type = backbone_config.get("model_type") config_class = CONFIG_MAPPING[backbone_model_type] backbone_config = config_class.from_dict(backbone_config) + backbone = None # set timm attributes to None - dilation, backbone, use_pretrained_backbone = None, None, None + dilation = None self.use_timm_backbone = use_timm_backbone self.backbone_config = backbone_config diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index 8e577a65a5fe00..9a684ee121ddca 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -279,17 +279,23 @@ def __init__(self, config): self.config = config + # For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API if config.use_timm_backbone: + # We default to values which were previously hard-coded. This enables configurability from the config + # using backbone arguments, while keeping the default behavior the same. requires_backends(self, ["timm"]) - kwargs = {} + kwargs = getattr(config, "backbone_kwargs", {}) + kwargs = {} if kwargs is None else kwargs.copy() + out_indices = kwargs.pop("out_indices", (1, 2, 3, 4)) + num_channels = kwargs.pop("in_chans", config.num_channels) if config.dilation: - kwargs["output_stride"] = 16 + kwargs["output_stride"] = kwargs.get("output_stride", 16) backbone = create_model( config.backbone, pretrained=config.use_pretrained_backbone, features_only=True, - out_indices=(1, 2, 3, 4), - in_chans=config.num_channels, + out_indices=out_indices, + in_chans=num_channels, **kwargs, ) else: diff --git a/src/transformers/models/timm_backbone/modeling_timm_backbone.py b/src/transformers/models/timm_backbone/modeling_timm_backbone.py index 0c6fe67b75731f..e8e0b28e042d6f 100644 --- a/src/transformers/models/timm_backbone/modeling_timm_backbone.py +++ b/src/transformers/models/timm_backbone/modeling_timm_backbone.py @@ -63,12 +63,13 @@ def __init__(self, config, **kwargs): # We just take the final layer by default. This matches the default for the transformers models. out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,) + in_chans = kwargs.pop("in_chans", config.num_channels) self._backbone = timm.create_model( config.backbone, pretrained=pretrained, # This is currently not possible for transformer architectures. features_only=config.features_only, - in_chans=config.num_channels, + in_chans=in_chans, out_indices=out_indices, **kwargs, ) @@ -79,7 +80,9 @@ def __init__(self, config, **kwargs): # These are used to control the output of the model when called. If output_hidden_states is True, then # return_layers is modified to include all layers. - self._return_layers = self._backbone.return_layers + self._return_layers = { + layer["module"]: str(layer["index"]) for layer in self._backbone.feature_info.get_dicts() + } self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)} super()._init_backbone(config) diff --git a/tests/models/conditional_detr/test_modeling_conditional_detr.py b/tests/models/conditional_detr/test_modeling_conditional_detr.py index d1152ed8622b9c..c3f77614b4dd31 100644 --- a/tests/models/conditional_detr/test_modeling_conditional_detr.py +++ b/tests/models/conditional_detr/test_modeling_conditional_detr.py @@ -444,7 +444,9 @@ def test_different_timm_backbone(self): # let's pick a random timm backbone config.backbone = "tf_mobilenetv3_small_075" + config.backbone_config = None config.use_timm_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} for model_class in self.all_model_classes: model = model_class(config) @@ -460,6 +462,14 @@ def test_different_timm_backbone(self): self.model_tester.num_labels, ) self.assertEqual(outputs.logits.shape, expected_shape) + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + elif model_class.__name__ == "ConditionalDetrForSegmentation": + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3) self.assertTrue(outputs) diff --git a/tests/models/deformable_detr/test_modeling_deformable_detr.py b/tests/models/deformable_detr/test_modeling_deformable_detr.py index 7a83c4f1ed80a8..36be099790a45b 100644 --- a/tests/models/deformable_detr/test_modeling_deformable_detr.py +++ b/tests/models/deformable_detr/test_modeling_deformable_detr.py @@ -521,8 +521,9 @@ def test_different_timm_backbone(self): # let's pick a random timm backbone config.backbone = "tf_mobilenetv3_small_075" - config.use_timm_backbone = True config.backbone_config = None + config.use_timm_backbone = True + config.backbone_kwargs = {"out_indices": [1, 2, 3, 4]} for model_class in self.all_model_classes: model = model_class(config) @@ -538,6 +539,14 @@ def test_different_timm_backbone(self): self.model_tester.num_labels, ) self.assertEqual(outputs.logits.shape, expected_shape) + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 4) + elif model_class.__name__ == "ConditionalDetrForSegmentation": + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.deformable_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 4) + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 4) self.assertTrue(outputs) diff --git a/tests/models/detr/test_modeling_detr.py b/tests/models/detr/test_modeling_detr.py index 59b071e031aa8a..27092c626dd46d 100644 --- a/tests/models/detr/test_modeling_detr.py +++ b/tests/models/detr/test_modeling_detr.py @@ -444,6 +444,9 @@ def test_different_timm_backbone(self): # let's pick a random timm backbone config.backbone = "tf_mobilenetv3_small_075" + config.backbone_config = None + config.use_timm_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} for model_class in self.all_model_classes: model = model_class(config) @@ -459,6 +462,14 @@ def test_different_timm_backbone(self): self.model_tester.num_labels + 1, ) self.assertEqual(outputs.logits.shape, expected_shape) + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + elif model_class.__name__ == "DetrForSegmentation": + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3) self.assertTrue(outputs) diff --git a/tests/models/table_transformer/test_modeling_table_transformer.py b/tests/models/table_transformer/test_modeling_table_transformer.py index 79da1d191063ab..d323083eb7f1d4 100644 --- a/tests/models/table_transformer/test_modeling_table_transformer.py +++ b/tests/models/table_transformer/test_modeling_table_transformer.py @@ -456,6 +456,9 @@ def test_different_timm_backbone(self): # let's pick a random timm backbone config.backbone = "tf_mobilenetv3_small_075" + config.backbone_config = None + config.use_timm_backbone = True + config.backbone_kwargs = {"out_indices": [2, 3, 4]} for model_class in self.all_model_classes: model = model_class(config) @@ -471,6 +474,11 @@ def test_different_timm_backbone(self): self.model_tester.num_labels + 1, ) self.assertEqual(outputs.logits.shape, expected_shape) + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + else: + # Confirm out_indices was propogated to backbone + self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3) self.assertTrue(outputs)