Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tie_weights() to LM heads and set bias in set_output_embeddings() #28948

Merged
merged 5 commits into from
Feb 14, 2024

Conversation

hackyon
Copy link
Contributor

@hackyon hackyon commented Feb 9, 2024

What does this PR do?

This fixes a bug from the wrong bias in prediction heads in some situations. The predictions.bias needs to be tied to predictions.decoder.bias inside tie_weights().

Repro Steps:

  1. Sync to HEAD in main
  2. Add the following test case to test_modeling_bert.py and run it:
    def test_save_load_bert_prediction_head(self):
        with tempfile.TemporaryDirectory() as tmpdirname:
            model_to_save = BertForMaskedLM.from_pretrained("bert-base-uncased")

            model_to_save.save_pretrained(tmpdirname)

            model = BertForMaskedLM.from_pretrained(
                tmpdirname,
                low_cpu_mem_usage=True,
            )
            model.to(torch_device)
  1. Error is thrown:
FAILED tests/models/bert/test_modeling_bert.py::BertModelTest::test_save_load_bert_prediction_head - NotImplementedError: Cannot copy out of meta tensor; no data!
    def convert(t):
        if convert_to_format is not None and t.dim() in (4, 5):
            return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None,
                        non_blocking, memory_format=convert_to_format)
>       return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
E       NotImplementedError: Cannot copy out of meta tensor; no data!

../venv/lib/python3.9/site-packages/torch/nn/modules/module.py:1158: NotImplementedError

The issue was uncovered in #28802.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @amyeroberts

@hackyon hackyon mentioned this pull request Feb 9, 2024
5 tasks
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for digging into this and fixing for all these models!

Just two things I think we need to do before merging:

  • Run the slow model tests for these models to confirm this change is backwards compatible.
  • Apply the test to all models.

@@ -600,6 +600,24 @@ def test_model_from_pretrained(self):
model = BertModel.from_pretrained(model_name)
self.assertIsNotNone(model)

@slow
def test_save_and_load_low_cpu_mem_usage(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be applied to all the models. Could you move the test to ModelTesterMixin?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

I also noticed that there were other non-slow tests in ModelTesterMixin that loaded/saved pretrained models.

I had a hunch that the saving/loading should actually be pretty fast, and so ran the tests without the @slow tag. The before/after runs on CircleCI showed similar timings for tests_torch, so my 2 cents is that we don't need to mark this test as @slow.

Before run (tests_torch ran in 6m58s):
https://app.circleci.com/pipelines/github/huggingface/transformers/84140/workflows/0998b8ba-9336-49e5-9680-2ddd86443669

After run (tests_torch ran in 6m48s):
https://app.circleci.com/pipelines/github/huggingface/transformers/84451/workflows/f5efa9af-c079-4303-8611-574f8a3bf7bd

@hackyon hackyon force-pushed the fix-bert-tie-weights branch from eafefb4 to e40f605 Compare February 14, 2024 15:43
@@ -520,6 +520,10 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

@unittest.skip("Cannot be initialized on meta device as some weights are modified during the initialization")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This came from an existing comment:

# We can't initialize the model on meta device as some weights are modified during the initialization

It seems like the DetaForObjectDetection model has logic to manipulate the params during init, and so it simply doesn't support initializing on a meta device.

With that said, should we add some kind of flag in PreTrainedModel that would throw an error when trying to initial on meta device for models that do not support it? So far, it seems like DetaModel is the only one, so maybe not worth the effort to do so.

cc @SunMarc who added this: #27089

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep as-is for the moment, until we have more affected models. DETR and DETR-like models are special but I think they're the only case.

@hackyon hackyon force-pushed the fix-bert-tie-weights branch 2 times, most recently from 59b1fa2 to e454391 Compare February 14, 2024 17:45
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work - thanks for adding this!

@@ -520,6 +520,10 @@ def test_initialization(self):
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)

@unittest.skip("Cannot be initialized on meta device as some weights are modified during the initialization")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep as-is for the moment, until we have more affected models. DETR and DETR-like models are special but I think they're the only case.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts amyeroberts merged commit 725f4ad into huggingface:main Feb 14, 2024
18 checks passed
@hackyon hackyon deleted the fix-bert-tie-weights branch February 14, 2024 20:43
sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Feb 14, 2024
…huggingface#28948)

* Add tie_weights() to LM heads and set bias in set_output_embeddings()

The bias were not tied correctly in some LM heads, and this change should fix that.

* Moving test_save_and_load_low_cpu_mem_usage to ModelTesterMixin

* Adding _tie_weights() to MPNet and Vilt

* Skip test for low cpu mem usage for Deta/DeformableDetr since they cannot init on meta device

* Rename to test name to save_load to match the convention
amyeroberts added a commit to amyeroberts/transformers that referenced this pull request Feb 20, 2024
amyeroberts added a commit that referenced this pull request Feb 20, 2024
* Revert "Add tie_weights() to LM heads and set bias in set_output_embeddings() (#28948)"

This reverts commit 725f4ad.

* Revert "Patch to skip failing `test_save_load_low_cpu_mem_usage` tests (#29043)"

This reverts commit 4156f51.
itazap pushed a commit that referenced this pull request May 14, 2024
…#28948)

* Add tie_weights() to LM heads and set bias in set_output_embeddings()

The bias were not tied correctly in some LM heads, and this change should fix that.

* Moving test_save_and_load_low_cpu_mem_usage to ModelTesterMixin

* Adding _tie_weights() to MPNet and Vilt

* Skip test for low cpu mem usage for Deta/DeformableDetr since they cannot init on meta device

* Rename to test name to save_load to match the convention
itazap pushed a commit that referenced this pull request May 14, 2024
* Revert "Add tie_weights() to LM heads and set bias in set_output_embeddings() (#28948)"

This reverts commit 725f4ad.

* Revert "Patch to skip failing `test_save_load_low_cpu_mem_usage` tests (#29043)"

This reverts commit 4156f51.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants