Skip to content

Commit

Permalink
Fix resize_token_embeddings (#26861) (#26865)
Browse files Browse the repository at this point in the history
* Fix `resize_token_embeddings` about `requires_grad`

The method `resize_token_embeddings` should keep `requires_grad`
unchanged for all parameters in embeddings.

Previously, `resize_token_embeddings` always set `requires_grad`
to `True`. After fixed, `resize_token_embeddings` copy the
`requires_grad` attribute in the old embeddings.
  • Loading branch information
czy-orange authored Nov 21, 2023
1 parent d2a980e commit c5be38c
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,6 +1586,8 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
if hasattr(old_embeddings, "_hf_hook"):
hook = old_embeddings._hf_hook
add_hook_to_module(new_embeddings, hook)
old_embeddings_requires_grad = old_embeddings.weight.requires_grad
new_embeddings.requires_grad_(old_embeddings_requires_grad)
self.set_input_embeddings(new_embeddings)

# Update new_num_tokens with the actual size of new_embeddings
Expand All @@ -1605,6 +1607,8 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
if hasattr(old_lm_head, "_hf_hook"):
hook = old_lm_head._hf_hook
add_hook_to_module(new_lm_head, hook)
old_lm_head_requires_grad = old_lm_head.weight.requires_grad
new_lm_head.requires_grad_(old_lm_head_requires_grad)
self.set_output_embeddings(new_lm_head)

return self.get_input_embeddings()
Expand Down

0 comments on commit c5be38c

Please sign in to comment.