Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True #29024

Merged
merged 5 commits into from
May 7, 2024

Conversation

hackyon
Copy link
Contributor

@hackyon hackyon commented Feb 14, 2024

What does this PR do?

This is a follow-up to #28948. That PR was rolled back, and this rolls it forward with all the necessary fixes.

I had to explicitly run the following command (it will take quite a long time) and verify that all the tests pass:
pytest -k "test_save_load_low_cpu_mem_usage" tests/

See #28948 for repro steps of the bug.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@amyeroberts

@hackyon
Copy link
Contributor Author

hackyon commented Feb 14, 2024

Still working through another 10 failing cases (where it's not so obvious), so only marking as draft for now.

tests/models/flava/test_modeling_flava.py .....F
tests/models/encodec/test_modeling_encodec.py F
tests/models/fsmt/test_modeling_fsmt.py F
tests/models/lxmert/test_modeling_lxmert.py F
tests/models/marian/test_modeling_marian.py F.

tests/models/musicgen/test_modeling_musicgen.py .F

tests/models/sew/test_modeling_sew.py F
tests/models/sew_d/test_modeling_sew_d.py F

tests/models/timm_backbone/test_modeling_timm_backbone.py F

@hackyon hackyon force-pushed the fix-more-tied-weights branch from e113920 to 63a6230 Compare February 14, 2024 23:37
@hackyon
Copy link
Contributor Author

hackyon commented Feb 15, 2024

Cool, I fixed the remaining missing _tie_weights(), and also added some more skip tests for some special audio/vision models (many failing due to use of nn.utils.weight_norm).

I ran the following and made sure all the tests pass:
pytest -k test_save_load_low_cpu_mem_usage tests/

Also, what do you think documenting the need to run the above command when modifying common tests? Perhaps it can go into the Contributions page? I can follow up with this if you think it makes sense.

@hackyon hackyon marked this pull request as ready for review February 15, 2024 15:04
@hackyon hackyon force-pushed the fix-more-tied-weights branch from 312d928 to 12b9e55 Compare February 15, 2024 19:04
@hackyon
Copy link
Contributor Author

hackyon commented Feb 15, 2024

The test failure in tests_tf looks unrelated. Any chance you can kick off a re-run of the CI checks? 🙏

Also, I've verified again that pytest -k test_save_load_low_cpu_mem_usage tests/ passes.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

I'm not super familiar with all the tie weights logic, so most of my review is questions to better understand.

cc @muellerzr to confirm expected behaviour with accelerate

# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then
# subsequently loaded with the correct values and onto the correct device. We check if there are any
# remaining params that were not properly loaded.
for name, param in model.named_parameters():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :)

Comment on lines +893 to +885
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding could you provide a bit more information here about the two cases. In particular, my questions are:

  • Which of the two branches of if/else here are for accelerate compatibility?
  • How was accelerate compatibility caught i.e. which tests?
  • Do we need both of these assignments? I'm assuming yes but it's not clear to me why one works in one case and not the other.

Copy link
Contributor Author

@hackyon hackyon Feb 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from existing tie_weights():

Here is more context on where that's added #19906.

I agree that the existing implementation of tie_weights() doesn't make sense. However, I added the if-statement to keep backward compatibility.

Here's my 2 cents on how this all works:

When loading the model, only one pointer out of the tied params would be loaded with the correct values (when device=meta). Let's call this the "canonical" pointer. All the other tied params must copy from the canonical pointer, and it can't really be the other way around (at least for device=meta case).

The canonical pointer is the key stored in the loaded state_dict. In the current logic for save_pretrained(), the "canonical" pointer is the weight key that's not in _tied_weight_keys list. In this case, that would make self.bias the canonical pointer.

However, I imagine it's still possible for some (older?) pretrained model to have some other logic of choosing which is the "canonical" pointer, and so I added the if-statement for backwards compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One note about the case where device != meta. When that happens, the direction of the pointer assignment doesn't matter as much. Whether you call self.bias = self.decoder.bias or self.decoder.bias = self.bias, when the params are loaded, they show up in both since they almost act like a pointer reference.

When device == meta, the pointer assignment operator doesn't have the same pointer reference relationship. It just copies the sizes/info of the other tensor. When one of those tensors are loaded (the canonical one), the other one needs to copy the pointer reference accordingly, it is not automatically loaded the same way.

src/transformers/models/lxmert/modeling_lxmert.py Outdated Show resolved Hide resolved
Comment on lines 709 to 710
if new_size != self.bias.shape[0]:
self.bias.data = nn.functional.pad(self.bias.data, (0, new_size - self.bias.shape[0]), "constant", 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit weird to me - why would we need to resize if we're tying weights?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the case where resize_token_embeddings() is called.

After resizing, tie_weights() is called. Because the embedding weights is resized, we also need to resize the bias so it will fit the embedding weights.

There is similar logic in the common _tie_or_clone_weights:

def _tie_or_clone_weights(self, output_embeddings, input_embeddings):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - thanks for explaining!

@muellerzr
Copy link
Contributor

cc @SunMarc :)

@hackyon
Copy link
Contributor Author

hackyon commented Feb 15, 2024

Thanks for the review!

I added the explanation of tie_weights() from my research, but it'd be great to get some feedback from someone who's more knowledgeable on this.

@hackyon hackyon mentioned this pull request Feb 16, 2024
5 tasks
@hackyon
Copy link
Contributor Author

hackyon commented Feb 16, 2024

cc @SunMarc @muellerzr

Don't meant to be pushy, but the tests for the models in this change are currently broken in main/HEAD, so I'd be grateful if you could give this a look in the next couple of days. Thanks!

@hackyon
Copy link
Contributor Author

hackyon commented Feb 19, 2024

cc @SunMarc @muellerzr

Don't meant to be pushy, but the tests for the models in this change are currently broken in main/HEAD, so I'd be grateful if you could give this a look in the next couple of days. Thanks!

I added PR #29118 to skip the failing tests, so we'd have more time to discuss this change. Feel free to comment/review whenever you get a chance. Thanks and greatly appreciate your input on this!

For more context, the tie_weights() I'm adding here should enable loading those models with low_cpu_mem_usage=True (it's currently unsupported for those models).

The tie_weights() should also be helpful in getting accelerate to work on more models, since we need it to properly infer the device map.

Cheers.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I might have missed a few things but:

  • this feels like a big regression: having to write and add the def _tie_weights(self): for all of these model is a bit strange to me
  • are we not changing the API by supporting bias tie? If yes I am against it. That should only be done manually by the user!
  • could you sum up in 1 line what was wrong with the previous behaviour and why we need to tie biases for these models?
  • copied from should be used

Comment on lines 931 to +922
self.predictions.decoder = new_embeddings
self.predictions.bias = new_embeddings.bias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a bit strange, we should only set embedding no the bias here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I was caught a bit off guard here too.

The reason for this is because of resize_token_embeddings(), which uses these setters to set new embeddings. We need to copy the new bias from the embeddings since the size may have changed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I just never saw one that had bias but you are right it's entirely possible

@hackyon
Copy link
Contributor Author

hackyon commented Feb 20, 2024

Sorry I added the wrong link in the PR description, this issue is a follow up of #28948. There's context in that link (tl;dr adding the tie_weights() enable those models to be loaded with low_cpu_mem_usage=True)

We're adding new functionality with these tie_weights(). We're basically adding support for low_cpu_mem_usage=True for all these models.

The functionality kind of snowballed out another unrelated change for SDPA #28802, since the common test for SDPA uses low_cpu_mem_usage. I looked into it, and figured it could be a good idea to add support for low_cpu_mem_usage to a bunch of models as well while I'm at it.

  • this feels like a big regression: having to write and add the def _tie_weights(self): for all of these model is a bit strange to me
  • are we not changing the API by supporting bias tie? If yes I am against it. That should only be done manually by the user!

The weights were already tied in __init__ (with self.decoder.bias = self.bias), but those get untied when loading through low_cpu_mem_usage, and needs to get retied.

If we are not loading with low_cpu_mem_usage, those biases will already be tied. If we save with save_pretrained, only one copy of those biases will be saved.

  • could you sum up in 1 line what was wrong with the previous behaviour and why we need to tie biases for these models?

Those models fail to load with low_cpu_mem_usage=True.

  • copied from should be used

Good point. A lot of these prediction heads were more-or-less copied from other existing heads, but not sure why they were not marked copied-from. I'll see if I can add back some fix-copies.

@ydshieh
Copy link
Collaborator

ydshieh commented Feb 20, 2024

Hi @hackyon I haven't followed your work (btw, thank you for the contribution!), but just jump in to let you know:

if you put a prefix [test all] in a commit message, that commit will trigger a full CI run.

For example, a commit message like [test all] check my commit is perfect.

This way, you have an easy way to check if the changes are all good.

(Our test fetcher is a trade-off of coverage and speed. But we will try to improve it)

@ArthurZucker
Copy link
Collaborator

Thanks for the details!
Alright let's make sure we take everything into account, safe serialization and unsafe (bin as well).
You can also just use copied from for single functions, but the idea is to make sure we have a fix in a single place and the rest is just the copy of the fix! Same for the test, copied from can be used for the new test

@hackyon hackyon force-pushed the fix-more-tied-weights branch from b7216df to 315246f Compare February 23, 2024 14:18
@hackyon hackyon changed the title Adding _tie_weights() to more models Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True Feb 23, 2024
@hackyon hackyon force-pushed the fix-more-tied-weights branch from 315246f to b51bd11 Compare February 23, 2024 14:26
@hackyon
Copy link
Contributor Author

hackyon commented Feb 23, 2024

The latest commit should be based off of main/HEAD and has all the necessary changes to tie_weights().

It includes additional tests for safe tensors and also checkpoint bins, and also more checks to ensure that the models are equivalent after loading. I added @slow to them since it's a lot more checks now, and the whole thing takes a longer to run.

I checked that there were already copied-from for many of the LM heads, and the ones that don't tend to have a good reason not to be copied-from.

@hackyon hackyon force-pushed the fix-more-tied-weights branch from b51bd11 to a4e2492 Compare February 23, 2024 14:29
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR ! Left one comment. Could you have a look @younesbelkada since you did this PR which is very similar. Thanks for adding the tests too.

tests/test_modeling_common.py Show resolved Hide resolved
@hackyon hackyon force-pushed the fix-more-tied-weights branch from 044e30c to c252ce1 Compare February 28, 2024 19:09
@hackyon
Copy link
Contributor Author

hackyon commented Mar 6, 2024

@SunMarc @younesbelkada

Any other insights or concerns over the use of tie_weights() here? Thanks!

@hackyon hackyon force-pushed the fix-more-tied-weights branch from c252ce1 to e6a1e76 Compare March 7, 2024 14:56
@hackyon
Copy link
Contributor Author

hackyon commented Mar 19, 2024

@SunMarc @younesbelkada

I just sync'd latest HEAD. Would you mind taking a quick look? Thanks!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The modification on the models looks good to me ! The only thing i'm concerned is the tests. Do we really want to add tests for low_cpu_mem_usage=True since

  • we would have to skip tests if they are not passing for all models
  • Takes a lot of time to run
  • these tests are kind of already implemented in the accelerate_tests: test_disk_offload_bin, test_disk_offload_safetensors, test_cpu_offload, test_model_parallelism if we make it compatible with device_map

If we manage to define the no_split_modules (should be pretty easy to make it work), low_cpu_mem_usage=True is the same as doing device_map = {"":"cpu"} or device_map = "cpu". And it would also work for multi-gpu, cpu and disk offload .
Contrary to device_map where the model needs _no_split_modules to be defined, the user can freely use the low_cpu_mem_usage argument. So, we should definitely test this somehow or warn the users that it might not work properly. LMK what you think @hackyon @ArthurZucker !

tests/test_modeling_common.py Show resolved Hide resolved
@hackyon hackyon force-pushed the fix-more-tied-weights branch from 25049e3 to 52695f4 Compare April 16, 2024 17:29
@hackyon
Copy link
Contributor Author

hackyon commented Apr 16, 2024

Cool, thanks for your review @SunMarc!

@ArthurZucker - would you mind taking a quick look again? Thank you 🙏

The modification on the models looks good to me ! The only thing i'm concerned is the tests. Do we really want to add tests for low_cpu_mem_usage=True since

  • we would have to skip tests if they are not passing for all models
  • Takes a lot of time to run
  • these tests are kind of already implemented in the accelerate_tests: test_disk_offload_bin, test_disk_offload_safetensors, test_cpu_offload, test_model_parallelism if we make it compatible with device_map

Good point. I believe that just the test_save_load_low_cpu_mem_usage would probably be sufficient as well, but I added the other tests as per the suggestion in one of the review comments. The tests are only run in RUN_SLOW, but if they are redundant, then maybe it's still not worth adding.

I'll leave it up to @ArthurZucker to decide.

If we manage to define the no_split_modules (should be pretty easy to make it work), low_cpu_mem_usage=True is the same as doing device_map = {"":"cpu"} or device_map = "cpu". And it would also work for multi-gpu, cpu and disk offload . Contrary to device_map where the model needs _no_split_modules to be defined, the user can freely use the low_cpu_mem_usage argument. So, we should definitely test this somehow or warn the users that it might not work properly. LMK what you think @hackyon @ArthurZucker !

I believe low_cpu_mem_usage should work even without no_split_modules, no? AFAIT the low_cpu_mem_usage flow doesn't do any model splitting, it just tries to load the models with the meta device and try to keep one copy of the weights in memory at a time.

Yea, I won't mind helping to add more no_split_modules if that is something you think is relatively high priority. From what I can tell, it's just a matter of checking for residual connections and making sure those are not split? I get the sense that there is a bit of a backlog of PRs recently (and I understand some of you might be overworked!), so I wouldn't want to push on new initiatives unnecessarily.

@SunMarc
Copy link
Member

SunMarc commented Apr 17, 2024

Hi @hackyon, thanks for still working on this PR !

Yea, I won't mind helping to add more no_split_modules if that is something you think is relatively high priority. From what I can tell, it's just a matter of checking for residual connections and making sure those are not split? I get the sense that there is a bit of a backlog of PRs recently (and I understand some of you might be overworked!), so I wouldn't want to push on new initiatives unnecessarily.

I think it could be great to add no_split_modules to the most used model on the hub ! Whenever there is a new model coming out, we try to make it compatible with device_map but for older model, we might have missed some of them. I can definitely review your PR if you do so ! Related issue.

@hackyon
Copy link
Contributor Author

hackyon commented Apr 22, 2024

@SunMarc Nice! Sounds like there is some activity in #29786.

Do you have permissions to merge open source code into main? Otherwise, do you know who else might have? I do have some bandwidth to work on in over the next 2-3 weeks, but I'd like to get this current PR in first before I start anything else (and so far I'm not receiving responses for my PRs, so there's no point in contributing more at this point).

@SunMarc
Copy link
Member

SunMarc commented Apr 22, 2024

Hi @hackyon, I can indeed merge but the code needs to be validated from a core maintainer. Could you have a look @ArthurZucker ? Thanks again @hackyon for your patience.

@SunMarc SunMarc requested a review from ArthurZucker April 23, 2024 09:04
@ArthurZucker
Copy link
Collaborator

Let me have a look!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry for the late review.
Really great work 🔥 It's clean and the test is well designed!

Kudos to you @hackyon

Comment on lines 931 to +922
self.predictions.decoder = new_embeddings
self.predictions.bias = new_embeddings.bias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I just never saw one that had bias but you are right it's entirely possible

model_to_save.save_pretrained(saved_model_path, safe_serialization=False)
self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path)

def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for designing this thorough test! 🤗

@ArthurZucker ArthurZucker merged commit 54a2361 into huggingface:main May 7, 2024
21 checks passed
@hackyon
Copy link
Contributor Author

hackyon commented May 7, 2024

Thank you for taking a look and merging! 🙏

@hackyon hackyon deleted the fix-more-tied-weights branch May 8, 2024 12:36
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 10, 2024
…e=True (huggingface#29024)

* Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True

* Testing for the non-safe-tensors case, since the default is safe-tensors already

* Running fixup/fix-copies

* Adding accelerate annotations to tests
itazap pushed a commit that referenced this pull request May 14, 2024
…e=True (#29024)

* Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True

* Testing for the non-safe-tensors case, since the default is safe-tensors already

* Running fixup/fix-copies

* Adding accelerate annotations to tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants