-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🚨All attention refactor🚨 #35235
🚨All attention refactor🚨 #35235
Conversation
0dc9253
to
d1aa9ce
Compare
src/transformers/modeling_utils.py
Outdated
) | ||
|
||
|
||
class GradientCheckpointLayer(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should help with kwargs as well
8b56823
to
ecd814b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very impressive work, kudos to you both!
Confirmed slow tests with Llama, everything is similar to main! |
run-slow: vit (just a check unrelated to this PR) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work guys! I think there might be value in keeping some comments, e.g. why call contiguous on sdpa, and clarifying the fa usage on recasting to half (which originates from PEFT and/or rope).
attention_mask: Optional[torch.Tensor], | ||
scaling: Optional[float] = None, | ||
softcap: Optional[float] = None, | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Head mask could be added as well as done in gpt neox.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's more like gpt_neo_x will get a flex attention function that has support for head mask! But good point!
@ArthurZucker @Cyrilvallez do you plan to refactor |
Release will happen on 🎅🏻 🎁 ! |
I'm going to holidays too so I'll wait for January ^^ Happy holidays! |
Happy holidays! |
I noticed that there is still torch.reshape used in quite a few places, for instance in |
Hey! Indeed for gpt2 I used a 'reshape' instead of the usual 'view' because I hit an edge case that wasn't compatible with viewing at some point (but I will recheck that it still appears with latest developments, might have been an artifact during the debugging process). Whenever possible (most of the time), 'reshape' is actually equivalent to 'view' so no worries there anyway 😉 |
What does this PR do?
Todo in this PR: