Skip to content

Commit

Permalink
Draft
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Feb 15, 2024
1 parent 4422c82 commit bc8d155
Show file tree
Hide file tree
Showing 23 changed files with 230 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class ConditionalDetrConfig(PretrainedConfig):
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to the backbone constructor e.g. `{'out_indices': (0, 1, 2, 3)}`.
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
Expand Down Expand Up @@ -194,10 +195,14 @@ def __init__(
if backbone_config is not None and use_timm_backbone:
raise ValueError("You can't specify both `backbone_config` and `use_timm_backbone`.")

if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
Expand All @@ -208,7 +213,7 @@ def __init__(

self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
self._num_channels = num_channels
self.num_queries = num_queries
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
Expand All @@ -230,8 +235,8 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.backbone_kwargs = backbone_kwargs if backbone_kwargs is not None else {}
self.dilation = dilation
self.backbone_kwargs = backbone_kwargs
self._dilation = dilation
# Hungarian matcher
self.class_cost = class_cost
self.bbox_cost = bbox_cost
Expand All @@ -245,6 +250,16 @@ def __init__(
self.focal_alpha = focal_alpha
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

@property
def num_channels(self):
logger.warn("The `num_channels` attribute is deprecated and will be removed in v4.40")
return self._num_channels

@property
def dilation(self):
logger.warn("The `dilation` attribute is deprecated and will be removed in v4.40")
return self._dilation

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
Expand All @@ -253,6 +268,12 @@ def num_attention_heads(self) -> int:
def hidden_size(self) -> int:
return self.d_model

def to_dict(self):
output = super().to_dict()
output.pop("_num_channels", None)
output.pop("_dilation", None)
return output


class ConditionalDetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_vision_available,
logging,
replace_return_docstrings,
Expand All @@ -44,9 +43,6 @@
if is_scipy_available():
from scipy.optimize import linear_sum_assignment

if is_timm_available():
from timm import create_model

if is_vision_available():
from ...image_transforms import center_to_corners_format

Expand Down Expand Up @@ -348,30 +344,13 @@ def __init__(self, config):
super().__init__()

self.config = config

if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(1, 2, 3, 4),
in_chans=config.num_channels,
**kwargs,
)
else:
backbone = load_backbone(config)
backbone = load_backbone(config)

# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
self.intermediate_channel_sizes = (
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)
self.intermediate_channel_sizes = self.model.channels

backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
if "resnet" in backbone_model_type:
Expand All @@ -385,7 +364,7 @@ def __init__(self, config):

def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
# send pixel_values through the model to get list of feature maps
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
features = self.model(pixel_values).feature_maps

out = []
for feature_map in features:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class DeformableDetrConfig(PretrainedConfig):
use_pretrained_backbone (`bool`, *optional*, defaults to `True`):
Whether to use pretrained weights for the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to the backbone constructor e.g. `{'out_indices': (0, 1, 2, 3)}`.
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
Expand Down Expand Up @@ -213,14 +214,22 @@ def __init__(
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [2, 3, 4] if num_feature_levels > 1 else [4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)

self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
Expand All @@ -244,7 +253,7 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.backbone_kwargs = backbone_kwargs if backbone_kwargs is not None else {}
self.backbone_kwargs = backbone_kwargs
self.dilation = dilation
# deformable attributes
self.num_feature_levels = num_feature_levels
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
is_timm_available,
is_torch_cuda_available,
is_vision_available,
replace_return_docstrings,
Expand Down Expand Up @@ -117,8 +116,6 @@ def backward(context, grad_output):
if is_scipy_available():
from scipy.optimize import linear_sum_assignment

if is_timm_available():
from timm import create_model

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -394,30 +391,13 @@ def __init__(self, config):
super().__init__()

self.config = config

if config.use_timm_backbone:
requires_backends(self, ["timm"])
kwargs = {}
if config.dilation:
kwargs["output_stride"] = 16
backbone = create_model(
config.backbone,
pretrained=config.use_pretrained_backbone,
features_only=True,
out_indices=(2, 3, 4) if config.num_feature_levels > 1 else (4,),
in_chans=config.num_channels,
**kwargs,
)
else:
backbone = load_backbone(config)
backbone = load_backbone(config)

# replace batch norm by frozen batch norm
with torch.no_grad():
replace_batch_norm(backbone)
self.model = backbone
self.intermediate_channel_sizes = (
self.model.feature_info.channels() if config.use_timm_backbone else self.model.channels
)
self.intermediate_channel_sizes = self.model.channels

backbone_model_type = config.backbone if config.use_timm_backbone else config.backbone_config.model_type
if "resnet" in backbone_model_type:
Expand All @@ -432,7 +412,7 @@ def __init__(self, config):
# Copied from transformers.models.detr.modeling_detr.DetrConvEncoder.forward with Detr->DeformableDetr
def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
# send pixel_values through the model to get list of feature maps
features = self.model(pixel_values) if self.config.use_timm_backbone else self.model(pixel_values).feature_maps
features = self.model(pixel_values).feature_maps

out = []
for feature_map in features:
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/deta/configuration_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class DetaConfig(PretrainedConfig):
Whether to load `backbone` from the timm library. If `False`, the backbone is loaded from the transformers
library.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to the backbone constructor e.g. `{'out_indices': (0, 1, 2, 3)}`.
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
num_queries (`int`, *optional*, defaults to 900):
Number of object queries, i.e. detection slots. This is the maximal number of objects [`DetaModel`] can
detect in a single image. In case `two_stage` is set to `True`, we use `two_stage_num_proposals` instead.
Expand Down Expand Up @@ -218,7 +219,7 @@ def __init__(
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.use_timm_backbone = use_timm_backbone
self.backbone_kwargs = backbone_kwargs if backbone_kwargs is not None else {}
self.backbone_kwargs = backbone_kwargs
self.num_queries = num_queries
self.max_position_embeddings = max_position_embeddings
self.d_model = d_model
Expand Down
37 changes: 31 additions & 6 deletions src/transformers/models/detr/configuration_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class DetrConfig(PretrainedConfig):
use_pretrained_backbone (`bool`, *optional*, `True`):
Whether to use pretrained weights for the backbone.
backbone_kwargs (`dict`, *optional*):
Keyword arguments to be passed to the backbone constructor e.g. `{'out_indices': (0, 1, 2, 3)}`.
Keyword arguments to be passed to AutoBackbone when loading from a checkpoint
e.g. `{'out_indices': (0, 1, 2, 3)}`. Cannot be specified if `backbone_config` is set.
dilation (`bool`, *optional*, defaults to `False`):
Whether to replace stride with dilation in the last convolutional block (DC5). Only supported when
`use_timm_backbone` = `True`.
Expand Down Expand Up @@ -194,20 +195,28 @@ def __init__(
if backbone_kwargs is not None and backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")

if not use_timm_backbone:
if use_timm_backbone and backbone_kwargs is None:
backbone_kwargs = {}
if dilation:
backbone_kwargs["output_stride"] = 16
backbone_kwargs["out_indices"] = [1, 2, 3, 4]
backbone_kwargs["in_chans"] = num_channels
# Backwards compatibility
elif not use_timm_backbone and backbone in (None, "resnet50"):
if backbone_config is None:
logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
elif isinstance(backbone_config, dict):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
backbone = None
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
dilation = None

self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
self.num_channels = num_channels
self._num_channels = num_channels
self.num_queries = num_queries
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
Expand All @@ -229,8 +238,8 @@ def __init__(
self.position_embedding_type = position_embedding_type
self.backbone = backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.backbone_kwargs = backbone_kwargs if backbone_kwargs is not None else {}
self.dilation = dilation
self.backbone_kwargs = backbone_kwargs
self._dilation = dilation
# Hungarian matcher
self.class_cost = class_cost
self.bbox_cost = bbox_cost
Expand All @@ -243,6 +252,16 @@ def __init__(
self.eos_coefficient = eos_coefficient
super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs)

@property
def num_channels(self):
logger.warn("The `num_channels` attribute is deprecated and will be removed in v4.40")
return self._num_channels

@property
def dilation(self):
logger.warn("The `dilation` attribute is deprecated and will be removed in v4.40")
return self._dilation

@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
Expand All @@ -263,6 +282,12 @@ def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
"""
return cls(backbone_config=backbone_config, **kwargs)

def to_dict(self):
output = super().to_dict()
output.pop("_num_channels", None)
output.pop("_dilation", None)
return output


class DetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
Expand Down
Loading

0 comments on commit bc8d155

Please sign in to comment.