Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA-2] Fix fa-2 issue when passing config to from_pretrained #28043

Merged
merged 12 commits into from
Dec 15, 2023
13 changes: 13 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,19 @@ def from_pretrained(
**kwargs,
)
else:
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038

# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
# we pop attn_implementation from the kwargs but this handles the case where users
# passes manually the config to `from_pretrained`.
config = copy.deepcopy(config)

if kwargs.get("attn_implementation", None) is not None and config._attn_implementation != kwargs.get(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to get here or pop from the kwargs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I think pop would work best here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"attn_implementation", None
):
config._attn_implementation = kwargs.get("attn_implementation", None)
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
model_kwargs = kwargs

quantizer = None
Expand Down
10 changes: 10 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,16 @@ def test_error_no_flash_available(self):

self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

def test_error_no_flash_available_with_config(self):
Copy link
Contributor

@fxmarty fxmarty Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for e.g. llama + passing a config + attn_implementation="flash_attention_2 that the correct class is loaded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean without AutoModel?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean a test for an architecture that do support FA2, passing both a config + attn_implementation="flash_attention_2"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

with self.assertRaises(ValueError) as cm:
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")

_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
)

self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
Expand Down
Loading