Skip to content

Commit

Permalink
Adding _tie_weights() to more models
Browse files Browse the repository at this point in the history
I had to run the following command to see all the failures from the newly
added test:

pytest -k "test_save_load_low_cpu_mem_usage" tests/

The test also took a while. Going forward, we should probably ask people to
run this command when they modify a common test.
  • Loading branch information
hackyon committed Feb 14, 2024
1 parent 5f06053 commit 63a6230
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 20 deletions.
9 changes: 7 additions & 2 deletions src/transformers/models/albert/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,8 +887,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return prediction_scores

def _tie_weights(self) -> None:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


class AlbertSOPHead(nn.Module):
Expand Down Expand Up @@ -925,6 +929,7 @@ def get_output_embeddings(self) -> nn.Linear:

def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.predictions.decoder = new_embeddings
self.predictions.bias = new_embeddings.bias

def get_input_embeddings(self) -> nn.Embedding:
return self.albert.embeddings.word_embeddings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -844,8 +844,11 @@ def forward(self, hidden_states):
return logits

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias


@add_start_docstrings(
Expand All @@ -872,6 +875,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1124,6 +1125,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1221,6 +1222,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down
12 changes: 9 additions & 3 deletions src/transformers/models/fnet/modeling_fnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,13 @@ def forward(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


class FNetOnlyMLMHead(nn.Module):
Expand Down Expand Up @@ -627,6 +631,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=FNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -721,6 +726,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/models/ibert/modeling_ibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(IBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -955,9 +956,13 @@ def forward(self, features, **kwargs):

return x

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


@add_start_docstrings(
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/mobilebert/modeling_mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self) -> None:
self.decoder.bias = self.bias

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.transform(hidden_states)
hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
Expand Down Expand Up @@ -938,8 +941,9 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.cls.predictions.decoder

def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
# resize dense output embedings at first
Expand Down Expand Up @@ -1047,8 +1051,9 @@ def __init__(self, config):
def get_output_embeddings(self):
return self.cls.predictions.decoder

def set_output_embeddings(self, new_embeddigs):
self.cls.predictions.decoder = new_embeddigs
def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
# resize dense output embedings at first
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/realm/modeling_realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self):
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -1393,6 +1396,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(
REALM_INPUTS_DOCSTRING.format("batch_size, num_candidates, sequence_length")
Expand Down
12 changes: 9 additions & 3 deletions src/transformers/models/reformer/modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,9 +1771,13 @@ def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


class ReformerPreTrainedModel(PreTrainedModel):
Expand Down Expand Up @@ -2211,6 +2215,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
Expand Down Expand Up @@ -2331,6 +2336,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/roformer/modeling_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self) -> None:
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -961,6 +964,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1060,6 +1064,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/squeezebert/modeling_squeezebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ def __init__(self, config):
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias

def _tie_weights(self) -> None:
self.decoder.bias = self.bias

def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
Expand Down Expand Up @@ -661,6 +664,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.cls.predictions.decoder = new_embeddings
self.cls.predictions.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(XLM_ROBERTA_XL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1014,6 +1015,7 @@ def get_output_embeddings(self):

def set_output_embeddings(self, new_embeddings):
self.lm_head.decoder = new_embeddings
self.lm_head.bias = new_embeddings.bias

@add_start_docstrings_to_model_forward(XLM_ROBERTA_XL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down Expand Up @@ -1102,9 +1104,13 @@ def forward(self, features, **kwargs):

return x

def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias
def _tie_weights(self) -> None:
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias


@add_start_docstrings(
Expand Down

0 comments on commit 63a6230

Please sign in to comment.