-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add diffllama #34083
base: main
Are you sure you want to change the base?
Add diffllama #34083
Conversation
I am coding now, but it's first time I contribute transformers and other OSS. I may ask you some help. |
765db6a
to
269055e
Compare
I still have a error located in modeling_diffllama.py@377: apply_rotary_pos_emb. Var "query_states" must be torch.Size([2, 32, 10, 128]) but the var is torch.Size([2, 64, 10, 64]). I need to change "query_states" or "cos"&"sin". |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! I think this would be an awesome fit to use modular transfomresr!
A bit of doc here: https://huggingface.co/docs/transformers/en/modular_transformers
this would help isolating the changes!
I've finished making normal/eager Attention, and I can run with AutoModelforForCausalLM.generate(). |
And also I fixed to fit modular transfomres. |
You don't need to divide by 2 if we use same number of attention heads as llama. instead you can just split in forward. Co-authored-by: Minho Ryu <[email protected]>
fit to changeing "num_heads // 2" place Co-authored-by: Minho Ryu <[email protected]>
new codes are more meaningful than before Co-authored-by: Minho Ryu <[email protected]>
new codes are more meaningful than before Co-authored-by: Minho Ryu <[email protected]>
fit to changeing "num_heads // 2" place Co-authored-by: Minho Ryu <[email protected]>
fix 2times divide by sqrt(self.head_dim) Co-authored-by: Minho Ryu <[email protected]>
fix 2times divide by sqrt(self.head_dim) Co-authored-by: Minho Ryu <[email protected]>
fit to changeing "num_heads // 2" place. and more visible Co-authored-by: Minho Ryu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implemented flash and sdpa attention as well.
Co-authored-by: Minho Ryu <[email protected]>
Co-authored-by: Minho Ryu <[email protected]>
All of your review implemented. And I tried the test many times, but it didn't pass. What should I do? |
Hey! Sorry we were all off for a week on a company-wide offsite! 🤗 @Cyrilvallez should be back on monday! |
I wonder this pr is still working in progress? Or, most of the implementation has been finalized and waiting for the test coverage review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW sorry for being late! Overall super good, what's left to do IMO is use modular
transformers https://huggingface.co/docs/transformers/en/modular_transformers to make it simpler (as a lot can inherit from Llama)! Let me know if I can help!
Hey, sorry for the delay! class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
pass in the modular file. In your case, you will probably need to only rewrite the attention classes 😉 |
Are you still working on this PR, @weak-kajuma ? |
@Cyrilvallez Could you review again? I made |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! A great first modular! But you can still cut a lot of code, the only difference here are the attention classes so it's perfect for modular to pick up on everything by itself!
LMK if you run into any issues
You may need to rebase/merge on |
@Cyrilvallez Could you review again? Moduler transformers is very easy and good. And also I can pass all tests by merging latest changes. |
@Cyrilvallez any plannings to review this pr? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, very good! Final comments 🤗
class DiffLlamaRMSNorm(LlamaRMSNorm): | ||
pass | ||
|
||
|
||
ALL_LAYERNORM_LAYERS.append(DiffLlamaRMSNorm) | ||
|
||
|
||
class DiffLlamaRotaryEmbedding(LlamaRotaryEmbedding): | ||
pass | ||
|
||
|
||
class DiffLlamaMLP(MistralMLP): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be removed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remove DiffLlamaMLP
, then AttributeError: 'DiffLlamaConfig' object has no attribute 'mlp_bias'
has happened. So I cannot remove it.
7b0da01
to
b4ff5f3
Compare
What does this PR do?
This PR adds the codes for the DiffLlama, which is Llama model with Differential Transformer. Please refer to Differential Transformer. @ArthurZucker