Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨All attention refactor🚨 #35235

Merged
merged 99 commits into from
Dec 18, 2024
Merged

🚨All attention refactor🚨 #35235

merged 99 commits into from
Dec 18, 2024

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Dec 12, 2024

What does this PR do?

Todo in this PR:

  • Cohere
  • Chameleon
  • DBRX
  • Gemma
  • Gemma2
  • GLM (modular donc rien à faire je crois)
  • gpt_neoX et GPT2
  • Granite
  • Jamba
  • JetMoe
  • Mimi
  • Mistral
  • Mixtral
  • Mllama
  • Moshi
  • Nemotron
  • OPT
  • Phi
  • Ph3
  • PhiMoe
  • Qwen2
  • qwen2Moe
  • qwen2VL
  • SableML
  • StartCoder2 -> Modular normalement oK
  • Idefics1,2,3
  • Olmo
  • Olmo2
  • Siglip
  • Whisper

@ArthurZucker ArthurZucker force-pushed the all-attention-refactor branch from 0dc9253 to d1aa9ce Compare December 12, 2024 13:49
)


class GradientCheckpointLayer(torch.nn.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should help with kwargs as well

@Cyrilvallez Cyrilvallez force-pushed the all-attention-refactor branch from 8b56823 to ecd814b Compare December 16, 2024 11:28
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very impressive work, kudos to you both!

@Cyrilvallez Cyrilvallez merged commit 2c47618 into main Dec 18, 2024
25 checks passed
@Cyrilvallez Cyrilvallez deleted the all-attention-refactor branch December 18, 2024 15:53
@Cyrilvallez
Copy link
Member

Confirmed slow tests with Llama, everything is similar to main!

@ydshieh
Copy link
Collaborator

ydshieh commented Dec 18, 2024

run-slow: vit

(just a check unrelated to this PR)

@Cyrilvallez Cyrilvallez mentioned this pull request Dec 18, 2024
5 tasks
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work guys! I think there might be value in keeping some comments, e.g. why call contiguous on sdpa, and clarifying the fa usage on recasting to half (which originates from PEFT and/or rope).

attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
softcap: Optional[float] = None,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Head mask could be added as well as done in gpt neox.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more like gpt_neo_x will get a flex attention function that has support for head mask! But good point!

@Cyrilvallez
Copy link
Member

Thanks a lot for the feedback @vasqu! Please have a look at the fixes in #35342!

@SimJeg
Copy link

SimJeg commented Dec 20, 2024

@ArthurZucker @Cyrilvallez do you plan to refactor modeling_phi3.py using the ALL_ATTENTION_FUNCTIONS you recently introduced ? I see Phi 3 is in the list shared by @ArthurZucker at the beginning of this PR. It would be very helpful for our kvpress package. We plan to update it to be compatible with the future v4.48, any idea of when it will be released ? (december ? january ?)

@ArthurZucker
Copy link
Collaborator Author

Release will happen on 🎅🏻 🎁 !
Yeah for sure. Unless you submit a PR first 👀 not sure it will be in this release as we are all going on holidays but in january's release it will be included

@SimJeg
Copy link

SimJeg commented Dec 20, 2024

I'm going to holidays too so I'll wait for January ^^ Happy holidays!

@ArthurZucker
Copy link
Collaborator Author

Happy holidays!

@poedator
Copy link
Contributor

I noticed that there is still torch.reshape used in quite a few places, for instance in modeling_gpt2.py Won't it be an obstacle to compiling these models? Why not replacing it with einops and using einops._torch_specific.allow_ops_in_compiled_graph() ?

@Cyrilvallez
Copy link
Member

Hey! Indeed for gpt2 I used a 'reshape' instead of the usual 'view' because I hit an edge case that wasn't compatible with viewing at some point (but I will recheck that it still appears with latest developments, might have been an artifact during the debugging process). Whenever possible (most of the time), 'reshape' is actually equivalent to 'view' so no worries there anyway 😉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants