From 0212921f4c35afd08351860cd57349ae5cefb2b6 Mon Sep 17 00:00:00 2001 From: NielsRogge Date: Thu, 22 Dec 2022 13:23:33 +0000 Subject: [PATCH] More improvements --- src/transformers/models/auto/feature_extraction_auto.py | 1 - src/transformers/models/deta/modeling_deta.py | 7 +++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/auto/feature_extraction_auto.py b/src/transformers/models/auto/feature_extraction_auto.py index 5620f63f76d7f0..a33affe3ec9b2f 100644 --- a/src/transformers/models/auto/feature_extraction_auto.py +++ b/src/transformers/models/auto/feature_extraction_auto.py @@ -49,7 +49,6 @@ ("data2vec-vision", "BeitFeatureExtractor"), ("deformable_detr", "DeformableDetrFeatureExtractor"), ("deit", "DeiTFeatureExtractor"), - ("deta", "DetaFeatureExtractor"), ("detr", "DetrFeatureExtractor"), ("dinat", "ViTFeatureExtractor"), ("donut-swin", "DonutFeatureExtractor"), diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index e4a3828b6695b7..2a458fe8cc602e 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -384,16 +384,15 @@ def __init__(self, config): super().__init__() backbone = AutoBackbone.from_config(config.backbone_config) - # TODO replace batch norm by frozen batch norm - # with torch.no_grad(): - # replace_batch_norm(backbone) + with torch.no_grad(): + replace_batch_norm(backbone) self.model = backbone self.intermediate_channel_sizes = self.model.channels # TODO fix this if config.backbone_config.model_type == "resnet": for name, parameter in self.model.named_parameters(): - if "layer2" not in name and "layer3" not in name and "layer4" not in name: + if "stages.1" not in name and "stages.2" not in name and "stages.3" not in name: parameter.requires_grad_(False) self.position_embedding = build_position_encoding(config)