Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DETA save_pretrained #30326

Merged
merged 3 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1888,7 +1888,7 @@ def forward(
)
class DetaForObjectDetection(DetaPreTrainedModel):
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
_tied_weights_keys = [r"bbox_embed\.\d+"]
_tied_weights_keys = [r"bbox_embed\.\d+", r"class_embed\.\d+"]
# We can't initialize the model on meta device as some weights are modified during the initialization
_no_split_modules = None

Expand Down
41 changes: 41 additions & 0 deletions tests/models/deta/test_modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
""" Testing suite for the PyTorch DETA model. """


import collections
import inspect
import math
import re
import unittest

from transformers import DetaConfig, ResNetConfig, is_torch_available, is_torchvision_available, is_vision_available
Expand All @@ -32,6 +34,8 @@
if is_torch_available():
import torch

from transformers.pytorch_utils import id_tensor_storage

if is_torchvision_available():
from transformers import DetaForObjectDetection, DetaModel

Expand Down Expand Up @@ -520,6 +524,43 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

# Inspired by tests.test_modeling_common.ModelTesterMixin.test_tied_weights_keys
def test_tied_weights_keys(self):
for model_class in self.all_model_classes:
# We need to pass model class name to correctly initialize the config.
# If we don't pass it, the config for `DetaForObjectDetection`` will be initialized
# with `two_stage=False` and the test will fail because for that case `class_embed`
# weights are not tied.
config, _ = self.model_tester.prepare_config_and_inputs_for_common(model_class_name=model_class.__name__)
config.tie_word_embeddings = True

model_tied = model_class(config)

ptrs = collections.defaultdict(list)
for name, tensor in model_tied.state_dict().items():
ptrs[id_tensor_storage(tensor)].append(name)

# These are all the pointers of shared tensors.
tied_params = [names for _, names in ptrs.items() if len(names) > 1]

tied_weight_keys = model_tied._tied_weights_keys if model_tied._tied_weights_keys is not None else []
# Detect we get a hit for each key
for key in tied_weight_keys:
if not any(re.search(key, p) for group in tied_params for p in group):
raise ValueError(f"{key} is not a tied weight key for {model_class}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't really be raising exceptions in a test. Exceptions are for terminating out of code when there's an incorrect input / code behaviour which we can then choose to handle. In tests, we're really performing sanity checks, which should always be True given the test.

Either we should test this behaviour with an assert or just remove.

Copy link
Member Author

@qubvel qubvel Apr 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Amy, I fixed this test case and the original one too 👍 (replaced with assert)


# Removed tied weights found from tied params -> there should only be one left after
for key in tied_weight_keys:
for i in range(len(tied_params)):
tied_params[i] = [p for p in tied_params[i] if re.search(key, p) is None]

tied_params = [group for group in tied_params if len(group) > 1]
self.assertListEqual(
tied_params,
[],
f"Missing `_tied_weights_keys` for {model_class}: add all of {tied_params} except one.",
)


TOLERANCE = 1e-4

Expand Down
Loading