-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FA-2
] Final fix for FA2 dtype
#26846
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, think we can simplify a bit and remove the warning ?
logger.warning_once( | ||
"The input hidden states seems to be silently casted in float32, this might be related to" | ||
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
" float16." | ||
f"The input hidden states seems to be silently casted in float32, this might be related to" | ||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
f" {target_dtype}." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this now no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I think we need to keep it to inform users about that
* final fix for FA2 dtype * try * oops * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <[email protected]> * apply fix everywhere --------- Co-authored-by: Arthur <[email protected]>
* final fix for FA2 dtype * try * oops * Update src/transformers/models/falcon/modeling_falcon.py Co-authored-by: Arthur <[email protected]> * apply fix everywhere --------- Co-authored-by: Arthur <[email protected]>
What does this PR do?
Replaces #26560
Fixes #26451
Proposes a simpler fix for dealing with FA-2 + PEFT + quantization fine-tuning where users usually cast all other modules (e.g. LayerNorms) in fp32 for training stability.
With #26761 being introduced, it is now much simpler to retrieve model's original dtype, note also that
self.config._pre_quantization_dtype
remains the single source of truth asto
is not supported for quantized modelscc @ArthurZucker @pacman100
Added also a nice test