Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing weights are not properly initialized when using model.from_pretrained() #35437

Open
4 tasks done
YifanXu74 opened this issue Dec 27, 2024 · 1 comment
Open
4 tasks done
Labels
bug Core: Modeling Internals of the library; Models.

Comments

@YifanXu74
Copy link

YifanXu74 commented Dec 27, 2024

System Info

  • transformers version: 4.47.1
  • Platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.35
  • Python version: 3.12.2
  • Huggingface_hub version: 0.27.0
  • Safetensors version: 0.4.5
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: yes
  • Using GPU in script?: yes
  • GPU type: NVIDIA A100-PCIE-40GB

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig

class Config(PretrainedConfig):
    def __init__(self, use_new=False, **kwargs):
        self.use_new = use_new
        super().__init__(**kwargs)


class Model(PreTrainedModel):
    config_class = Config
    def __init__(self, config: Config):
        super().__init__(config)
        self.use_new = config.use_new

        self.proj = nn.Linear(10, 10, bias=False)
        if self.use_new:
            self.new_proj = nn.Linear(20, 20, bias=False)
        
        self.post_init()
    
    def post_init(self):
        nn.init.constant_(self.proj.weight, 0)
        if self.use_new:
            nn.init.constant_(self.new_proj.weight, 0)

if __name__ == "__main__":
    # 1. Pretrain a base model
    config = Config(use_new=False)
    original_model = Model(config)
    print(original_model.proj.weight.data.max()) # 0

    # 2. Save the pretrained weights
    original_model.save_pretrained("./original_model/")

    # 3. Load the pretrained weights, and finetune the model with a newly added layer
    new_model1 = Model.from_pretrained("./original_model/", use_new=True)
    print(new_model1.proj.weight.data.max()) # 0
    print(new_model1.new_proj.weight.data.max()) # nan - BUG: This is unexpected!

    # 4. A trick to work around this problem: pass _fast_init=False into from_pretrained()
    new_model2 = Model.from_pretrained("./original_model/", use_new=True, _fast_init=False)
    print(new_model2.proj.weight.data.max()) # 0
    print(new_model2.new_proj.weight.data.max()) # 0

Expected behavior

The missing weights during from_pretrained() are not initialized according to self.post_init().
In this case, I want to fine-tune a pretrained model and add some new parameters (self.new_proj.weight), which is a very common scenario.
The missing weights (self.new_proj.weight) are expected to be initialized to 0, but the values are actually frozen during from_pretrained() and cannot be properly initialized.
A workaround is to pass _fast_init=False to from_pretrained(), but I noticed that this feature is deprecated. Therefore, there should be a more appropriate solution to this problem.

@YifanXu74 YifanXu74 added the bug label Dec 27, 2024
@LysandreJik
Copy link
Member

cc @ArthurZucker I believe

@LysandreJik LysandreJik added the Core: Modeling Internals of the library; Models. label Dec 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Core: Modeling Internals of the library; Models.
Projects
None yet
Development

No branches or pull requests

2 participants