Skip to content

Commit

Permalink
[DETR] Remove timm hardcoded logic in modeling files (#29038)
Browse files Browse the repository at this point in the history
* Enable instantiating model with pretrained backbone weights

* Clarify pretrained import

* Use load_backbone instead

* Add backbone_kwargs to config

* Fix up

* Add tests

* Tidy up

* Enable instantiating model with pretrained backbone weights

* Update tests so backbone checkpoint isn't passed in

* Clarify pretrained import

* Update configs - docs and validation check

* Update src/transformers/utils/backbone_utils.py

Co-authored-by: Arthur <[email protected]>

* Clarify exception message

* Update config init in tests

* Add test for when use_timm_backbone=True

* Use load_backbone instead

* Add use_timm_backbone to the model configs

* Add backbone_kwargs to config

* Pass kwargs to constructors

* Draft

* Fix tests

* Add back timm - weight naming

* More tidying up

* Whoops

* Tidy up

* Handle when kwargs are none

* Update tests

* Revert test changes

* Deformable detr test - don't use default

* Don't mutate; correct model attributes

* Add some clarifying comments

* nit - grammar is hard

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
amyeroberts and ArthurZucker authored Apr 26, 2024
1 parent 77ff304 commit aafa7ce
Show file tree
Hide file tree
Showing 14 changed files with 156 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,16 @@ def __init__(
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")

if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,12 @@ def replace_batch_norm(model):
replace_batch_norm(module)


# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->ConditionalDetr
class ConditionalDetrConvEncoder(nn.Module):
"""
Convolutional backbone, using either the AutoBackbone API or one from the timm library.
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
nn.BatchNorm2d layers are replaced by ConditionalDetrFrozenBatchNorm2d as defined above.
"""

Expand All @@ -352,17 +352,23 @@ def __init__(self, config):

self.config = config

# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,24 @@ def __init__(
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [2, 3, 4] if num_feature_levels > 1 else [4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)

self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
Expand Down
49 changes: 30 additions & 19 deletions src/transformers/models/deformable_detr/modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,31 @@ def load_cuda_kernels():
if is_vision_available():
from transformers.image_transforms import center_to_corners_format


if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce


if is_timm_available():
from timm import create_model


if is_scipy_available():
from scipy.optimize import linear_sum_assignment


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "DeformableDetrConfig"
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"

DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [
"sensetime/deformable-detr",
# See all Deformable DETR models at https://huggingface.co/models?filter=deformable-detr
]


class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
Expand Down Expand Up @@ -141,21 +161,6 @@ def backward(context, grad_output):
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None


if is_scipy_available():
from scipy.optimize import linear_sum_assignment

if is_timm_available():
from timm import create_model

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "DeformableDetrConfig"
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"


from ..deprecated._archive_maps import DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402


@dataclass
class DeformableDetrDecoderOutput(ModelOutput):
"""
Expand Down Expand Up @@ -420,17 +425,23 @@ def __init__(self, config):

self.config = config

# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(2, 3, 4) if config.num_feature_levels > 1 else (4,),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
Expand Down
14 changes: 12 additions & 2 deletions src/transformers/models/detr/configuration_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,26 @@ def __init__(
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
backbone = None
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
dilation = None

self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@
if is_scipy_available():
from scipy.optimize import linear_sum_assignment


if is_timm_available():
from timm import create_model


if is_vision_available():
from transformers.image_transforms import center_to_corners_format

Expand Down Expand Up @@ -345,17 +347,23 @@ def __init__(self, config):

self.config = config

# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,10 +1075,10 @@ def __init__(self, config):
super().__init__(config)

self.backbone = None
if config.backbone_config is not None and config.is_hybrid is False:
self.backbone = load_backbone(config)
else:
if config.is_hybrid or config.backbone_config is None:
self.dpt = DPTModel(config, add_pooling_layer=False)
else:
self.backbone = load_backbone(config)

# Neck
self.neck = DPTNeck(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,26 @@ def __init__(
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
backbone = None
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
dilation = None

self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,23 @@ def __init__(self, config):

self.config = config

# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def __init__(self, config, **kwargs):
# We just take the final layer by default. This matches the default for the transformers models.
out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)

in_chans = kwargs.pop("in_chans", config.num_channels)
self._backbone = timm.create_model(
config.backbone,
pretrained=pretrained,
# This is currently not possible for transformer architectures.
features_only=config.features_only,
in_chans=config.num_channels,
in_chans=in_chans,
out_indices=out_indices,
**kwargs,
)
Expand All @@ -79,7 +80,9 @@ def __init__(self, config, **kwargs):

# These are used to control the output of the model when called. If output_hidden_states is True, then
# return_layers is modified to include all layers.
self._return_layers = self._backbone.return_layers
self._return_layers = {
layer["module"]: str(layer["index"]) for layer in self._backbone.feature_info.get_dicts()
}
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
super()._init_backbone(config)

Expand Down
10 changes: 10 additions & 0 deletions tests/models/conditional_detr/test_modeling_conditional_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,9 @@ def test_different_timm_backbone(self):

# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}

for model_class in self.all_model_classes:
model = model_class(config)
Expand All @@ -460,6 +462,14 @@ def test_different_timm_backbone(self):
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
elif model_class.__name__ == "ConditionalDetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)

self.assertTrue(outputs)

Expand Down
11 changes: 10 additions & 1 deletion tests/models/deformable_detr/test_modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,9 @@ def test_different_timm_backbone(self):

# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.use_timm_backbone = True
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [1, 2, 3, 4]}

for model_class in self.all_model_classes:
model = model_class(config)
Expand All @@ -538,6 +539,14 @@ def test_different_timm_backbone(self):
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 4)
elif model_class.__name__ == "ConditionalDetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.deformable_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 4)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 4)

self.assertTrue(outputs)

Expand Down
Loading

0 comments on commit aafa7ce

Please sign in to comment.