Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Reduce ops workaround for keepDim=false #1625

Merged
merged 7 commits into from
Dec 20, 2024
Merged

Conversation

mrakitaTT
Copy link
Contributor

@mrakitaTT mrakitaTT commented Dec 18, 2024

This PR adds TTNN workarounds for these Metal issues:

As part of this work I've also:

  • Enabled conversion of stablehlo.reduce op with multiple reduce dimensions
  • Added reduce ops verifiers in TTIR
  • Added a separate function in TTNNWorkarounds to run rewrite patterns for decomposition and layout workarounds
  • Added lots of unit tests for reduce ops to cover conversions and verifiers
  • Added lots of silicon tests for reduce ops

Opened issue #1624 on myself to revert these workarounds once Metal issues are fixed.

Closes #805, #848

After implementing these workarounds and running tests, I've encountered another Metal issue, this time in reshape op. I've debugged it and I have a local fix, I will send a PR to fix it in Metal repo, confirmed with reshape op owners. I've opened myself an issue #1640 to enable Reduce ops silicon tests after this fix is uplifted.

Another issue that I've encountered while working on this is that after the workaround pass decompositions, if we are changing the shapes of the ops tensors, that means that their layout needs to be changed too, but layout pass is done before the workaround pass. I've managed to solve it by reusing the layout of the input tensor, but I am not sure if that is a good solution and maybe we need to repeat some of the layout logic again after workaround decompositions. FYI @sdjordjevicTT

Here is the example TTNN IR before the workarounds:

%3 = "ttnn.sum"(%2) <{dim_arg = [0: i32, 1 : i32, 2: i32], keep_dim = false}> : (tensor<128x32x4xf32, #ttnn_layout2>) -> tensor<1xf32, #ttnn_layout2>

and after the workarounds:

%3 = "ttnn.sum"(%2) <{keep_dim = true}> : (tensor<128x32x4xf32, #ttnn_layout2>) -> tensor<1x1x1xf32, #ttnn_layout2>
%4 = "ttnn.reshape"(%3) <{shape = [1 : i32]}> : (tensor<1x1x1xf32, #ttnn_layout2>) -> tensor<1xf32, #ttnn_layout3>

@mrakitaTT
Copy link
Contributor Author

Copy link
Contributor

@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you writing a great description of PR.

There are some points that repeat throughout this PR review, so I didn't want to repeat myself, mainly about SmallVector, SmallSet, const, ref and std::optional interface. I would like to hear opinions of others on this matter. I believe we agreed on some of those points, but some of them weren't discussed.

lib/Dialect/TTIR/IR/TTIROps.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTIR/IR/TTIROps.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTIR/IR/TTIROps.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTIR/IR/TTIROps.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTIR/IR/TTIROps.cpp Outdated Show resolved Hide resolved
test/ttmlir/Conversion/StableHLOToTTIR/reduce_add_op.mlir Outdated Show resolved Hide resolved
test/ttmlir/Silicon/StableHLO/reduce_add_op.mlir Outdated Show resolved Hide resolved
test/ttmlir/Silicon/StableHLO/reduce_add_op.mlir Outdated Show resolved Hide resolved
@mrakitaTT
Copy link
Contributor Author

mrakitaTT commented Dec 19, 2024

@azecevicTT Thank you for detailed review and pointers to llvm docs! I've learned a few new things 😄 I've left one comment unresolved where I had a different opinion. Please let me know if you disagree, and also if I missed to cover something.

@mrakitaTT mrakitaTT requested a review from azecevicTT December 19, 2024 02:03
Copy link
Contributor

@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing comments. Just go once more through changes, there are has_value, value on std::optional, std::vector, std::unordered_set, const and ref on MLIR types in the places where I didn't mark them.

Regarding ReduceOps verification it's okay on my end even if it stays as-is, because of limited scope of that function.

lib/Dialect/TTIR/IR/TTIROps.cpp Outdated Show resolved Hide resolved
@mrakitaTT mrakitaTT requested a review from azecevicTT December 19, 2024 15:10
@LPanosTT
Copy link
Contributor

@mrakitaTT It looks like @bbradelTT has a fix in metal to handle keepDim=False properly in this PR: tenstorrent/tt-metal#16163

It's not merged yet but we'll keep an eye on it.

@sdjordjevicTT
Copy link
Contributor

Can you please attach MLIR IR before and after your change? It is a bit hard to imagine these kinds of changes by just looking just at the code :)

@mrakitaTT
Copy link
Contributor Author

Can you please attach MLIR IR before and after your change? It is a bit hard to imagine these kinds of changes by just looking just at the code :)

@sdjordjevicTT Sure, added example IRs to the PR description, will do it for the future PRs too :)

@bbradelTT
Copy link

I just merged in tenstorrent/tt-metal#16163 to support keepdim=False

It should work for many inputs, with the caveat that you need to use a release build to avoid asserts in many cases.

@mrakitaTT
Copy link
Contributor Author

mrakitaTT commented Dec 20, 2024

I just merged in tenstorrent/tt-metal#16163 to support keepdim=False

It should work for many inputs, with the caveat that you need to use a release build to avoid asserts in many cases.

Thank you @bbradelTT I wish you had let me know that you are going to start working on it.

In any case, I've tested this workaround in debug mode with lots of shapes and all tests pass. I think I am still going to merge it and then once we uplift the new version of Metal I will test your fix and then if it also passes our tests I can remove that workaround (tracked by #1624). I'll let you know if some tests don't pass and with which shapes/parameters.

@mrakitaTT mrakitaTT force-pushed the mrakita/reduce_false branch from 3d6c7db to 1a7fcd7 Compare December 20, 2024 21:01
@mrakitaTT mrakitaTT enabled auto-merge (squash) December 20, 2024 22:32
@mrakitaTT mrakitaTT merged commit cb3e406 into main Dec 20, 2024
21 checks passed
@mrakitaTT mrakitaTT deleted the mrakita/reduce_false branch December 20, 2024 22:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Stablehlo MINIST Softmax test is failing due to Reduction Op
5 participants