Skip to content

Commit

Permalink
Fix DETA save_pretrained (#30326)
Browse files Browse the repository at this point in the history
* Add class_embed to tied weights for DETA

* Fix test_tied_weights_keys for DETA model

* Replace error raise with assert statement
  • Loading branch information
qubvel authored Apr 22, 2024
1 parent 6c7335e commit 13b3b90
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1888,7 +1888,7 @@ def forward(
)
class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_tied_weights_keys = [r"bbox_embed\.\d+"]
_tied_weights_keys = [r"bbox_embed\.\d+", r"class_embed\.\d+"]
# We can't initialize the model on meta device as some weights are modified during the initialization
_no_split_modules = None

Expand Down
41 changes: 41 additions & 0 deletions tests/models/deta/test_modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
""" Testing suite for the PyTorch DETA model. """


import collections
import inspect
import math
import re
import unittest

from transformers import DetaConfig, ResNetConfig, is_torch_available, is_torchvision_available, is_vision_available
Expand All @@ -32,6 +34,8 @@
if is_torch_available():
import torch

from transformers.pytorch_utils import id_tensor_storage

if is_torchvision_available():
from transformers import DetaForObjectDetection, DetaModel

Expand Down Expand Up @@ -520,6 +524,43 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

# Inspired by tests.test_modeling_common.ModelTesterMixin.test_tied_weights_keys
def test_tied_weights_keys(self):
for model_class in self.all_model_classes:
# We need to pass model class name to correctly initialize the config.
# If we don't pass it, the config for `DetaForObjectDetection`` will be initialized
# with `two_stage=False` and the test will fail because for that case `class_embed`
# weights are not tied.
config, _ = self.model_tester.prepare_config_and_inputs_for_common(model_class_name=model_class.__name__)
config.tie_word_embeddings = True

model_tied = model_class(config)

ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)

# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]

tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")

# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]

tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)


TOLERANCE = 1e-4

Expand Down
4 changes: 2 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,8 +2025,8 @@ def test_tied_weights_keys(self):
tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group):
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
is_tied_key = any(re.search(key, p) for group in tied_params for p in group)
self.assertTrue(is_tied_key, f"{key} is not a tied weight key for {model_class}.")

# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
Expand Down

0 comments on commit 13b3b90

Please sign in to comment.