-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DETR
] Remove timm hardcoded logic in modeling files
#29038
Changes from all commits
5e8d40c
11ede69
e1067a4
d19fc39
59ba869
899c9bd
6fc904d
ad201dd
589007d
e593007
4a601dc
4506d08
cbc0f6a
89804ae
6300aef
6870317
aa3376c
737fd0c
62d79f7
19fd92d
f9cbc01
cffd51f
5bf5329
7d4e93a
ac56450
1c55822
30d3232
8183436
61fe673
94309d9
80b32cc
5a9799c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -338,12 +338,12 @@ def replace_batch_norm(model): | |
replace_batch_norm(module) | ||
|
||
|
||
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder | ||
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder with Detr->ConditionalDetr | ||
class ConditionalDetrConvEncoder(nn.Module): | ||
""" | ||
Convolutional backbone, using either the AutoBackbone API or one from the timm library. | ||
|
||
nn.BatchNorm2d layers are replaced by DetrFrozenBatchNorm2d as defined above. | ||
nn.BatchNorm2d layers are replaced by ConditionalDetrFrozenBatchNorm2d as defined above. | ||
|
||
""" | ||
|
||
|
@@ -352,17 +352,23 @@ def __init__(self, config): | |
|
||
self.config = config | ||
|
||
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API | ||
if config.use_timm_backbone: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, we can't remove the timm logic here and use load_backbone instead. When using load_backbone a timm model is loaded as a TimmBackbone class. This means, the loaded weight names are different from using the create_model call here. For backwards compatibility - being able to load existing checkpoints - we need to leave as-is. Instead - to be compatible with the backbone API and remove the hard-coding, we allow specifying of the backbone behaviour through backbone_kwargs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, makes sense! |
||
# We default to values which were previously hard-coded. This enables configurability from the config | ||
# using backbone arguments, while keeping the default behavior the same. | ||
requires_backends(self, ["timm"]) | ||
kwargs = {} | ||
kwargs = getattr(config, "backbone_kwargs", {}) | ||
kwargs = {} if kwargs is None else kwargs.copy() | ||
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4)) | ||
num_channels = kwargs.pop("in_chans", config.num_channels) | ||
Comment on lines
-357
to
+363
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would maybe add a few comments here to explain what's happening for posterity |
||
if config.dilation: | ||
kwargs["output_stride"] = 16 | ||
kwargs["output_stride"] = kwargs.get("output_stride", 16) | ||
backbone = create_model( | ||
config.backbone, | ||
pretrained=config.use_pretrained_backbone, | ||
features_only=True, | ||
out_indices=(1, 2, 3, 4), | ||
in_chans=config.num_channels, | ||
out_indices=out_indices, | ||
in_chans=num_channels, | ||
**kwargs, | ||
) | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,11 +88,31 @@ def load_cuda_kernels(): | |
if is_vision_available(): | ||
from transformers.image_transforms import center_to_corners_format | ||
|
||
|
||
if is_accelerate_available(): | ||
from accelerate import PartialState | ||
from accelerate.utils import reduce | ||
|
||
|
||
if is_timm_available(): | ||
from timm import create_model | ||
|
||
|
||
if is_scipy_available(): | ||
from scipy.optimize import linear_sum_assignment | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
_CONFIG_FOR_DOC = "DeformableDetrConfig" | ||
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr" | ||
|
||
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [ | ||
"sensetime/deformable-detr", | ||
# See all Deformable DETR models at https://huggingface.co/models?filter=deformable-detr | ||
] | ||
|
||
|
||
class MultiScaleDeformableAttentionFunction(Function): | ||
@staticmethod | ||
def forward( | ||
|
@@ -141,21 +161,6 @@ def backward(context, grad_output): | |
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None | ||
|
||
|
||
if is_scipy_available(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moves these above the MultiScaleDeformableAttentionFunction definition - better matching library patterns |
||
from scipy.optimize import linear_sum_assignment | ||
|
||
if is_timm_available(): | ||
from timm import create_model | ||
|
||
logger = logging.get_logger(__name__) | ||
|
||
_CONFIG_FOR_DOC = "DeformableDetrConfig" | ||
_CHECKPOINT_FOR_DOC = "sensetime/deformable-detr" | ||
|
||
|
||
from ..deprecated._archive_maps import DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST # noqa: F401, E402 | ||
|
||
|
||
@dataclass | ||
class DeformableDetrDecoderOutput(ModelOutput): | ||
""" | ||
|
@@ -420,17 +425,23 @@ def __init__(self, config): | |
|
||
self.config = config | ||
|
||
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API | ||
if config.use_timm_backbone: | ||
# We default to values which were previously hard-coded. This enables configurability from the config | ||
# using backbone arguments, while keeping the default behavior the same. | ||
requires_backends(self, ["timm"]) | ||
kwargs = {} | ||
kwargs = getattr(config, "backbone_kwargs", {}) | ||
kwargs = {} if kwargs is None else kwargs.copy() | ||
out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,)) | ||
num_channels = kwargs.pop("in_chans", config.num_channels) | ||
if config.dilation: | ||
kwargs["output_stride"] = 16 | ||
kwargs["output_stride"] = kwargs.get("output_stride", 16) | ||
Comment on lines
-425
to
+438
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
backbone = create_model( | ||
config.backbone, | ||
pretrained=config.use_pretrained_backbone, | ||
features_only=True, | ||
out_indices=(2, 3, 4) if config.num_feature_levels > 1 else (4,), | ||
in_chans=config.num_channels, | ||
out_indices=out_indices, | ||
in_chans=num_channels, | ||
**kwargs, | ||
) | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These replicate the defaults that are used to load a timm backbone in the modeling file. This PR makes it possible to configure the timm backbone loaded, using the standard backbone API, the defaults here are for backwards compatibility