Skip to content

Commit

Permalink
Adding _tie_weights() to prediction heads to support low_cpu_mem_usag…
Browse files Browse the repository at this point in the history
…e=True (#29024)

* Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True

* Testing for the non-safe-tensors case, since the default is safe-tensors already

* Running fixup/fix-copies

* Adding accelerate annotations to tests
  • Loading branch information
hackyon authored and Ita Zaporozhets committed May 14, 2024
1 parent 5c9fbee commit 8d8bd8f
Show file tree
Hide file tree
Showing 42 changed files with 366 additions and 20 deletions.
9 changes: 7 additions & 2 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return prediction_scores

def _tie_weights(self) -> None:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


class AlbertSOPHead(nn.Module):
Expand Down Expand Up @@ -915,6 +919,7 @@ def get_output_embeddings(self) -> nn.Linear:

def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.predictions.decoder = new_embeddings
self.predictions.bias = new_embeddings.bias

def get_input_embeddings(self) -> nn.Embedding:
return self.albert.embeddings.word_embeddings
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -1186,6 +1189,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1295,6 +1299,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1448,6 +1453,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -851,8 +851,12 @@ def forward(self, hidden_states):
return logits

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


@add_start_docstrings(
Expand All @@ -879,6 +883,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -2263,6 +2266,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=BigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -2375,6 +2379,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -2516,6 +2521,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -818,6 +821,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

def forward(
self,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1117,6 +1118,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1217,6 +1218,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/ernie/modeling_ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -990,6 +993,7 @@ def get_output_embeddings(self):
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1104,6 +1108,7 @@ def get_output_embeddings(self):
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1264,6 +1269,7 @@ def get_output_embeddings(self):
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/flava/modeling_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,6 +1661,9 @@ def __init__(self, config, weight=None):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, x):
x = self.transform(x)
x = self.decoder(x)
Expand Down
12 changes: 9 additions & 3 deletions src/transformers/models/fnet/modeling_fnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,13 @@ def forward(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


class FNetOnlyMLMHead(nn.Module):
Expand Down Expand Up @@ -624,6 +628,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=FNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -718,6 +723,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/fsmt/modeling_fsmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,9 @@ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding):
self.output_projection = nn.Linear(embed_tokens_weight_shape[1], embed_tokens_weight_shape[0], bias=False)
self.output_projection.weight = self.embed_tokens.weight

def _tie_weights(self):
self.embed_tokens.weight = self.output_projection.weight

def forward(
self,
input_ids: torch.Tensor,
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/models/ibert/modeling_ibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(IBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -952,9 +953,13 @@ def forward(self, features, **kwargs):

return x

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


@add_start_docstrings(
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/layoutlm/modeling_layoutlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -873,6 +876,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/markuplm/modeling_markuplm.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -1021,6 +1024,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1130,6 +1134,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1288,6 +1293,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self) -> None:
self.decoder.bias = self.bias

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
Expand Down Expand Up @@ -938,8 +941,9 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.cls.predictions.decoder

def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
# resize dense output embedings at first
Expand Down Expand Up @@ -1047,8 +1051,9 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.cls.predictions.decoder

def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
# resize dense output embedings at first
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/mpnet/modeling_mpnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -656,6 +657,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
Expand Down
Loading

0 comments on commit 8d8bd8f

Please sign in to comment.