Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add end-to-end conversion of TTIR broadcast to TTNN repeat op. #1638

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

uazizTT
Copy link
Contributor

@uazizTT uazizTT commented Dec 18, 2024

Fixes #1348
Fixes #1345
Fixes #1235

This provides a general solution to lower ttir.broadcast by splitting it into ttnn.reshape (where required) and ttnn.repeat. This should provide a general lowering for ttiir.broadcast op.

A subsequent change will provide an optimization to fold broadcast whenever possible with binary or ternary eltwise ops.

Example:

module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512x512xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x512xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.broadcast_in_dim %arg0, dims = [1] : (tensor<1xf32>) -> tensor<512x512xf32>
    %1 = stablehlo.maximum %0, %arg1 : tensor<512x512xf32>
    // CHECK: %[[C:.*]] = "ttir.broadcast"[[C:.*]]
    return %1 : tensor<512x512xf32>
  }
}

This gets lowered to the following TTNN graph:

module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
    func.func public @main(%arg0: tensor<1xf32, #ttnn_layout> {mhlo.layout_mode = "default"}, %arg1: tensor<512x512xf32, #ttnn_layout1> {mhlo.layout_mode = "default"}) -> (tensor<512x512xf32, #ttnn_layout1> {jax.result_info = "", mhlo.layout_mode = "default"}) {
      %0 = "ttnn.get_device"() <{mesh_shape = #ttnn<mesh_shape 1x1>}> : () -> !tt.device<#device>
      %1 = "ttnn.to_device"(%arg0, %0) <{memory_config = #ttnn.memory_config<#dram, <<1x1>>, <interleaved>>}> : (tensor<1xf32, #ttnn_layout>, !tt.device<#device>) -> tensor<1xf32, #ttnn_layout2>
      %2 = "ttnn.to_layout"(%1) <{layout = #ttnn.layout<tile>}> : (tensor<1xf32, #ttnn_layout2>) -> tensor<1xf32, #ttnn_layout2>
      "ttnn.deallocate"(%1) <{force = false}> : (tensor<1xf32, #ttnn_layout2>) -> ()
      %3 = "ttnn.reshape"(%2) <{shape = [1 : i32, 1 : i32]}> : (tensor<1xf32, #ttnn_layout2>) -> tensor<1x1xf32, #ttnn_layout3>
      "ttnn.deallocate"(%2) <{force = false}> : (tensor<1xf32, #ttnn_layout2>) -> ()
      %4 = "ttnn.repeat"(%3) <{shape = [512, 512]}> : (tensor<1x1xf32, #ttnn_layout3>) -> tensor<512x512xf32, #ttnn_layout4>
      "ttnn.deallocate"(%3) <{force = false}> : (tensor<1x1xf32, #ttnn_layout3>) -> ()
      %5 = "ttnn.to_device"(%arg1, %0) <{memory_config = #ttnn.memory_config<#dram, <<16x16>>, <interleaved>>}> : (tensor<512x512xf32, #ttnn_layout1>, !tt.device<#device>) -> tensor<512x512xf32, #ttnn_layout4>
      %6 = "ttnn.to_layout"(%5) <{layout = #ttnn.layout<tile>}> : (tensor<512x512xf32, #ttnn_layout4>) -> tensor<512x512xf32, #ttnn_layout4>
      "ttnn.deallocate"(%5) <{force = false}> : (tensor<512x512xf32, #ttnn_layout4>) -> ()
      %7 = "ttnn.empty"(%0) <{dtype = #tt.supportedDataTypes<f32>, layout = #ttnn.layout<tile>, memory_config = #ttnn.memory_config<#dram, <<16x16>>, <interleaved>>, shape = #ttnn.shape<512x512>}> : (!tt.device<#device>) -> tensor<512x512xf32, #ttnn_layout4>
      %8 = "ttnn.maximum"(%4, %6, %7) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<512x512xf32, #ttnn_layout4>, tensor<512x512xf32, #ttnn_layout4>, tensor<512x512xf32, #ttnn_layout4>) -> tensor<512x512xf32, #ttnn_layout4>
      "ttnn.deallocate"(%6) <{force = false}> : (tensor<512x512xf32, #ttnn_layout4>) -> ()
      "ttnn.deallocate"(%4) <{force = false}> : (tensor<512x512xf32, #ttnn_layout4>) -> ()
      %9 = "ttnn.from_device"(%8) : (tensor<512x512xf32, #ttnn_layout4>) -> tensor<512x512xf32, #ttnn_layout1>
      "ttnn.deallocate"(%7) <{force = false}> : (tensor<512x512xf32, #ttnn_layout4>) -> ()
      %10 = "ttnn.to_layout"(%9) <{layout = #ttnn.layout<row_major>}> : (tensor<512x512xf32, #ttnn_layout1>) -> tensor<512x512xf32, #ttnn_layout1>
      "ttnn.deallocate"(%9) <{force = false}> : (tensor<512x512xf32, #ttnn_layout1>) -> ()
      return %10 : tensor<512x512xf32, #ttnn_layout1>
    }
  }

@uazizTT uazizTT marked this pull request as draft December 18, 2024 20:28
@uazizTT uazizTT force-pushed the uaziz/broadcast-repeat-lowering branch from 0612e3f to 294b2b1 Compare December 19, 2024 16:59
@uazizTT uazizTT marked this pull request as ready for review December 19, 2024 17:00
@uazizTT uazizTT force-pushed the uaziz/broadcast-repeat-lowering branch 8 times, most recently from 054b9e0 to b3f07cc Compare December 21, 2024 02:05
@mtopalovicTT
Copy link
Contributor

@uazizTT Please provide better PR description with some IR before and after decomposition.

@uazizTT uazizTT force-pushed the uaziz/broadcast-repeat-lowering branch 2 times, most recently from 6cb7d13 to 092ca7d Compare December 23, 2024 19:24
lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp Outdated Show resolved Hide resolved
lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp Outdated Show resolved Hide resolved
lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp Outdated Show resolved Hide resolved
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Outdated Show resolved Hide resolved
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Outdated Show resolved Hide resolved
lib/Dialect/TTNN/IR/TTNNOps.cpp Outdated Show resolved Hide resolved
test/ttmlir/Dialect/TTNN/simple_broadcast.mlir Outdated Show resolved Hide resolved
lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp Outdated Show resolved Hide resolved
test/ttmlir/Dialect/TTNN/simple_broadcast.mlir Outdated Show resolved Hide resolved
test/ttmlir/Dialect/TTNN/simple_broadcast.mlir Outdated Show resolved Hide resolved
@uazizTT uazizTT force-pushed the uaziz/broadcast-repeat-lowering branch from 092ca7d to 3b43c5a Compare December 24, 2024 17:54
@uazizTT uazizTT force-pushed the uaziz/broadcast-repeat-lowering branch from 3850f61 to f741c50 Compare December 30, 2024 16:43
Copy link
Contributor

@mtopalovicTT mtopalovicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uazizTT Approved, but let either @sdjordjevicTT @svuckovicTT @azecevicTT check the changes before checking in.

include/ttmlir/Dialect/TTNN/IR/TTNNOps.td Show resolved Hide resolved
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp Outdated Show resolved Hide resolved
lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp Outdated Show resolved Hide resolved
test/ttmlir/Dialect/TTNN/simple_broadcast.mlir Outdated Show resolved Hide resolved
test/ttmlir/Dialect/TTNN/simple_broadcast.mlir Outdated Show resolved Hide resolved
test/ttmlir/Dialect/TTNN/simple_broadcast.mlir Outdated Show resolved Hide resolved
test/ttmlir/Dialect/TTNN/simple_broadcast.mlir Outdated Show resolved Hide resolved
test/ttmlir/Dialect/TTNN/simple_broadcast.mlir Outdated Show resolved Hide resolved
Copy link
Contributor

@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it doesn't look bad from functional point of view, SHLO lowering need some refactoring, there is also a double conversion during lowering, but the logic looks good.

I've also left a comment regarding TTNN interface, I would like if others would get involved in discussion, as I see it as a problem without a clear solution at the moment, I am okay with either proposal from that comment until we finally resolve that question.

}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I64ArrayAttr:$dimension);
I64ArrayAttr:$repeat_dimensions);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use DenseI64ArrayAttr it's much nicer to deal with it that I64ArrayAttr.

Copy link
Contributor Author

@uazizTT uazizTT Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to use DenseI64ArrayAttr but found that we don't yet have infrastructure to handle it in TTNNToFlatBuffer, so I suggest we keep it as I32ArrayAttr which is the same type as used in ReshapeOp and ReductionOp and then refactor them all to use DenseI64ArrayAttr as a refactoring task.

include/ttmlir/Dialect/TTIR/IR/TTIROps.td Outdated Show resolved Hide resolved
}];

let arguments = (ins AnyRankedTensor:$input,
I64ArrayAttr:$shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TTNN op takes Shape as parameter. At least from our standpoint Shape is a wrapper around std::vector<uint32_t>. We should decide if we should model it as TTNN_ShapeAttr in the dialect. One problem that I have with modeling unsigned arrays is that MLIR intentionally didn't include DenseUIXArrayAttr (and it's pretty much impossible to extend it without the fork of the whole project). So our options are either:

  • Model it as close as possible, which means making a TTNN_ShapeAttr that is wrapper around unsigned array attr. There isn't a nice interface for unsigned arrays, from my experience it means relaying on the user to interpret APInt as unsigned integer (in every place that it's used).
  • Ignore signdness and width, make every integer attribute int64_t, and by extension every integer array attribute DenseI64ArrayAttr.

For TTIR 2nd option is a no-brainer. For TTNN first option has the advantage that we can catch some errors during compile time, but it requires considerably more effort than second option.

We should definitely talk to folks from the Metal team, I know they used stl::span<int64_t> in some places, which is great and ideal option. I would like to hear their rationale for using unsigned types in other places. Other than addresses (if we use them anywhere) and bitwise ops (but we are specifically talking about metadata/attrs here) I really can't think of use-case where unsigned is a better decision than signed. I guess it's just a matter of how much effort would it require to fix it.

@svuckovicTT I think we should discuss this with Metal team before we move with the 1st option.

Either way I64ArrayAttr is a bad decision here, use either DenseI64ArrayAttr or I32ArrayAttr.

lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp Outdated Show resolved Hide resolved
table RepeatOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
shape: [int32];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking into account comment about ttnn.repeat this should be either [uint32] or [int64].

test/ttmlir/Conversion/StableHLOToTTIR/broadcast_op.mlir Outdated Show resolved Hide resolved
test/ttmlir/Dialect/TTNN/simple_repeat.mlir Outdated Show resolved Hide resolved
test/ttmlir/Silicon/TTNN/simple_repeat.mlir Outdated Show resolved Hide resolved
test/ttmlir/Silicon/StableHLO/broadcast_op.mlir Outdated Show resolved Hide resolved
@uazizTT uazizTT force-pushed the uaziz/broadcast-repeat-lowering branch from 6ed4be5 to 26a4297 Compare January 3, 2025 23:17
@uazizTT uazizTT force-pushed the uaziz/broadcast-repeat-lowering branch from 26a4297 to f25b452 Compare January 6, 2025 17:16
@AleksKnezevic
Copy link

Sorry for chiming in late, I missed this discussion as it was being held. I agree that shlo broadcast in dim should be lowered into multiple ops, an explicit rank change, when needed, and the actual broadcast.

I would prefer that for the broadcasting operation we use the name ttir.broadcast as opposed to ttir.repeat. I think it's more clear to users of the dialect of what we intend to do, i.e. the value will be broadcast to all relevant cores/kernels, etc.

If a particular ttnn op does not support operand broadcast, that is a ttnn limitation that should not leak into ttir dialect. In that case, we can repeat the value ahead of time and provide the full sized input to the op. This should happen in ttnn.

@uazizTT
Copy link
Contributor Author

uazizTT commented Jan 6, 2025

If a particular ttnn op does not support operand broadcast, that is a ttnn limitation that should not leak into ttir dialect. In that case, we can repeat the value ahead of time and provide the full sized input to the op. This should happen in ttnn.

Are you suggesting that we also always fold the ttir.broadcast to an eltwise op and then during TTIRToTTNN conversion, we somehow detect that this folding is not possible due to a TTNN limitation and then convert only those cases to ttnn.repeat?

The alternative solution is to apply this limitation at the TTIR level using TTIR_ImplicitBroadcastable attribute to mark only the Ops that support implicit broadcasts and the remainder of the broadcasts will be lowered to ttir.repeat. I have this draft commit that shows this implementation ceb0ebb.

@AleksKnezevic
Copy link

ttir dialect is a dialect meant to represent what we want to do with the graph, and we should prevent ttnn limitations from leaking in. My preference would be:

  • Lower shlo broadcast_in_dim to ttir.reshape + ttir.broadcast
  • Lower ttir.broadcast to ttnn.broadcast
  • Run broadcast folding pass on ttnn graph and either fold the broadcast into eltwise or convert it to ttnn.repeat.

Since we don't currently have a ttnnToTtnnPipeline to do the third step, we could do it in the lowering and in the work arounds.

I do think we should keep the op that does broadcasting in ttir called broadcast.

@AleksKnezevic
Copy link

Had a quick chat with @uazizTT offline, want to summarize here for others.

We both agree that folding in ttnn is probably the best approach in the long term because presumably, different backends will have different broadcasting capabilities and the folder should be aware of that. That said, until we have the infrastructure in place to do the folding in ttnn, we can have it in ttir using TTIR_ImplicitBroadcastable trait as recommended by @azecevicTT.

I'm OK with this PR as is, as long as we change the name to ttir.broadcast. One the folding PR lands, the overall flow would be:

  1. Convert shlo broadcast_in_dim to ttir.reshape (if needed) and ttir.broadcast
  2. Fold broadcast in ttir using something similar to @uazizTT commit above
  3. Lower remaining ttir.broadcast into ttnn.repeat

…ue to scalar to 1d-tensor conversion. Updated tests.
@uazizTT uazizTT force-pushed the uaziz/broadcast-repeat-lowering branch from f25b452 to 59d5fa5 Compare January 6, 2025 19:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants