Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SwitchTransformer] Significant performance improvement on MoE blocks #31173

Merged
merged 4 commits into from
Jun 6, 2024

Conversation

ranggihwang
Copy link
Contributor

What does this PR do?

This is an edited version of the previously closed PR (#30490)

This PR includes a performant implementation of SwitchTransformersSparseMLP in the Google SwitchTransformer.
In the current implementation of the SwitchTransformer, it spans all possible experts, including the inactive ones.

for idx, expert in enumerate(self.experts.values()):
            token_indices = router_mask[:, :, idx].bool()
            next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.dtype)

This results in serious performance degradation of the SwitchTransformer.

스크린샷 2024-04-26 오전 2 16 44 As shown in this figure, the current implementation of the SwitchTransformer spans inactive experts, unnecessarily increasing latency. 스크린샷 2024-04-26 오전 2 17 37 This issue can be particularly severe in models with a larger number of experts, as it needlessly spans more experts.

However, in my custom implementation of SwitchTransformersSparseMLP, it only accesses and computes the active experts.

Advantages

  • This can significantly reduce the latency of the SwitchTransformer and make the model more accessible to a broader range of users.
  • This change achieves greater latency reductions when expert parameters are offloaded to the CPU or SSD.
  • This change addresses the problem of increasing latency proportional to the number of experts.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker and @younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ! Looks very good ! Can you make sure the styling checks pas make fixup && make fix-copies

@ranggihwang
Copy link
Contributor Author

Thanks @younesbelkada
I've done make fixup && make fix-copies before in #30490 (review)
But I was requested to revert in the end.

How do I need to this correctly? Would you please let me know how to do this?

@younesbelkada
Copy link
Contributor

Hi @ranggihwang
Thanks ! Hmm I think there was a misunderstanding on my side at that time, if you could run the styling checks and push the results here (it should only change 2 files, switch transformers & gpt_san_japanese file), it would be great !

@ranggihwang
Copy link
Contributor Author

Shouldn't the styling check be done for the src/transformers/models/switch_transformers/modeling_switch_transformers.py?
I haven't changed anything except the file and it seems like there's no file named gpt_san_japanese_file in my repo.

@younesbelkada
Copy link
Contributor

since gpt_san_japanese uses blocks that are copied from switch transformers, running make fix-copies will propagate the changes you introduced in that file as well, see: https://app.circleci.com/pipelines/github/huggingface/transformers/94615/workflows/e1e9d110-614a-411a-a0f5-b7d4146e4db8/jobs/1241937

@amyeroberts
Copy link
Collaborator

amyeroberts commented Jun 3, 2024

@younesbelkada @ranggihwang gpt san has been deprecated, so we don't really want these changes to be propogated. I've just merged in #31153 which removes the # Copied from headers for this model. Rebasing on main will include this and remove the need to run make fix-copies here. Thanks!

@younesbelkada
Copy link
Contributor

Perfect thanks for the heads up @amyeroberts !
@ranggihwang feel free to proceed as suggested by amy 🙏

@ranggihwang
Copy link
Contributor Author

@amyeroberts @younesbelkada
Thank you for your advice, Amy and Younes.

I've just rebase it to main and commit it. Would you please check if it is correct?

@younesbelkada
Copy link
Contributor

Thanks @ranggihwang ! Now styling checks are failing, can you run make fixup and commit the changes ?

@ranggihwang
Copy link
Contributor Author

Okay, now make fixup is done!

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@younesbelkada younesbelkada requested a review from amyeroberts June 3, 2024 10:22
@ArthurZucker
Copy link
Collaborator

I'll review this as I reviewed the previous PR, want to make sure the suggestions are all applied!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you apply the suggestion I did in the previous PR

Comment on lines 298 to 305
router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1, 2) # Batch * experts * tokens
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens)
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
router_mask = router_mask.bool()
idx_mask = router_mask.transpose(1, 2) # Batch * experts * tokens
idx_mask = torch.cat(torch.split(idx_mask, 1, dim=0), dim=2) # 1 * experts * (batch * tokens)
idx_mask = idx_mask.sum(dim=2)
idx_mask = idx_mask.squeeze() # length: number of experts / value: number of tokens
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[
0
].tolist() # length: number of "activated" expert / value: index
idx_mask = router_mask.reshape(batch*seq_len, num_experts).transpose(0,1).sum(dim=1)
idx_mask = torch.nonzero(idx_mask, as_tuple=True)[0].tolist()
  • the comment about shapes! 🤗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch_size, seq_len, and num_experts are not defined in the funciton.
So, I've defined it with the router_mask and reflected your suggestions.

Thank you @ArthurZucker !

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@ranggihwang
Copy link
Contributor Author

@ArthurZucker @younesbelkada
Please let me know if I need to rebase it again :)

@ranggihwang ranggihwang changed the title Significant performance improvement on MoE blocks of SwitchTransformer [SwitchTransformer] Significant performance improvement on MoE blocks of SwitchTransformer Jun 4, 2024
@ranggihwang ranggihwang changed the title [SwitchTransformer] Significant performance improvement on MoE blocks of SwitchTransformer [SwitchTransformer] Significant performance improvement on MoE blocks Jun 4, 2024
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still LGTM ! Let's wait for @ArthurZucker 's final review!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! 🤗

@ArthurZucker ArthurZucker merged commit 9b85e40 into huggingface:main Jun 6, 2024
21 checks passed
@ArthurZucker
Copy link
Collaborator

Could this be propagated to the qwen code @ranggihwang ? I know that they have some variants with lots of experts!

@ranggihwang
Copy link
Contributor Author

@ArthurZucker I think it can be adopted for many MoE models in HuggingFace not only qwen-moe but also for NLLB-MoE, Mixtral, etc.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Jun 6, 2024

awesome! Then if you are interested feel free to open a PR and ping me! 🤗
Some models need compile support which might be a little bit tricky we'll see

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 11, 2024
…ks (huggingface#31173)

* SwitchTransformer MoE layer performance improvement

* make fixup

* comments about shapes

* make fixup
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 14, 2024
…ks (huggingface#31173)

* SwitchTransformer MoE layer performance improvement

* make fixup

* comments about shapes

* make fixup
@ArthurZucker ArthurZucker mentioned this pull request Dec 13, 2024
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants