Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DETR] Remove timm hardcoded logic in modeling files #29038

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5e8d40c
Enable instantiating model with pretrained backbone weights
amyeroberts Dec 22, 2023
11ede69
Clarify pretrained import
amyeroberts Jan 3, 2024
e1067a4
Use load_backbone instead
amyeroberts Jan 4, 2024
d19fc39
Add backbone_kwargs to config
amyeroberts Jan 4, 2024
59ba869
Fix up
amyeroberts Jan 31, 2024
899c9bd
Add tests
amyeroberts Feb 1, 2024
6fc904d
Tidy up
amyeroberts Feb 1, 2024
ad201dd
Enable instantiating model with pretrained backbone weights
amyeroberts Dec 22, 2023
589007d
Update tests so backbone checkpoint isn't passed in
amyeroberts Jan 2, 2024
e593007
Clarify pretrained import
amyeroberts Jan 3, 2024
4a601dc
Update configs - docs and validation check
amyeroberts Jan 4, 2024
4506d08
Update src/transformers/utils/backbone_utils.py
amyeroberts Jan 4, 2024
cbc0f6a
Clarify exception message
amyeroberts Jan 4, 2024
89804ae
Update config init in tests
amyeroberts Jan 4, 2024
6300aef
Add test for when use_timm_backbone=True
amyeroberts Jan 4, 2024
6870317
Use load_backbone instead
amyeroberts Jan 4, 2024
aa3376c
Add use_timm_backbone to the model configs
amyeroberts Jan 4, 2024
737fd0c
Add backbone_kwargs to config
amyeroberts Jan 4, 2024
62d79f7
Pass kwargs to constructors
amyeroberts Jan 4, 2024
19fd92d
Draft
amyeroberts Jan 4, 2024
f9cbc01
Fix tests
amyeroberts Feb 19, 2024
cffd51f
Add back timm - weight naming
amyeroberts Mar 7, 2024
5bf5329
More tidying up
amyeroberts Mar 8, 2024
7d4e93a
Whoops
amyeroberts Mar 8, 2024
ac56450
Tidy up
amyeroberts Mar 8, 2024
1c55822
Handle when kwargs are none
amyeroberts Mar 8, 2024
30d3232
Update tests
amyeroberts Mar 8, 2024
8183436
Revert test changes
amyeroberts Mar 8, 2024
61fe673
Deformable detr test - don't use default
amyeroberts Mar 8, 2024
94309d9
Don't mutate; correct model attributes
amyeroberts Mar 12, 2024
80b32cc
Add some clarifying comments
amyeroberts Apr 26, 2024
5a9799c
nit - grammar is hard
amyeroberts Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,16 @@ def __init__(
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")

if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These replicate the defaults that are used to load a timm backbone in the modeling file. This PR makes it possible to configure the timm backbone loaded, using the standard backbone API, the defaults here are for backwards compatibility

backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,12 @@ def replace_batch_norm(model):
replace_batch_norm(module)


# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->ConditionalDetr
class ConditionalDetrConvEncoder(nn.Module):
"""
Convolutional backbone, using either the AutoBackbone API or one from the timm library.

nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above.
nn.BatchNorm2d layers are replaced by ConditionalDetrFrozenBatchNorm2d as defined above.

"""

Expand All @@ -352,17 +352,23 @@ def __init__(self, config):

self.config = config

# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we can't remove the timm logic here and use load_backbone instead. When using load_backbone a timm model is loaded as a TimmBackbone class. This means, the loaded weight names are different from using the create_model call here. For backwards compatibility - being able to load existing checkpoints - we need to leave as-is.

Instead - to be compatible with the backbone API and remove the hard-coding, we allow specifying of the backbone behaviour through backbone_kwargs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense!

# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
Comment on lines -357 to +363
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe add a few comments here to explain what's happening for posterity

if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,24 @@ def __init__(
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [2, 3, 4] if num_feature_levels > 1 else [4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)

self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
Expand Down
49 changes: 30 additions & 19 deletions src/transformers/models/deformable_detr/modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,31 @@ def load_cuda_kernels():
if is_vision_available():
from transformers.image_transforms import center_to_corners_format


if is_accelerate_available():
from accelerate import PartialState
from accelerate.utils import reduce


if is_timm_available():
from timm import create_model


if is_scipy_available():
from scipy.optimize import linear_sum_assignment


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "DeformableDetrConfig"
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"

DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [
"sensetime/deformable-detr",
# See all Deformable DETR models at https://huggingface.co/models?filter=deformable-detr
]


class MultiScaleDeformableAttentionFunction(Function):
@staticmethod
def forward(
Expand Down Expand Up @@ -141,21 +161,6 @@ def backward(context, grad_output):
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None


if is_scipy_available():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moves these above the MultiScaleDeformableAttentionFunction definition - better matching library patterns

from scipy.optimize import linear_sum_assignment

if is_timm_available():
from timm import create_model

logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "DeformableDetrConfig"
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr"


from ..deprecated._archive_maps import DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402


@dataclass
class DeformableDetrDecoderOutput(ModelOutput):
"""
Expand Down Expand Up @@ -420,17 +425,23 @@ def __init__(self, config):

self.config = config

# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
Comment on lines -425 to +438
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(2, 3, 4) if config.num_feature_levels > 1 else (4,),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
Expand Down
14 changes: 12 additions & 2 deletions src/transformers/models/detr/configuration_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,26 @@ def __init__(
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
backbone = None
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
dilation = None

self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
Expand Down
16 changes: 12 additions & 4 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@
if is_scipy_available():
from scipy.optimize import linear_sum_assignment


if is_timm_available():
from timm import create_model


if is_vision_available():
from transformers.image_transforms import center_to_corners_format

Expand Down Expand Up @@ -345,17 +347,23 @@ def __init__(self, config):

self.config = config

# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,10 +1075,10 @@ def __init__(self, config):
super().__init__(config)

self.backbone = None
if config.backbone_config is not None and config.is_hybrid is False:
self.backbone = load_backbone(config)
else:
if config.is_hybrid or config.backbone_config is None:
self.dpt = DPTModel(config, add_pooling_layer=False)
else:
self.backbone = load_backbone(config)

# Neck
self.neck = DPTNeck(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,26 @@ def __init__(
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
# We default to values which were previously hard-coded in the model. This enables configurability of the config
# while keeping the default behavior the same.
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
backbone = None
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
dilation = None

self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,17 +279,23 @@ def __init__(self, config):

self.config = config

# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API
if config.use_timm_backbone:
# We default to values which were previously hard-coded. This enables configurability from the config
# using backbone arguments, while keeping the default behavior the same.
requires_backends(self, ["timm"])
kwargs = {}
kwargs = getattr(config, "backbone_kwargs", {})
kwargs = {} if kwargs is None else kwargs.copy()
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4))
num_channels = kwargs.pop("in_chans", config.num_channels)
if config.dilation:
kwargs["output_stride"] = 16
kwargs["output_stride"] = kwargs.get("output_stride", 16)
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
out_indices=out_indices,
in_chans=num_channels,
**kwargs,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def __init__(self, config, **kwargs):
# We just take the final layer by default. This matches the default for the transformers models.
out_indices = config.out_indices if getattr(config, "out_indices", None) is not None else (-1,)

in_chans = kwargs.pop("in_chans", config.num_channels)
self._backbone = timm.create_model(
config.backbone,
pretrained=pretrained,
# This is currently not possible for transformer architectures.
features_only=config.features_only,
in_chans=config.num_channels,
in_chans=in_chans,
out_indices=out_indices,
**kwargs,
)
Expand All @@ -79,7 +80,9 @@ def __init__(self, config, **kwargs):

# These are used to control the output of the model when called. If output_hidden_states is True, then
# return_layers is modified to include all layers.
self._return_layers = self._backbone.return_layers
self._return_layers = {
layer["module"]: str(layer["index"]) for layer in self._backbone.feature_info.get_dicts()
}
self._all_layers = {layer["module"]: str(i) for i, layer in enumerate(self._backbone.feature_info.info)}
super()._init_backbone(config)

Expand Down
10 changes: 10 additions & 0 deletions tests/models/conditional_detr/test_modeling_conditional_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,9 @@ def test_different_timm_backbone(self):

# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [2, 3, 4]}

for model_class in self.all_model_classes:
model = model_class(config)
Expand All @@ -460,6 +462,14 @@ def test_different_timm_backbone(self):
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
elif model_class.__name__ == "ConditionalDetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.conditional_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3)

self.assertTrue(outputs)

Expand Down
11 changes: 10 additions & 1 deletion tests/models/deformable_detr/test_modeling_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,8 +521,9 @@ def test_different_timm_backbone(self):

# let's pick a random timm backbone
config.backbone = "tf_mobilenetv3_small_075"
config.use_timm_backbone = True
config.backbone_config = None
config.use_timm_backbone = True
config.backbone_kwargs = {"out_indices": [1, 2, 3, 4]}

for model_class in self.all_model_classes:
model = model_class(config)
Expand All @@ -538,6 +539,14 @@ def test_different_timm_backbone(self):
self.model_tester.num_labels,
)
self.assertEqual(outputs.logits.shape, expected_shape)
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 4)
elif model_class.__name__ == "ConditionalDetrForSegmentation":
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.deformable_detr.model.backbone.conv_encoder.intermediate_channel_sizes), 4)
else:
# Confirm out_indices was propogated to backbone
self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 4)

self.assertTrue(outputs)

Expand Down
Loading
Loading