Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA-2] Fix fa-2 issue when passing config to from_pretrained #28043

Merged
merged 12 commits into from
Dec 15, 2023
13 changes: 13 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,19 @@ def from_pretrained(
**kwargs,
)
else:
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038

# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
# we pop attn_implementation from the kwargs but this handles the case where users
# passes manually the config to `from_pretrained`.
config = copy.deepcopy(config)

if kwargs.get("attn_implementation", None) is not None and getattr(
config, "_attn_implementation", None
) != kwargs.get("attn_implementation", None):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handles the case where users pass a config object to from_pretrained. Note AutoModelxxx.from_pretrained pops the attn_impelmentation from the kwargs in case one do not pass a config, but doesn't if we pass the config.

Therefore this handles this corner case as well (passing a config --> attn_implementation does not get popped + attn_implementation through from_pretrained kwargs). If that's the case we should over-write the config's attn_impelmentation by the one from the kwargs assuming the user knows what they are doing.

https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/auto_factory.py#L516-L540

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
config._attn_implementation = kwargs.get("attn_implementation", None)
model_kwargs = kwargs

quantizer = None
Expand Down
10 changes: 10 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,16 @@ def test_error_no_flash_available(self):

self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

def test_error_no_flash_available_with_config(self):
Copy link
Contributor

@fxmarty fxmarty Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for e.g. llama + passing a config + attn_implementation="flash_attention_2 that the correct class is loaded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean without AutoModel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean a test for an architecture that do support FA2, passing both a config + attn_implementation="flash_attention_2"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

with self.assertRaises(ValueError) as cm:
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")

_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
)

self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
Expand Down
Loading