Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix attention mask handling in the Hybrid Engine Bloom flow #5101

Merged

Conversation

deepcharm
Copy link
Contributor

The Bloom flow in Hybrid Engine applies the same transformation of the input mask which is already performed earlier by the transformers BloomModel::forward.

This results in the non-convergence of scores, specifically in Deepspeed Chat on different accelerators, including CUDA and HPU.

The fix removes redundant mask transformation and application, producing correct convergence.

The Bloom flow in Hybrid Engine applies the same transformation
of the input mask already performed earlier in the transformers
BloomModel::forward.

This results in the non-convergence of scores, specifically in
Deepspeed Chat on different accelerators, including CUDA and HPU.

The fix removes the redundant 2-nd mask transformation
and application, producing correct convergence.
@lekurile lekurile self-requested a review February 9, 2024 17:20
@lekurile
Copy link
Contributor

lekurile commented Feb 9, 2024

Hello @deepcharm,

Thank you for the contribution. I've studied the code in BloomModel::forward and can see that the call to _prepare_4d_causal_attention_mask is made here. I believe here is where the referenced 1-mask operation is happening in the _prepare_4d_causal_attention_mask function.

I think in this case the change makes sense since we don't want to do the redundant mask processing operation, however, I just want to be careful that we don't change ds_attention.py masking behavior fundamentally and account for model support beyond BLOOM, since I'm not sure if all transformer model implementations do this mask processing outside the transformers block.

One option is to add an optional config parameter in our inference config.py that will enable skipping of this operation. We can set this parameter to False by default and to True only in the BLOOM container.

If more models need this behavior, we can enable this in the corresponding model-specific container.

Please let me know if you have feedback or questions.

Thanks,
Lev

@deepcharm
Copy link
Contributor Author

Hi @lekurile

Thanks for your feedback. I will implement the option that you have described and submit another patch.

Max

@tjruwase
Copy link
Contributor

@deepcharm, thank for improving the PR. Please ping us again when it is ready for review.

deepcharm and others added 2 commits March 4, 2024 13:17
The BLOOM flow in Hybrid Engine applies the same transformation
of the input mask already performed earlier in the transformers
BloomModel::forward.

This results in the non-convergence of scores, specifically in
Deepspeed Chat on different accelerators, including CUDA and HPU.

An optional config parameter invert_mask is introduced into
DeepSpeedInferenceConfig (True by default), which enables skipping
the invert operation for some transformer implementations,
such as BLOOM.
@deepcharm
Copy link
Contributor Author

Hi @lekurile, @tjruwase

As advised, I've added an optional config parameter invert_mask into DeepSpeedInferenceConfig (True by default),
which enables skipping the invert operation for some transformer implementations, such as BLOOM.

Kindly review the change. Thanks.

@lekurile
Copy link
Contributor

lekurile commented Mar 4, 2024

Hi @lekurile, @tjruwase

As advised, I've added an optional config parameter invert_mask into DeepSpeedInferenceConfig (True by default), which enables skipping the invert operation for some transformer implementations, such as BLOOM.

Kindly review the change. Thanks.

@deepcharm Thank you! Looks good to me. Approved and running checks.

@lekurile lekurile enabled auto-merge March 12, 2024 22:57
@lekurile lekurile added this pull request to the merge queue Mar 12, 2024
Merged via the queue into microsoft:master with commit d9e12d3 Mar 13, 2024
12 checks passed
@deepcharm deepcharm deleted the fix-bloom-attention-mask-hybrid-engine branch March 14, 2024 16:42
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
…t#5101)

The Bloom flow in Hybrid Engine applies the same transformation of the
input mask which is already performed earlier by the transformers
BloomModel::forward.

This results in the non-convergence of scores, specifically in Deepspeed
Chat on different accelerators, including CUDA and HPU.

The fix removes redundant mask transformation and application, producing
correct convergence.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Lev Kurilenko <[email protected]>
dbyoung18 pushed a commit to dbyoung18/DeepSpeed that referenced this pull request Jun 11, 2024
…t#5101)

The Bloom flow in Hybrid Engine applies the same transformation of the
input mask which is already performed earlier by the transformers
BloomModel::forward.

This results in the non-convergence of scores, specifically in Deepspeed
Chat on different accelerators, including CUDA and HPU.

The fix removes redundant mask transformation and application, producing
correct convergence.

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Lev Kurilenko <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants