diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cee472036b2742..e2b27de7d1e51d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1586,6 +1586,8 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): if hasattr(old_embeddings, "_hf_hook"): hook = old_embeddings._hf_hook add_hook_to_module(new_embeddings, hook) + old_embeddings_requires_grad = old_embeddings.weight.requires_grad + new_embeddings.requires_grad_(old_embeddings_requires_grad) self.set_input_embeddings(new_embeddings) # Update new_num_tokens with the actual size of new_embeddings @@ -1605,6 +1607,8 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): if hasattr(old_lm_head, "_hf_hook"): hook = old_lm_head._hf_hook add_hook_to_module(new_lm_head, hook) + old_lm_head_requires_grad = old_lm_head.weight.requires_grad + new_lm_head.requires_grad_(old_lm_head_requires_grad) self.set_output_embeddings(new_lm_head) return self.get_input_embeddings()