-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama Attention Call should not pass **kwargs #30523
Comments
Kwargs should indeed not be passed. I would need a reproducer but feel free to open a PR for a fix! 😉 |
I will open a PR after cataloguing all the models that have this issue. Gptneox also has this issue. Reproducer is to wrap a model in FSDP and then do a forward on any data. |
Yep, can confirm I also see the same issue with LLaMA-3-8b-Instruct with FSDP + Gradient Checkpointing. The Yi series of models also have this issue, I just checked. And it makes perfect sense since they follow the LLaMA architecture. |
We'll remove the kwargs! cc @zhenglongjiepheonix who is working on something related! |
@ArthurZucker are there any updates on this? I don't see a PR for this yet. |
#30743 closed this 😉 |
@ArthurZucker there's a slightly different error this time around with LlamaDecoderLayer.forward() got an unexpected keyword argument 'offload_to_cpu' There's something going on here. |
that seems pretty impossible TBH, we don't have kwargs or whatnot. Make sure you are on 4.41.1 |
Yeah, even with TypeError: LlamaDecoderLayer.forward() got an unexpected keyword argument ```offload_to_cpu``` And given the signature of the Is it so that LLaMA models are not compatible with Gradient Checkpointing when training with FSDP because |
On it ! |
Okay, with MistralDecoderLayer.forward() got an unexpected keyword argument 'offload_to_cpu' This wasn't happening with Mistral before. Again, this is when I use Gradient Checkpointing with FSDP. FYI, this is a rough template of how I apply Gradient Checkpointing with FSDP: check_fn = lambda submodule: isinstance(submodule, MistralDecoderLayer)
non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
) This is the method suggested by the official FSDP tutorials from the PyTorch team here. With |
I was able to repro: from functools import partial
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
# verify we have FSDP activation support ready by importing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
model_id = "HuggingFaceM4/tiny-random-Llama3ForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_id)
model.train()
model.gradient_checkpointing_enable()
accelerator = Accelerator()
model = accelerator.prepare(model)
check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)
print(model)
rand_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)
model(rand_input) #31161 fixes the bug |
System Info
transformers
version: 4.40.1Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Wrapping a LlamaModel with FSDP results in the following error during a forward pass;
This occurs because we are passing **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L749 to a function that does not accept **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L608
If we use another model, ex Mistral, this issue does not occurs, because we don't pass **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L757C63-L757C77
Expected behavior
Remove line 749 or add **kwargs to forward().
The text was updated successfully, but these errors were encountered: