Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama Attention Call should not pass **kwargs #30523

Closed
2 of 4 tasks
kiddyboots216 opened this issue Apr 28, 2024 · 13 comments · Fixed by #31161
Closed
2 of 4 tasks

Llama Attention Call should not pass **kwargs #30523

kiddyboots216 opened this issue Apr 28, 2024 · 13 comments · Fixed by #31161

Comments

@kiddyboots216
Copy link

System Info

  • transformers version: 4.40.1
  • Platform: Linux-4.18.0-513.24.1.el8_9.x86_64-x86_64-with-glibc2.28
  • Python version: 3.10.13
  • Huggingface_hub version: 0.22.2
  • Safetensors version: 0.4.1
  • Accelerate version: 0.29.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.3.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: FSDP

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Wrapping a LlamaModel with FSDP results in the following error during a forward pass;

  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1196, in forward
    outputs = self.model(
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1016, in forward
    layer_outputs = decoder_layer(
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward
    return self.checkpoint_fn(  # type: ignore[misc]
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
    ret = function(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 739, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: LlamaSdpaAttention.forward() got an unexpected keyword argument 'offload_to_cpu'

This occurs because we are passing **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L749 to a function that does not accept **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L608

If we use another model, ex Mistral, this issue does not occurs, because we don't pass **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L757C63-L757C77

Expected behavior

Remove line 749 or add **kwargs to forward().

@ArthurZucker
Copy link
Collaborator

Kwargs should indeed not be passed. I would need a reproducer but feel free to open a PR for a fix! 😉

@kiddyboots216
Copy link
Author

Kwargs should indeed not be passed. I would need a reproducer but feel free to open a PR for a fix! 😉

I will open a PR after cataloguing all the models that have this issue. Gptneox also has this issue. Reproducer is to wrap a model in FSDP and then do a forward on any data.

@vikram71198
Copy link

vikram71198 commented May 9, 2024

Yep, can confirm I also see the same issue with LLaMA-3-8b-Instruct with FSDP + Gradient Checkpointing.

The Yi series of models also have this issue, I just checked. And it makes perfect sense since they follow the LLaMA architecture.

@ArthurZucker
Copy link
Collaborator

We'll remove the kwargs! cc @zhenglongjiepheonix who is working on something related!
We can open a separate PR for this and link this issue

@vikram71198
Copy link

@ArthurZucker are there any updates on this? I don't see a PR for this yet.

@ArthurZucker
Copy link
Collaborator

#30743 closed this 😉

@vikram71198
Copy link

@ArthurZucker there's a slightly different error this time around with transformers==4.41.0

LlamaDecoderLayer.forward() got an unexpected keyword argument 'offload_to_cpu'

There's something going on here.

@ArthurZucker
Copy link
Collaborator

                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

that seems pretty impossible TBH, we don't have kwargs or whatnot. Make sure you are on 4.41.1

@vikram71198
Copy link

vikram71198 commented May 24, 2024

Yeah, even with transformers == 4.41.1, I get the error:

TypeError: LlamaDecoderLayer.forward() got an unexpected keyword argument ```offload_to_cpu```

And given the signature of the forward() of this layer in your message above, it seems like it doesn't take the input offload_to_cpu, so this error message actually makes sense.

Is it so that LLaMA models are not compatible with Gradient Checkpointing when training with FSDP because offload_to_cpu is a param which comes from the Gradient Checkpointing.

@ArthurZucker
Copy link
Collaborator

cc @younesbelkada

@younesbelkada younesbelkada reopened this May 28, 2024
@younesbelkada
Copy link
Contributor

On it !

@vikram71198
Copy link

vikram71198 commented May 28, 2024

Okay, with transformers == 4.41.1, I'm now getting this same aforementioned error with Mistral models as well:

MistralDecoderLayer.forward() got an unexpected keyword argument 'offload_to_cpu'

This wasn't happening with Mistral before.

Again, this is when I use Gradient Checkpointing with FSDP.

FYI, this is a rough template of how I apply Gradient Checkpointing with FSDP:

check_fn = lambda submodule: isinstance(submodule, MistralDecoderLayer)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    offload_to_cpu=False,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

apply_activation_checkpointing(
    model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

This is the method suggested by the official FSDP tutorials from the PyTorch team here.

With transformers==4.40.2, this issue does not appear with Mistral models, but with LLaMA models (including Yi), the issue mentioned at the very start in the description of this issue is what I see.

@younesbelkada
Copy link
Contributor

I was able to repro:

from functools import partial
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator

# verify we have FSDP activation support ready by importing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

model_id = "HuggingFaceM4/tiny-random-Llama3ForCausalLM"

model = AutoModelForCausalLM.from_pretrained(model_id)

model.train()
model.gradient_checkpointing_enable()

accelerator = Accelerator()
model = accelerator.prepare(model)

check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    offload_to_cpu=False,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

apply_activation_checkpointing(
    model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

print(model)

rand_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)

model(rand_input)

#31161 fixes the bug

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants