From 7d718e2a1d270a50aa52eba1bf07541d6ec9a2e9 Mon Sep 17 00:00:00 2001 From: "JB (Don)" <1557853+hackyon@users.noreply.github.com> Date: Tue, 7 May 2024 17:12:21 +0800 Subject: [PATCH] Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True (#29024) * Adding _tie_weights() to prediction heads to support low_cpu_mem_usage=True * Testing for the non-safe-tensors case, since the default is safe-tensors already * Running fixup/fix-copies * Adding accelerate annotations to tests --- .../models/albert/modeling_albert.py | 9 +- src/transformers/models/bert/modeling_bert.py | 6 ++ .../modeling_bert_generation.py | 9 +- .../models/big_bird/modeling_big_bird.py | 6 ++ .../models/blip/modeling_blip_text.py | 4 + .../models/deberta/modeling_deberta.py | 4 + .../models/deberta_v2/modeling_deberta_v2.py | 4 + .../models/ernie/modeling_ernie.py | 6 ++ .../models/flava/modeling_flava.py | 3 + src/transformers/models/fnet/modeling_fnet.py | 12 ++- src/transformers/models/fsmt/modeling_fsmt.py | 3 + .../models/ibert/modeling_ibert.py | 11 ++- .../models/layoutlm/modeling_layoutlm.py | 4 + .../models/markuplm/modeling_markuplm.py | 3 + .../megatron_bert/modeling_megatron_bert.py | 6 ++ .../models/mobilebert/modeling_mobilebert.py | 13 ++- .../models/mpnet/modeling_mpnet.py | 4 + src/transformers/models/mra/modeling_mra.py | 4 + .../models/nezha/modeling_nezha.py | 5 ++ .../nystromformer/modeling_nystromformer.py | 4 + .../models/qdqbert/modeling_qdqbert.py | 5 ++ .../models/realm/modeling_realm.py | 4 + .../models/reformer/modeling_reformer.py | 12 ++- .../models/roc_bert/modeling_roc_bert.py | 6 ++ .../models/roformer/modeling_roformer.py | 5 ++ .../squeezebert/modeling_squeezebert.py | 4 + .../models/tapas/modeling_tapas.py | 4 + src/transformers/models/vilt/modeling_vilt.py | 4 + .../visual_bert/modeling_visual_bert.py | 4 + .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 12 ++- src/transformers/models/yoso/modeling_yoso.py | 4 + .../test_modeling_deformable_detr.py | 12 +++ tests/models/deta/test_modeling_deta.py | 12 +++ tests/models/encodec/test_modeling_encodec.py | 12 +++ tests/models/lxmert/test_modeling_lxmert.py | 12 +++ tests/models/marian/test_modeling_marian.py | 12 +++ .../models/musicgen/test_modeling_musicgen.py | 12 +++ .../test_modeling_musicgen_melody.py | 12 +++ tests/models/sew/test_modeling_sew.py | 12 +++ tests/models/sew_d/test_modeling_sew_d.py | 12 +++ .../test_modeling_timm_backbone.py | 12 +++ tests/test_modeling_common.py | 82 +++++++++++++++++++ 42 files changed, 366 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 87f5a9e30c8f54..ff50f2f1293e17 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -877,8 +877,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return prediction_scores def _tie_weights(self) -> None: - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias class AlbertSOPHead(nn.Module): @@ -915,6 +919,7 @@ def get_output_embeddings(self) -> nn.Linear: def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: self.predictions.decoder = new_embeddings + self.predictions.bias = new_embeddings.bias def get_input_embeddings(self) -> nn.Embedding: return self.albert.embeddings.word_embeddings diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 9e2847b11b53f0..129336cc5280c0 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -778,6 +778,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1186,6 +1189,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1295,6 +1299,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1448,6 +1453,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 73c4d1d1e5da91..a5fb3d0531153e 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -851,8 +851,12 @@ def forward(self, hidden_states): return logits def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias @add_start_docstrings( @@ -879,6 +883,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 510c98079501ef..6e5363f0bc6e57 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1704,6 +1704,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -2263,6 +2266,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=BigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -2375,6 +2379,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @@ -2516,6 +2521,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 3eb6ad45791030..a800ba89825dcb 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -523,6 +523,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -818,6 +821,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias def forward( self, diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 42dae5c80894a8..02047a5cffd448 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -1021,6 +1021,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1117,6 +1118,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index dfe18b0d4964af..f898c33af09492 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -1120,6 +1120,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1217,6 +1218,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 3db6501985604d..95e27121bc2046 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -603,6 +603,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -990,6 +993,7 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1104,6 +1108,7 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1264,6 +1269,7 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 19f19d4c9d5666..d967335d8e0068 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -1661,6 +1661,9 @@ def __init__(self, config, weight=None): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, x): x = self.transform(x) x = self.decoder(x) diff --git a/src/transformers/models/fnet/modeling_fnet.py b/src/transformers/models/fnet/modeling_fnet.py index 5724faee56cf85..a11b1c87a0254c 100755 --- a/src/transformers/models/fnet/modeling_fnet.py +++ b/src/transformers/models/fnet/modeling_fnet.py @@ -355,9 +355,13 @@ def forward(self, hidden_states): hidden_states = self.decoder(hidden_states) return hidden_states - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + def _tie_weights(self) -> None: + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias class FNetOnlyMLMHead(nn.Module): @@ -624,6 +628,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=FNetForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -718,6 +723,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(FNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 4c180c52678b82..4a0e591d62f580 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -692,6 +692,9 @@ def __init__(self, config: FSMTConfig, embed_tokens: nn.Embedding): self.output_projection = nn.Linear(embed_tokens_weight_shape[1], embed_tokens_weight_shape[0], bias=False) self.output_projection.weight = self.embed_tokens.weight + def _tie_weights(self): + self.embed_tokens.weight = self.output_projection.weight + def forward( self, input_ids: torch.Tensor, diff --git a/src/transformers/models/ibert/modeling_ibert.py b/src/transformers/models/ibert/modeling_ibert.py index 54c37f507e3a63..f06557c2616078 100644 --- a/src/transformers/models/ibert/modeling_ibert.py +++ b/src/transformers/models/ibert/modeling_ibert.py @@ -868,6 +868,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(IBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -952,9 +953,13 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + def _tie_weights(self) -> None: + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias @add_start_docstrings( diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index 6914f5ee3efb62..98765b3f75ff29 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -594,6 +594,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -873,6 +876,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 318110daf5d8d1..707f612459ddc0 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -316,6 +316,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 528bcca3d9bc00..a9d228bf3bb652 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -657,6 +657,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1021,6 +1024,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1130,6 +1134,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @@ -1288,6 +1293,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 8dc0aafa70fc25..92a18dfe599041 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -650,6 +650,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self) -> None: + self.decoder.bias = self.bias + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.transform(hidden_states) hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0)) @@ -938,8 +941,9 @@ def __init__(self, config): def get_output_embeddings(self): return self.cls.predictions.decoder - def set_output_embeddings(self, new_embeddigs): - self.cls.predictions.decoder = new_embeddigs + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding: # resize dense output embedings at first @@ -1047,8 +1051,9 @@ def __init__(self, config): def get_output_embeddings(self): return self.cls.predictions.decoder - def set_output_embeddings(self, new_embeddigs): - self.cls.predictions.decoder = new_embeddigs + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding: # resize dense output embedings at first diff --git a/src/transformers/models/mpnet/modeling_mpnet.py b/src/transformers/models/mpnet/modeling_mpnet.py index d9b9f90d398d90..12e6bdbffaaa7b 100644 --- a/src/transformers/models/mpnet/modeling_mpnet.py +++ b/src/transformers/models/mpnet/modeling_mpnet.py @@ -584,6 +584,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -656,6 +657,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, features, **kwargs): x = self.dense(features) x = gelu(x) diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 846578997c4a84..db918484d986cd 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -809,6 +809,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1042,6 +1045,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(MRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/nezha/modeling_nezha.py b/src/transformers/models/nezha/modeling_nezha.py index 6d983bd2378903..5ab2dc8958dff0 100644 --- a/src/transformers/models/nezha/modeling_nezha.py +++ b/src/transformers/models/nezha/modeling_nezha.py @@ -674,6 +674,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1039,6 +1042,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -1147,6 +1151,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/nystromformer/modeling_nystromformer.py b/src/transformers/models/nystromformer/modeling_nystromformer.py index 1da61bc59e6a7a..df0dd0e405c0ef 100755 --- a/src/transformers/models/nystromformer/modeling_nystromformer.py +++ b/src/transformers/models/nystromformer/modeling_nystromformer.py @@ -426,6 +426,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -664,6 +667,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index c5e9af7025842b..7f1916dc80bf5c 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -681,6 +681,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1022,6 +1025,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @@ -1188,6 +1192,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/realm/modeling_realm.py b/src/transformers/models/realm/modeling_realm.py index adec5647a28134..3753ba9dd28d01 100644 --- a/src/transformers/models/realm/modeling_realm.py +++ b/src/transformers/models/realm/modeling_realm.py @@ -797,6 +797,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1391,6 +1394,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward( REALM_INPUTS_DOCSTRING.format("batch_size, num_candidates, sequence_length") diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index e6768e897eca0c..bfb8fb5ebe1cfa 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -1768,9 +1768,13 @@ def forward_chunk(self, hidden_states): hidden_states = self.decoder(hidden_states) return hidden_states - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + def _tie_weights(self) -> None: + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias class ReformerPreTrainedModel(PreTrainedModel): @@ -2208,6 +2212,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -2328,6 +2333,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(REFORMER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 739e60b550baf3..9d8284461f6679 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -749,6 +749,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -1094,6 +1097,7 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @@ -1286,6 +1290,7 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) def forward( @@ -1423,6 +1428,7 @@ def get_output_embeddings(self): # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index b2a63221a8dc90..f7589d4853d581 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -657,6 +657,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self) -> None: + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -954,6 +957,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1053,6 +1057,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(ROFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py index b5657f6e6f5003..d95a58daaf6164 100644 --- a/src/transformers/models/squeezebert/modeling_squeezebert.py +++ b/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -400,6 +400,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self) -> None: + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -658,6 +661,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index e2ce847926b38f..97636d8b28e18e 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -699,6 +699,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -978,6 +981,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/vilt/modeling_vilt.py b/src/transformers/models/vilt/modeling_vilt.py index 5545b881bd670a..e5f775cfc6f079 100755 --- a/src/transformers/models/vilt/modeling_vilt.py +++ b/src/transformers/models/vilt/modeling_vilt.py @@ -894,6 +894,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.mlm_score.decoder = new_embeddings + self.mlm_score.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) @@ -1040,6 +1041,9 @@ def __init__(self, config, weight=None): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, x): x = self.transform(x) x = self.decoder(x) diff --git a/src/transformers/models/visual_bert/modeling_visual_bert.py b/src/transformers/models/visual_bert/modeling_visual_bert.py index 07c8b7a4b5173c..33df2ac13cf5b9 100755 --- a/src/transformers/models/visual_bert/modeling_visual_bert.py +++ b/src/transformers/models/visual_bert/modeling_visual_bert.py @@ -489,6 +489,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -869,6 +872,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index d8994e335b1242..da76ca29ae27ee 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -852,6 +852,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(XLM_ROBERTA_XL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @@ -1011,6 +1012,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings + self.lm_head.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(XLM_ROBERTA_XL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -1099,9 +1101,13 @@ def forward(self, features, **kwargs): return x - def _tie_weights(self): - # To tie those two weights if they get disconnected (on TPU or when the bias is resized) - self.bias = self.decoder.bias + def _tie_weights(self) -> None: + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + self.bias = self.decoder.bias @add_start_docstrings( diff --git a/src/transformers/models/yoso/modeling_yoso.py b/src/transformers/models/yoso/modeling_yoso.py index b1fed0acc468df..3615ea80719be1 100644 --- a/src/transformers/models/yoso/modeling_yoso.py +++ b/src/transformers/models/yoso/modeling_yoso.py @@ -627,6 +627,9 @@ def __init__(self, config): # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias + def _tie_weights(self): + self.decoder.bias = self.bias + def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) @@ -861,6 +864,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.cls.predictions.decoder = new_embeddings + self.cls.predictions.bias = new_embeddings.bias @add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/tests/models/deformable_detr/test_modeling_deformable_detr.py b/tests/models/deformable_detr/test_modeling_deformable_detr.py index 23e95267d8788c..5831d338dc3b50 100644 --- a/tests/models/deformable_detr/test_modeling_deformable_detr.py +++ b/tests/models/deformable_detr/test_modeling_deformable_detr.py @@ -578,6 +578,18 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + def test_two_stage_training(self): model_class = DeformableDetrForObjectDetection config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/deta/test_modeling_deta.py b/tests/models/deta/test_modeling_deta.py index 70c1009a508907..faab773efd79a3 100644 --- a/tests/models/deta/test_modeling_deta.py +++ b/tests/models/deta/test_modeling_deta.py @@ -528,6 +528,18 @@ def test_initialization(self): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + # Inspired by tests.test_modeling_common.ModelTesterMixin.test_tied_weights_keys def test_tied_weights_keys(self): for model_class in self.all_model_classes: diff --git a/tests/models/encodec/test_modeling_encodec.py b/tests/models/encodec/test_modeling_encodec.py index 0c021eaad21ab5..dd07f5a6e25246 100644 --- a/tests/models/encodec/test_modeling_encodec.py +++ b/tests/models/encodec/test_modeling_encodec.py @@ -325,6 +325,18 @@ def test_feed_forward_chunking(self): def test_hidden_states_output(self): pass + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + def test_determinism(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/lxmert/test_modeling_lxmert.py b/tests/models/lxmert/test_modeling_lxmert.py index 723fef6061b3e4..a98643b33cd74a 100644 --- a/tests/models/lxmert/test_modeling_lxmert.py +++ b/tests/models/lxmert/test_modeling_lxmert.py @@ -766,6 +766,18 @@ def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict): return tf_inputs_dict + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + @require_torch class LxmertModelIntegrationTest(unittest.TestCase): diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index 593ef8e3405e38..3144dd48dab21f 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -372,6 +372,18 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + def assert_tensors_close(a, b, atol=1e-12, prefix=""): """If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error.""" diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index b04f99c05ec1e7..8482072d73cfe7 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -1273,6 +1273,18 @@ def test_tied_model_weights_key_ignore(self): def test_tied_weights_keys(self): pass + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + # override since changing `output_hidden_states` / `output_attentions` from the top-level model config won't work def test_retain_grad_hidden_states_attentions(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 628cc76a093ea7..b32b5082584668 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -1258,6 +1258,18 @@ def test_tied_model_weights_key_ignore(self): def test_tied_weights_keys(self): pass + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + # override since changing `output_hidden_states` / `output_attentions` from the top-level model config won't work # Ignore copy def test_retain_grad_hidden_states_attentions(self): diff --git a/tests/models/sew/test_modeling_sew.py b/tests/models/sew/test_modeling_sew.py index 528d5f84185e6e..5342df9e088039 100644 --- a/tests/models/sew/test_modeling_sew.py +++ b/tests/models/sew/test_modeling_sew.py @@ -356,6 +356,18 @@ def test_resize_tokens_embeddings(self): def test_model_common_attributes(self): pass + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + def test_retain_grad_hidden_states_attentions(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.output_hidden_states = True diff --git a/tests/models/sew_d/test_modeling_sew_d.py b/tests/models/sew_d/test_modeling_sew_d.py index 6fda7963a8009c..1980bd3ab12166 100644 --- a/tests/models/sew_d/test_modeling_sew_d.py +++ b/tests/models/sew_d/test_modeling_sew_d.py @@ -460,6 +460,18 @@ def _mock_init_weights(self, module): def test_feed_forward_chunking(self): pass + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + @slow def test_model_from_pretrained(self): model = SEWDModel.from_pretrained("asapp/sew-d-tiny-100k") diff --git a/tests/models/timm_backbone/test_modeling_timm_backbone.py b/tests/models/timm_backbone/test_modeling_timm_backbone.py index 60ab9e2a217eb5..1cd04cd4843933 100644 --- a/tests/models/timm_backbone/test_modeling_timm_backbone.py +++ b/tests/models/timm_backbone/test_modeling_timm_backbone.py @@ -169,6 +169,18 @@ def test_from_pretrained_no_checkpoint(self): def test_save_load(self): pass + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_checkpoints(self): + pass + + @unittest.skip("No support for low_cpu_mem_usage=True.") + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + pass + @unittest.skip("model weights aren't tied in TimmBackbone.") def test_tie_model_weights(self): pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 66cac0d6c585a1..df585f4afc65e1 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -437,6 +437,88 @@ class CopyClass(model_class): max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + @slow + @require_accelerate + @mark.accelerate_tests + def test_save_load_low_cpu_mem_usage(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + with tempfile.TemporaryDirectory() as saved_model_path: + for model_class in self.all_model_classes: + model_to_save = model_class(config) + model_to_save.save_pretrained(saved_model_path) + + self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path) + + @slow + @require_accelerate + @mark.accelerate_tests + def test_save_load_low_cpu_mem_usage_checkpoints(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + with tempfile.TemporaryDirectory() as saved_model_path: + for model_class in self.all_model_classes: + model_to_save = model_class(config) + model_to_save.config.save_pretrained(saved_model_path) + torch.save(model_to_save.state_dict(), os.path.join(saved_model_path, "pytorch_model.bin")) + + self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path) + + @slow + @require_accelerate + @mark.accelerate_tests + def test_save_load_low_cpu_mem_usage_no_safetensors(self): + with tempfile.TemporaryDirectory() as saved_model_path: + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model_to_save = model_class(config) + + model_to_save.save_pretrained(saved_model_path, safe_serialization=False) + self._check_save_load_low_cpu_mem_usage(model_class, saved_model_path) + + def _check_save_load_low_cpu_mem_usage(self, model_class, saved_model_path): + # Load the low usage and the normal models. + model_low_usage, loading_info = model_class.from_pretrained( + saved_model_path, + low_cpu_mem_usage=True, + output_loading_info=True, + ) + model_non_low_usage = model_class.from_pretrained(saved_model_path) + + # Check that there were no missing keys. + self.assertEqual(loading_info["missing_keys"], []) + + # The low_cpu_mem_usage=True causes the model params to be initialized with device=meta, and then + # subsequently loaded with the correct values and onto the correct device. We check if there are any + # remaining params that were not properly loaded. + for name, param in model_low_usage.named_parameters(): + self.assertNotEqual( + param.device, + torch.device("meta"), + "Parameter '" + name + "' has not been properly loaded and has device=meta.", + ) + + # Tests moving the model to a device other than meta. + model_low_usage.to(torch_device) + + # Check that the parameters are equal. + for p1, p2 in zip(model_low_usage.parameters(), model_non_low_usage.parameters()): + self.assertEquals(p1.data.ne(p2.data).sum(), 0) + + # Check that the state dict keys are equal. + self.assertEqual(set(model_low_usage.state_dict().keys()), set(model_non_low_usage.state_dict().keys())) + + # Check that the shared tensors are equal. + tensor_ptrs1 = collections.defaultdict(list) + for name, tensor in model_low_usage.state_dict().items(): + tensor_ptrs1[id_tensor_storage(tensor)].append(name) + tied_params1 = [names for _, names in tensor_ptrs1.items() if len(names) > 1] + + tensor_ptrs2 = collections.defaultdict(list) + for name, tensor in model_non_low_usage.state_dict().items(): + tensor_ptrs2[id_tensor_storage(tensor)].append(name) + tied_params2 = [names for _, names in tensor_ptrs2.items() if len(names) > 1] + + self.assertEqual(tied_params1, tied_params2) + def test_fast_init_context_manager(self): # 1. Create a dummy class. Should have buffers as well? To make sure we test __init__ class MyClass(PreTrainedModel):