Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot specify config and attn_implementation simultaneously #28038

Closed
2 of 4 tasks
hiyouga opened this issue Dec 14, 2023 · 5 comments · Fixed by #28043
Closed
2 of 4 tasks

Cannot specify config and attn_implementation simultaneously #28038

hiyouga opened this issue Dec 14, 2023 · 5 comments · Fixed by #28043

Comments

@hiyouga
Copy link
Contributor

hiyouga commented Dec 14, 2023

System Info

  • transformers version: 4.36.1
  • Platform: Linux-5.15.0-89-generic-x86_64-with-glibc2.35
  • Python version: 3.10.10
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.0
  • Accelerate version: 0.23.0
  • Accelerate config:
    • compute_environment: LOCAL_MACHINE
    • distributed_type: MULTI_GPU
    • mixed_precision: bf16
    • use_cpu: False
    • debug: False
    • num_processes: 8
    • machine_rank: 0
    • num_machines: 1
    • gpu_ids: all
    • rdzv_backend: static
    • same_network: True
    • main_training_function: main
    • downcast_bf16: no
    • tpu_use_cluster: False
    • tpu_use_sudo: False
    • tpu_env: []
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoConfig, AutoModelForCausalLM
config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    config=config,
    device_map="auto",
    torch_dtype="auto",
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2"
)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 566, in from_pretrained
    return model_class.from_pretrained(
  File "lib/python3.10/site-packages/transformers/modeling_utils.py", line 3450, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
TypeError: LlamaForCausalLM.__init__() got an unexpected keyword argument 'attn_implementation'

Expected behavior

What should I do if I want to specify both of them?

Besides, it cannot enable FA2 by modifying the model config with config.attn_implementation=flash_attention_2.

However, it works if I pass a deprecated parameter use_flash_attention_2 when the config is also specified.

@younesbelkada
Copy link
Contributor

hi @hiyouga
Thanks a lot for the issue!
I think that you cannot pass both the config and attn_implementation , can you elaborate a bit on why you would like to pass the config as well as attn_implementation into from_pretrained? The canonical way to load a FA2 model is:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=torch.bfloat16, 
    attn_implementation="flash_attention_2",
)

It is also not recommended to enable FA2 through the config directly. However you can enable FA2 by passing attn_implementation="flash_attention_2" in from_config methd: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L1235

@hiyouga
Copy link
Contributor Author

hiyouga commented Dec 14, 2023

Thanks for replying!
I wish to modify the model config before loading the pre-trained models, such as setting rope_scaling and torch_dtype.
I wonder why I could pass both config and use_flash_attention_2 to from_pretrained

@younesbelkada
Copy link
Contributor

Thanks @hiyouga
Indeed, it was possible to do this before, therefore there is a regression now, I just made #28043 which should resolve your problem

@amyeroberts
Copy link
Collaborator

Hi @xingniandage Could you open a new issue, including information about the running env and a minimal reproducer?

@xingniandage
Copy link

Hi @xingniandage Could you open a new issue, including information about the running env and a minimal reproducer?
thanks for your reply
I don't know why the error is reported
I changed the version of torch and torchvision and the error disappears

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants