diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 5620f63f76d7f0..a33affe3ec9b2f 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -49,7 +49,6 @@ ("data2vec-vision", "BeitFeatureExtractor"), ("deformable_detr", "DeformableDetrFeatureExtractor"), ("deit", "DeiTFeatureExtractor"), - ("deta", "DetaFeatureExtractor"), ("detr", "DetrFeatureExtractor"), ("dinat", "ViTFeatureExtractor"), ("donut-swin", "DonutFeatureExtractor"), diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index e4a3828b6695b7..2a458fe8cc602e 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -384,16 +384,15 @@ def __init__(self, config): super().__init__() backbone = AutoBackbone.from_config(config.backbone_config) - # TODO replace batch norm by frozen batch norm - # with torch.no_grad(): - # replace_batch_norm(backbone) + with torch.no_grad(): + replace_batch_norm(backbone) self.model = backbone self.intermediate_channel_sizes = self.model.channels # TODO fix this if config.backbone_config.model_type == "resnet": for name, parameter in self.model.named_parameters(): - if "layer2" not in name and "layer3" not in name and "layer4" not in name: + if "stages.1" not in name and "stages.2" not in name and "stages.3" not in name: parameter.requires_grad_(False) self.position_embedding = build_position_encoding(config)