-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BERT
] Add support for sdpa
#28802
[BERT
] Add support for sdpa
#28802
Conversation
Hey @ArthurZucker @younesbelkada I was thinking SDPA (#28005) could be a good addition to BERT, so I drafted this change. It doesn't look too hairy so far. As @ArthurZucker mentioned, BERT doesn't have a lot of params so there might not be much of a speedup, but this didn't look too difficult to implement so I figured whatever little improvement might still be helpful (as an aside, there's been some benchmarking of Flash Attention on training other implementations of BERT, and it still shows decent improvements). Can you let me know if this is worth pursuing? If so, I'll add the tests and also fix the fix-copies dependencies. Thanks! |
# if encoder bi-directional self-attention `past_key_value` is always `None` | ||
past_key_value = (key_layer, value_layer) | ||
|
||
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is fixed in torch 2.2.0 I think, maybe I should check for it and skip the calls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is fine to leave. We should probably bump the requirement for SDPA to torch>=2.2 in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This got me thinking, and I ran an additional set of benchmarking, given that FA2 is supported and the contiguous bug is fixed in 2.2.0: training and inference.
Both training and inference were ~5% faster with torch==2.2.0 (FA2 should be supported). I also tried out gating the .contiguous()
requirement and saw an additional ~5-10% gain on top of that.
if version.parse(get_torch_version()) < version.parse("2.2.0")
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
I'm leaning towards adding the if-statement to gate the call, so users who upgrade to torch=2.2.0 first can get the benefits right away (before we set the min torch version to 2.2.0). WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the if-statement for 2.2.0 in there. If you don't think it's a good idea, let me know and I'll remove it.
I think a good way to se if it is worth the shot is to benchmark your code and check if you have speedups in different contexts! |
Sounds good, lemme look into that |
@ArthurZucker I did some training and inference benchmarking for my change and posted the results in the PR description. It looks like there are decent improvements across the board (percentage-wise, but I think the improvements would add up if we're doing a lot of training/inferencing). I think it could be a good addition. Thoughts? |
Sounds like a good addition then! I'll let @fxmarty review and will be doing the final pass! |
Yes, it's similar. SDPA is built into pytorch, and can support Flash Attention (1) depending on the environment. AFAIK Flash Attention 2 isn't supported in SDPA yet, but there is a possibility for it to be supported down the road (but that should be built into pytorch already, and shouldn't need many changes from our end). |
Thanks, I think it is now |
Oh nice, so I guess we could get FA2 for free eventually (when we upgrade pytorch). Thanks for the links to similar work. I think they could cause some merge conflicts, so I'll message them and try to resolve it before it goes in. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks in good shape thank you, left a few comments
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] | ||
# ourselves in which case we just need to make it broadcastable to all heads. | ||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) | ||
embedding_output = self.embeddings( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would probably move the Copied from
just to the __init__
and other methods, but not forward
. For the forward, you can probably just add a comment that it is adapted from bert/roberta and once bridge_tower supports sdpa we can put back to copied from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There seems to be 8 methods that copy-from BertMode#forward() exactly and has this section of change.
I won't mind adding SDPA to them as well once this goes in and reinstating the copy-from. It shouldn't be that difficult (famous last words)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed the fix-copies from the instances, and so the logic for sdpa attention masks should only be in BertModel now.
src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
Outdated
Show resolved
Hide resolved
# Expand the attention mask for SDPA. | ||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
if self.config.is_decoder: | ||
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker there are create_extended_attention_mask_for_decoder
, invert_attention_mask
, get_extended_attention_mask
methods in modeling_utils.py that should probably be deprecated / redirect to modeling_attn_mask_utils.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, I agree.
It'd be great if we could mark those old methods as deprecated, and slowly update them once we verify that the old methods and the new methods are always returning the same results.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the updated_attention_mask
for sdpa, why can't we keep the previous logic and just do:
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == torch.finfo(embedding_output.dtype).min, dim=-1, keepdim=True)).to(
dtype
)
(from Llama)?
Not super fan of the complexity of _prepare_4d_causal_attention_mask_for_sdpa
, and we should not add it in our new code IMO.
@@ -451,12 +451,10 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, | |||
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` | |||
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. | |||
# TODO: Fix this as well when using torchdynamo with fullgraph=True. | |||
is_tracing = torch.jit.is_tracing() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code was changed to pass the fx tracing test (in common tests).
It would be good if you can help double check the logic here. I think the idea here is that we'll still have to use our own attention mask (rather than None
) when tracing is active. The previous "pass" would cause the function to end without any return statements, which would have defaulted to None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks OK to me, cc @fxmarty to confirm.
AFAICT, the difference here is coming from the additional isinstance(mask, torch.fx.Proxy)
in the is_tracing_check
. I don't believe the reworking to remove pass
should affect anything - the new code is equivalent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is fine, see
is_tracing = torch.jit.is_tracing() or isinstance(inputs_embeds, torch.fx.Proxy) |
@@ -692,6 +807,10 @@ def __init__(self, config): | |||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` | |||
self.decoder.bias = self.bias | |||
|
|||
def _tie_weights(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fix was added due to a test failure that uncovered an existing bug.
The head was initialized but the weights weren't retied as necessary. This was causing self.decoder.bias
to be different from self.bias
. When loading the pretrained model with low_cpu_mem_usage=True
, the self.decoder.bias
had uninitiated params (with device=meta
) whereas self.bias
was set properly (with device=cpu
)
I'm slightly concerned this will affect the output some users see when using this model. Please let me know what you think about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pulled this out to its own PR here:
#28948
This issue is unrelated to SDPA, but was just uncovered by a SPDA test, so I just pulled it out to its own PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addition looks OK to me - thanks for digging into this.
I'm slightly concerned this will affect the output some users see when using this model. Please let me know what you think about this.
Could you expand on what you think might be an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was initially concerned that users were loading and using the model with a wrong bias (ie. device=meta), and this fix to use the correct bias will cause the results to change between versions.
However, that seems unlikely after playing around with this a bit more - turns out it was quite difficult to run the model when the bias had device=meta, so I doubt anyone was actually running the model in this particular configuration before the fix.
@@ -3560,8 +3564,9 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): | |||
enable_math=True, | |||
enable_mem_efficient=enable_kernels, | |||
): | |||
outputs_eager = model_eager(dummy_input, **other_inputs) | |||
outputs_sdpa = model_sdpa(dummy_input, **other_inputs) | |||
prepared_inputs = self._prepare_for_class(processed_inputs, model_class) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The self._prepare_for_class
is necessary to support the BertForMultipleChoice
model.
I've rebased off of head and marked as ready for review. I had to dig through a couple of issues to get the tests to pass, let me now if you want to chat about any of them. Thanks! |
The tests are passing now. I also verified that test_modeling_bert passes with RUN_SLOW=1 (which contains the tests to ensure model output is the same for eager and sdpa attentions) Please take another look when you get a chance. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all the work adding this @hackyon as well as the additional work to dig into weird errors and find solutions. Great work!
Some general comments:
- Let's wait for the merging of Add tie_weights() to LM heads and set bias in set_output_embeddings() #28948 before merging this in
- It would be good to add the performance numbers in the PR description to BERT's model page, similar to what's done for Flash Attention e.g. [here](https://huggingface.co/docs/transformers/v4.37.2/en/model_doc/gpt_neox#using-flash-attention-2.
test_eager_matches_sdpa_inference
should be run for all existing models with SDPA implemented to confirm compatibility with the change inprocessed_inputs
- We shouldn't be setting
self._use_sdpa
that don't have an SDPA attention class. We can just about get away with it for the models which have an attention dict, but not for the other models.
@@ -451,12 +451,10 @@ def _prepare_4d_attention_mask_for_sdpa(mask: torch.Tensor, dtype: torch.dtype, | |||
# torch.jit.trace and torchdynamo with fullgraph=True are unable to capture the controlflow `is_causal=attention_mask is None and q_len > 1` | |||
# used as an SDPA argument. We keep compatibility with these tracing tools by always using SDPA's `attn_mask` argument in case we are tracing. | |||
# TODO: Fix this as well when using torchdynamo with fullgraph=True. | |||
is_tracing = torch.jit.is_tracing() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks OK to me, cc @fxmarty to confirm.
AFAICT, the difference here is coming from the additional isinstance(mask, torch.fx.Proxy)
in the is_tracing_check
. I don't believe the reworking to remove pass
should affect anything - the new code is equivalent.
@@ -692,6 +807,10 @@ def __init__(self, config): | |||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` | |||
self.decoder.bias = self.bias | |||
|
|||
def _tie_weights(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addition looks OK to me - thanks for digging into this.
I'm slightly concerned this will affect the output some users see when using this model. Please let me know what you think about this.
Could you expand on what you think might be an issue?
src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
Outdated
Show resolved
Hide resolved
Oh wow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, let's rebase on main!
Thanks! I merged with main/HEAD, and re-ran the RUN_SLOW tests for both bert and also for test_eager_matches_sdpa_inference and they work as expected. There were existing failures for test_eager_matches_sdpa_inference with RUN_SLOW on main/HEAD, but nothing new introduced by this change. I'm not sure about this test_pipelines_tf failure. I haven't touched any code with tf, and I was able to get the failing test test_stop_sequence_stopping_criteria to pass locally, so I'm thinking it's a flake or unrelated to this change. |
Hi @hackyon - great to see this ready to merge! The generation tests aren't related to this diff and are failing on other PRs. We're working to push a fix to main - will let you know when resolved, you can rebase and hopefully we have full 🟢 for merging 🤗 |
Thanks @amyeroberts @ArthurZucker Just remerged with main/HEAD, and the unrelated failing TF pipeline test now passes. I checked the bert tests again with RUN_SLOW for good measure, and they continue to pass. Let me know if there's anything else I could do here. Thanks! |
@ArthurZucker Please let me know if there's anything else you'd like me to do for this PR. Thanks! |
Remerged with the latest main, and fixed a test. @ArthurZucker @amyeroberts @fxmarty Please let me know if there's anything I can do here. |
@hackyon Everything's green and two approvals, so we're good to merge. Thanks for all the effort in adding this and iterating with us. It's great to have this added to one of the most popular models ❤️ |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks @amyeroberts for the merge! 🎉 I appreciate all the help from @fxmarty, @ArthurZucker, and you in getting this PR merged 🙏 I see you've submitted #30506 as a follow-up, and thank you for covering that. Please let me know if there's any other follow-up work, and I'd be happy to look into it. |
As I mentioned previously, I've also drafted a PR for adding SDPA support to RoBERTa-based models at #30510. Almost all of the changes are "Copied from" BERT, and so there is a little less room for error. |
* Adding SDPA support for BERT * Using the proper input name for testing model input in inference() * Adding documentation for SDPA in BERT model page * Use the stable link for the documentation * Adding a gate to only call .contiguous() for torch < 2.2.0 * Additions and fixes to the documentation * Minor updates to documentation * Adding extra requirements needed for the contiguous() bug * Adding "Adapted from" in plcae of the "Copied from" * Add benchmark speedup tables to the documentation * Minor fixes to the documentation * Use ClapText as a replacemenet for Bert in the Copied-From * Some more fixes for the fix-copies references * Overriding the test_eager_matches_sdpa_generate in bert tests to not load with low_cpu_mem_usage [test all] * Undo changes to separate test * Refactored SDPA self attention code for KV projections * Change use_sdpa to attn_implementation * Fix test_sdpa_can_dispatch_on_flash by preparing input (required for MultipleChoice models)
I appreciate your job! As Esm is a Bert-base model, I think sdpa can be add to Esm with little modification. |
What does this PR do?
Adding support for SDPA (scaled dot product attention) for Bert. More context in #28005.
Benchmarking Results on A100-80GB, CPUx12, RAM 96.6GB, OS Ubuntu 22.04, using BertLMHeadModel
Training benchmark based on fxmarty's script:
Inference benchmark based on fxmarty's script:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @younesbelkada
(cc @fxmarty)