-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FA-2
] Fix fa-2 issue when passing config
to from_pretrained
#28043
Changes from 7 commits
021a6ed
19ef983
cfb2f4e
dec3754
a2c566e
a9be74d
8160f44
8476ea1
6571ec8
f45eb42
9068a89
4d0366e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1823,6 +1823,16 @@ def test_error_no_flash_available(self): | |
|
||
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception)) | ||
|
||
def test_error_no_flash_available_with_config(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test for e.g. llama + passing a config + There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean without There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean a test for an architecture that do support FA2, passing both a config + There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
with self.assertRaises(ValueError) as cm: | ||
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel") | ||
|
||
_ = AutoModel.from_pretrained( | ||
"hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2" | ||
) | ||
|
||
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception)) | ||
|
||
def test_error_wrong_attn_implementation(self): | ||
with self.assertRaises(ValueError) as cm: | ||
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This handles the case where users pass a
config
object tofrom_pretrained
. NoteAutoModelxxx.from_pretrained
pops theattn_impelmentation
from the kwargs in case one do not pass a config, but doesn't if we pass the config.Therefore this handles this corner case as well (passing a config -->
attn_implementation
does not get popped +attn_implementation
through from_pretrained kwargs). If that's the case we should over-write the config's attn_impelmentation by the one from the kwargs assuming the user knows what they are doing.https://github.com/huggingface/transformers/blob/main/src/transformers/models/auto/auto_factory.py#L516-L540