From c34ab9922eb0c4013d29e2885b9f2add8da6cd32 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Thu, 18 Apr 2024 17:37:22 +0000 Subject: [PATCH 1/3] Add class_embed to tied weights for DETA --- src/transformers/models/deta/modeling_deta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/deta/modeling_deta.py b/src/transformers/models/deta/modeling_deta.py index ce0a5e79aa4eb1..b90a62dfa5342c 100644 --- a/src/transformers/models/deta/modeling_deta.py +++ b/src/transformers/models/deta/modeling_deta.py @@ -1888,7 +1888,7 @@ def forward( ) class DetaForObjectDetection(DetaPreTrainedModel): # When using clones, all layers > 0 will be clones, but layer 0 *is* required - _tied_weights_keys = [r"bbox_embed\.\d+"] + _tied_weights_keys = [r"bbox_embed\.\d+", r"class_embed\.\d+"] # We can't initialize the model on meta device as some weights are modified during the initialization _no_split_modules = None From c6214b204035f6b8f2453e10d14b97b73fb4fa91 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Thu, 18 Apr 2024 21:42:52 +0000 Subject: [PATCH 2/3] Fix test_tied_weights_keys for DETA model --- tests/models/deta/test_modeling_deta.py | 41 +++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/models/deta/test_modeling_deta.py b/tests/models/deta/test_modeling_deta.py index 3a3a957dd012e2..6cc5b7247906ff 100644 --- a/tests/models/deta/test_modeling_deta.py +++ b/tests/models/deta/test_modeling_deta.py @@ -15,8 +15,10 @@ """ Testing suite for the PyTorch DETA model. """ +import collections import inspect import math +import re import unittest from transformers import DetaConfig, ResNetConfig, is_torch_available, is_torchvision_available, is_vision_available @@ -32,6 +34,8 @@ if is_torch_available(): import torch + from transformers.pytorch_utils import id_tensor_storage + if is_torchvision_available(): from transformers import DetaForObjectDetection, DetaModel @@ -520,6 +524,43 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + # Inspired by tests.test_modeling_common.ModelTesterMixin.test_tied_weights_keys + def test_tied_weights_keys(self): + for model_class in self.all_model_classes: + # We need to pass model class name to correctly initialize the config. + # If we don't pass it, the config for `DetaForObjectDetection`` will be initialized + # with `two_stage=False` and the test will fail because for that case `class_embed` + # weights are not tied. + config, _ = self.model_tester.prepare_config_and_inputs_for_common(model_class_name=model_class.__name__) + config.tie_word_embeddings = True + + model_tied = model_class(config) + + ptrs = collections.defaultdict(list) + for name, tensor in model_tied.state_dict().items(): + ptrs[id_tensor_storage(tensor)].append(name) + + # These are all the pointers of shared tensors. + tied_params = [names for _, names in ptrs.items() if len(names) > 1] + + tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else [] + # Detect we get a hit for each key + for key in tied_weight_keys: + if not any(re.search(key, p) for group in tied_params for p in group): + raise ValueError(f"{key} is not a tied weight key for {model_class}.") + + # Removed tied weights found from tied params -> there should only be one left after + for key in tied_weight_keys: + for i in range(len(tied_params)): + tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None] + + tied_params = [group for group in tied_params if len(group) > 1] + self.assertListEqual( + tied_params, + [], + f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.", + ) + TOLERANCE = 1e-4 From 3b13b3831ec0375e0ada771035cab2f49880cf9f Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 22 Apr 2024 15:03:58 +0000 Subject: [PATCH 3/3] Replace error raise with assert statement --- tests/models/deta/test_modeling_deta.py | 4 ++-- tests/test_modeling_common.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/models/deta/test_modeling_deta.py b/tests/models/deta/test_modeling_deta.py index 6cc5b7247906ff..655bb50bb52dbb 100644 --- a/tests/models/deta/test_modeling_deta.py +++ b/tests/models/deta/test_modeling_deta.py @@ -546,8 +546,8 @@ def test_tied_weights_keys(self): tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else [] # Detect we get a hit for each key for key in tied_weight_keys: - if not any(re.search(key, p) for group in tied_params for p in group): - raise ValueError(f"{key} is not a tied weight key for {model_class}.") + is_tied_key = any(re.search(key, p) for group in tied_params for p in group) + self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.") # Removed tied weights found from tied params -> there should only be one left after for key in tied_weight_keys: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 71cb28d7548555..c5f22c5eb23cdb 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2025,8 +2025,8 @@ def test_tied_weights_keys(self): tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else [] # Detect we get a hit for each key for key in tied_weight_keys: - if not any(re.search(key, p) for group in tied_params for p in group): - raise ValueError(f"{key} is not a tied weight key for {model_class}.") + is_tied_key = any(re.search(key, p) for group in tied_params for p in group) + self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.") # Removed tied weights found from tied params -> there should only be one left after for key in tied_weight_keys: