-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True #29024
Changes from all commits
1184601
3dbcfa1
e55d4a8
52695f4
3d5a7b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -877,8 +877,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
return prediction_scores | ||
|
||
def _tie_weights(self) -> None: | ||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) | ||
self.bias = self.decoder.bias | ||
# For accelerate compatibility and to not break backward compatibility | ||
if self.decoder.bias.device.type == "meta": | ||
self.decoder.bias = self.bias | ||
else: | ||
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) | ||
self.bias = self.decoder.bias | ||
|
||
|
||
class AlbertSOPHead(nn.Module): | ||
|
@@ -915,6 +919,7 @@ def get_output_embeddings(self) -> nn.Linear: | |
|
||
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: | ||
self.predictions.decoder = new_embeddings | ||
self.predictions.bias = new_embeddings.bias | ||
Comment on lines
921
to
+922
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is a bit strange, we should only set embedding no the bias here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, I was caught a bit off guard here too. The reason for this is because of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed. I just never saw one that had bias but you are right it's entirely possible |
||
|
||
def get_input_embeddings(self) -> nn.Embedding: | ||
return self.albert.embeddings.word_embeddings | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my own understanding could you provide a bit more information here about the two cases. In particular, my questions are:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied this from existing tie_weights():
transformers/src/transformers/models/roberta/modeling_roberta.py
Line 1141 in f3aa7db
Here is more context on where that's added #19906.
I agree that the existing implementation of
tie_weights()
doesn't make sense. However, I added the if-statement to keep backward compatibility.Here's my 2 cents on how this all works:
When loading the model, only one pointer out of the tied params would be loaded with the correct values (when device=meta). Let's call this the "canonical" pointer. All the other tied params must copy from the canonical pointer, and it can't really be the other way around (at least for device=meta case).
The canonical pointer is the key stored in the loaded
state_dict
. In the current logic forsave_pretrained()
, the "canonical" pointer is the weight key that's not in_tied_weight_keys
list. In this case, that would makeself.bias
the canonical pointer.However, I imagine it's still possible for some (older?) pretrained model to have some other logic of choosing which is the "canonical" pointer, and so I added the if-statement for backwards compatibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One note about the case where device != meta. When that happens, the direction of the pointer assignment doesn't matter as much. Whether you call
self.bias = self.decoder.bias
orself.decoder.bias = self.bias
, when the params are loaded, they show up in both since they almost act like a pointer reference.When device == meta, the pointer assignment operator doesn't have the same pointer reference relationship. It just copies the sizes/info of the other tensor. When one of those tensors are loaded (the canonical one), the other one needs to copy the pointer reference accordingly, it is not automatically loaded the same way.