From f22c4160920275be35e3beb84c0017c60317cf18 Mon Sep 17 00:00:00 2001 From: Jovan Serbedzija Date: Thu, 12 Dec 2024 14:30:30 +0100 Subject: [PATCH] [TTIR] Remove TTIR operand_constraints (#1388) --- docs/src/overview.md | 3 +- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 109 +++++------- .../Dialect/TTIR/IR/TTIROpsInterfaces.td | 10 -- .../StableHLOToTTIRPatterns.cpp | 124 ++----------- .../TTIRToTTIRDecomposition.cpp | 89 +++------- .../TosaToTTIR/TosaToTTIRPatterns.cpp | 36 +--- lib/Dialect/TTIR/IR/TTIROps.cpp | 15 +- lib/Dialect/TTIR/Transforms/Constant.cpp | 7 +- lib/Dialect/TTIR/Transforms/Generic.cpp | 2 +- lib/Dialect/TTIR/Transforms/Layout.cpp | 27 +-- lib/Dialect/TTNN/Transforms/TTNNLayout.cpp | 44 +---- python/test_infra/ttir_builder.py | 28 --- test/python/smoketest.py | 2 +- .../StableHLOToTTIR/binary/concat_op.mlir | 18 +- .../Conversion/StableHLOToTTIR/clamp_op.mlir | 2 +- .../exponential_minus_one_op.mlir | 3 +- .../StableHLOToTTIR/log_plus_one_op.mlir | 3 +- .../StableHLOToTTIR/scatter_op.mlir | 3 +- .../Conversion/StableHLOToTTIR/select_op.mlir | 3 +- .../Conversion/StableHLOToTTIR/sign_op.mlir | 3 +- .../StableHLOToTTIR/unary/ceil_op.mlir | 2 +- .../StableHLOToTTIR/unary/cosine_op.mlir | 2 +- .../StableHLOToTTIR/unary/log_op.mlir | 2 +- .../StableHLOToTTIR/unary/logit_op.mlir | 2 +- .../StableHLOToTTIR/unary/sine_op.mlir | 2 +- .../StableHLOToTTIR/unary/tan_op.mlir | 2 +- .../StableHLOToTTIR/unary/tanh_op.mlir | 2 +- .../select_decomposition_tests.mlir | 5 +- .../TTIR/clamp/clamp_tests_negative.mlir | 3 +- .../ttmlir/Dialect/TTIR/constant_as_fill.mlir | 5 +- .../convolution_tests_negative.mlir | 7 - .../TTIR/index/index_tests_negative.mlir | 25 ++- .../TTIR/index/index_tests_positive.mlir | 17 +- .../TTIR/linear/linear_tests_negative.mlir | 51 ++---- .../TTIR/matmul/matmul_tests_negative.mlir | 31 ++-- .../TTIR/matmul/matmul_tests_positive.mlir | 25 ++- .../TTIR/select/select_tests_negative.mlir | 27 +-- .../TTIR/select/select_tests_positive.mlir | 11 +- .../TTIR/slice/slice_tests_negative.mlir | 33 ++-- .../TTIR/slice/slice_tests_positive.mlir | 17 +- test/ttmlir/Dialect/TTIR/test_allocate.mlir | 3 +- test/ttmlir/Dialect/TTIR/test_generic.mlir | 3 +- test/ttmlir/Dialect/TTIR/test_layout.mlir | 3 +- .../TTIR/test_remove_dead_values_pass.mlir | 11 +- .../TTIR/ttir_broadcastable_negative.mlir | 12 +- .../Dialect/TTIR/ttir_noperands_negative.mlir | 12 +- .../TTNN/arange/arange_tests_negative.mlir | 3 +- .../TTNN/arange/arange_tests_positive.mlir | 3 +- test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir | 3 +- .../Dialect/TTNN/ccl/all_gather_negative.mlir | 3 +- test/ttmlir/Dialect/TTNN/ccl/all_reduce.mlir | 5 +- test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir | 3 +- .../Dialect/TTNN/concat/concat_dim_oob.mlir | 3 +- .../TTNN/concat/concat_multiple_tensors.mlir | 3 +- .../TTNN/concat/concat_negative_dim.mlir | 3 +- .../TTNN/concat/concat_negative_dim_oob.mlir | 3 +- .../Dialect/TTNN/concat/simple_concat.mlir | 3 +- .../complex_conv_channel_first.mlir | 2 - .../Dialect/TTNN/convolution/simple_conv.mlir | 3 +- .../TTNN/convolution/simple_conv1d.mlir | 3 +- .../binary/logical_and/simple_and.mlir | 4 +- .../eltwise/binary/logical_or/simple_or.mlir | 4 +- .../binary/logical_xor/simple_xor.mlir | 3 +- .../binary/minimum/simple_minimum.mlir | 3 +- .../binary/remainder/simple_remainder.mlir | 3 +- .../TTNN/eltwise/operand_broadcasts.mlir | 5 +- .../eltwise/operand_broadcasts_negative.mlir | 3 +- .../TTNN/eltwise/unary/abs/simple_abs.mlir | 3 +- .../TTNN/eltwise/unary/cast/simple_cast.mlir | 9 +- .../TTNN/eltwise/unary/cbrt/simple_cbrt.mlir | 3 +- .../TTNN/eltwise/unary/ceil/simple_ceil.mlir | 3 +- .../TTNN/eltwise/unary/cos/simple_cos.mlir | 3 +- .../eltwise/unary/expm1/simple_expm1.mlir | 3 +- .../eltwise/unary/floor/simple_floor.mlir | 3 +- .../TTNN/eltwise/unary/gelu/simple_gelu.mlir | 3 +- .../unary/isfinite/simple_isfinite.mlir | 3 +- .../unary/leaky_relu/simple_leaky_relu.mlir | 3 +- .../eltwise/unary/log1p/simple_log1p.mlir | 3 +- .../eltwise/unary/logical_not/simple_not.mlir | 4 +- .../TTNN/eltwise/unary/negate/simple_neg.mlir | 3 +- .../unary/reciprocal/simple_reciprocal.mlir | 3 +- .../TTNN/eltwise/unary/relu/simple_relu.mlir | 3 +- .../eltwise/unary/rsqrt/simple_rsqrt.mlir | 3 +- .../eltwise/unary/sigmoid/simple_sigmoid.mlir | 3 +- .../TTNN/eltwise/unary/sign/simple_sign.mlir | 3 +- .../TTNN/eltwise/unary/sin/simple_sin.mlir | 3 +- .../TTNN/eltwise/unary/sqrt/simple_sqrt.mlir | 3 +- .../TTNN/eltwise/unary/tan/simple_tan.mlir | 3 +- .../TTNN/eltwise/unary/tanh/simple_tanh.mlir | 3 +- .../TTNN/embedding/embedding_1d_tensor.mlir | 3 +- .../TTNN/embedding/embedding_non_tile.mlir | 3 +- .../TTNN/embedding/gather_to_embedding.mlir | 10 +- .../gather_to_embedding_negative.mlir | 24 +-- .../TTNN/embedding/simple_embedding.mlir | 3 +- .../TTNN/linear/linear_tests_positive.mlir | 33 ++-- .../Dialect/TTNN/linear/simple_linear.mlir | 5 +- .../TTNN/matmul/matmul_tests_negative.mlir | 16 +- .../TTNN/matmul/matmul_tests_positive.mlir | 25 ++- .../Dialect/TTNN/matmul/simple_matmul.mlir | 3 +- test/ttmlir/Dialect/TTNN/multiple_func.mlir | 3 +- .../optimizer/input_layout_loc_override.mlir | 3 +- .../insert_memreconfig_override.mlir | 7 +- .../all_l1_interleaved_policy.mlir | 13 +- .../l1_interleaved_policy/fork_join.mlir | 11 +- .../mnist_l1_interleaved.mlir | 13 +- .../simple_join_tests/dram_ABC_l1_None.mlir | 7 +- .../simple_join_tests/dram_AB_l1_C.mlir | 7 +- .../simple_join_tests/dram_AC_l1_B.mlir | 7 +- .../simple_join_tests/dram_A_l1_BC.mlir | 7 +- .../simple_join_tests/dram_BC_l1_A.mlir | 7 +- .../simple_join_tests/dram_B_l1_AC.mlir | 7 +- .../simple_join_tests/dram_C_l1_AB.mlir | 7 +- .../simple_join_tests/dram_None_l1_ABC.mlir | 7 +- .../l1_interleaved_policy/single_op.mlir | 3 +- .../TTNN/optimizer/multiple_add_with_loc.mlir | 7 +- .../optimizer/output_layout_override.mlir | 7 +- .../partial_output_layout_override.mlir | 5 +- .../optimizer/sharding_matmul_override_0.mlir | 5 +- .../sharding_matmul_override_32.mlir | 5 +- .../TTNN/optimizer/ttir_to_ttnn_pipeline.mlir | 3 +- .../TTNN/pipelines/ttir_to_emitc_add.mlir | 4 +- .../Dialect/TTNN/pooling/complex_pooling.mlir | 6 +- .../TTNN/pooling/simple_maxpool2d.mlir | 3 +- .../Dialect/TTNN/pooling/simple_pooling.mlir | 4 +- test/ttmlir/Dialect/TTNN/remove_empty_op.mlir | 3 +- .../TTNN/reshape/reshape_folding_test.mlir | 3 +- .../ttmlir/Dialect/TTNN/simple_broadcast.mlir | 5 +- test/ttmlir/Dialect/TTNN/simple_clamp.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_compare.mlir | 15 +- test/ttmlir/Dialect/TTNN/simple_constant.mlir | 1 - test/ttmlir/Dialect/TTNN/simple_div.mlir | 3 +- .../TTNN/simple_get_dimension_size.mlir | 1 - test/ttmlir/Dialect/TTNN/simple_max.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_maximum.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_mean.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_multiply.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_reshape.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_scatter.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_slice.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_squeeze.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_subtract.mlir | 3 +- .../Dialect/TTNN/simple_subtract_to_add.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_sum.mlir | 3 +- .../ttmlir/Dialect/TTNN/simple_unsqueeze.mlir | 3 +- test/ttmlir/Dialect/TTNN/simple_where.mlir | 5 +- .../Dialect/TTNN/softmax/simple_softmax.mlir | 5 +- .../TTNN/softmax/softmax_negative_1.mlir | 3 +- .../TTNN/softmax/softmax_negative_2.mlir | 3 +- .../TTNN/transpose/simple_transpose.mlir | 3 +- .../simple_transpose_8x16_reverse_dims.mlir | 3 +- .../TTNN/transpose/simple_transpose_8x8.mlir | 3 +- .../simple_transpose_negative_dims.mlir | 3 +- .../TTNN/transpose/transpose_twice.mlir | 5 +- .../ttir_to_ttnn_pipeline_custom_opt.mlir | 5 +- .../Silicon/TTMetal/simple_constant.mlir | 4 +- .../Silicon/TTMetal/simple_eltwise.mlir | 11 +- test/ttmlir/Silicon/TTMetal/simple_max.mlir | 4 +- .../ttmlir/Silicon/TTMetal/simple_reduce.mlir | 13 +- .../Silicon/TTMetal/simple_reduce_1x1.mlir | 10 +- .../arange/simple_device_arange_dim2.mlir | 3 +- .../arange/simple_device_arange_dim3.mlir | 3 +- test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir | 5 +- .../TTNN/complex_conv_channel_first.mlir | 2 - test/ttmlir/Silicon/TTNN/deallocate.mlir | 13 +- .../TTNN/embedding/embedding_1d_tensor.mlir | 3 +- .../TTNN/embedding/embedding_backward.mlir | 4 +- .../TTNN/embedding/embedding_non_tile.mlir | 3 +- .../TTNN/embedding/gather_to_embedding.mlir | 6 +- .../TTNN/embedding/simple_embedding.mlir | 3 +- .../ttmlir/Silicon/TTNN/emitc/simple_add.mlir | 4 +- test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir | 6 +- .../Silicon/TTNN/kv_cache/fill_cache.mlir | 5 +- .../Silicon/TTNN/kv_cache/update_cache.mlir | 5 +- test/ttmlir/Silicon/TTNN/multi_device.mlir | 5 +- .../Silicon/TTNN/operand_broadcasts.mlir | 5 +- .../TTNN/optimizer/mnist_sharding.mlir | 13 +- .../TTNN/optimizer/simple_fork_join.mlir | 9 +- test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir | 13 +- .../Silicon/TTNN/perf_unit/test_perf_and.mlir | 6 +- .../TTNN/perf_unit/test_perf_ceil.mlir | 5 +- .../TTNN/perf_unit/test_perf_clamp.mlir | 5 +- .../TTNN/perf_unit/test_perf_concat.mlir | 5 +- .../TTNN/perf_unit/test_perf_conv.mlir | 3 +- .../TTNN/perf_unit/test_perf_cosine.mlir | 5 +- .../Silicon/TTNN/perf_unit/test_perf_div.mlir | 5 +- .../TTNN/perf_unit/test_perf_embedding.mlir | 3 +- .../Silicon/TTNN/perf_unit/test_perf_eq.mlir | 6 +- .../TTNN/perf_unit/test_perf_expm1.mlir | 5 +- .../TTNN/perf_unit/test_perf_floor.mlir | 5 +- .../Silicon/TTNN/perf_unit/test_perf_ge.mlir | 5 +- .../TTNN/perf_unit/test_perf_gelu.mlir | 4 +- .../Silicon/TTNN/perf_unit/test_perf_gt.mlir | 6 +- .../TTNN/perf_unit/test_perf_isfinite.mlir | 4 +- .../TTNN/perf_unit/test_perf_linear.mlir | 3 +- .../Silicon/TTNN/perf_unit/test_perf_log.mlir | 5 +- .../TTNN/perf_unit/test_perf_log1p.mlir | 5 +- .../Silicon/TTNN/perf_unit/test_perf_lt.mlir | 6 +- .../TTNN/perf_unit/test_perf_matmul.mlir | 3 +- .../Silicon/TTNN/perf_unit/test_perf_max.mlir | 4 +- .../TTNN/perf_unit/test_perf_maximum.mlir | 5 +- .../TTNN/perf_unit/test_perf_maxpool2d.mlir | 3 +- .../TTNN/perf_unit/test_perf_multiply.mlir | 5 +- .../Silicon/TTNN/perf_unit/test_perf_ne.mlir | 6 +- .../Silicon/TTNN/perf_unit/test_perf_neg.mlir | 5 +- .../Silicon/TTNN/perf_unit/test_perf_not.mlir | 6 +- .../Silicon/TTNN/perf_unit/test_perf_or.mlir | 6 +- .../TTNN/perf_unit/test_perf_reciprocal.mlir | 5 +- .../TTNN/perf_unit/test_perf_relu.mlir | 5 +- .../TTNN/perf_unit/test_perf_remainder.mlir | 5 +- .../TTNN/perf_unit/test_perf_rsqrt.mlir | 5 +- .../TTNN/perf_unit/test_perf_sigmoid.mlir | 5 +- .../TTNN/perf_unit/test_perf_sign.mlir | 5 +- .../TTNN/perf_unit/test_perf_sine.mlir | 5 +- .../TTNN/perf_unit/test_perf_slice.mlir | 3 +- .../TTNN/perf_unit/test_perf_softmax.mlir | 7 +- .../TTNN/perf_unit/test_perf_sqrt.mlir | 5 +- .../TTNN/perf_unit/test_perf_subtract.mlir | 5 +- .../Silicon/TTNN/perf_unit/test_perf_sum.mlir | 4 +- .../Silicon/TTNN/perf_unit/test_perf_tan.mlir | 4 +- .../TTNN/perf_unit/test_perf_tanh.mlir | 4 +- .../TTNN/perf_unit/test_perf_transpose.mlir | 5 +- .../TTNN/perf_unit/test_perf_typecast.mlir | 3 +- .../TTNN/perf_unit/test_perf_where.mlir | 6 +- .../Silicon/TTNN/perf_unit/test_perf_xor.mlir | 5 +- .../Silicon/TTNN/pooling/complex_pooling.mlir | 4 +- .../Silicon/TTNN/pooling/simple_pooling.mlir | 4 +- .../TTNN/sharded/simple_eltwise_sharded.mlir | 23 ++- test/ttmlir/Silicon/TTNN/simple_compare.mlir | 16 +- test/ttmlir/Silicon/TTNN/simple_conv.mlir | 3 +- test/ttmlir/Silicon/TTNN/simple_eltwise.mlir | 81 +++++---- test/ttmlir/Silicon/TTNN/simple_index.mlir | 3 +- test/ttmlir/Silicon/TTNN/simple_linear.mlir | 5 +- test/ttmlir/Silicon/TTNN/simple_logical.mlir | 12 +- test/ttmlir/Silicon/TTNN/simple_matmul.mlir | 3 +- .../ttmlir/Silicon/TTNN/simple_maxpool2d.mlir | 3 +- test/ttmlir/Silicon/TTNN/simple_mean.mlir | 3 +- .../Silicon/TTNN/simple_reductions.mlir | 16 +- test/ttmlir/Silicon/TTNN/simple_slice.mlir | 3 +- test/ttmlir/Silicon/TTNN/simple_typecast.mlir | 3 +- test/ttmlir/Silicon/TTNN/transpose.mlir | 11 +- test/ttmlir/Translate/TTNN/1d_tensor.mlir | 4 +- .../unittests/TestScheduler/TestScheduler.cpp | 40 ++--- .../test/models/forward_and_backward.mlir | 22 +-- .../test/models/linear_autoencoder.mlir | 44 ++--- .../models/open_llama_3b_single_layer.mlir | 163 +++++++++--------- 245 files changed, 754 insertions(+), 1462 deletions(-) diff --git a/docs/src/overview.md b/docs/src/overview.md index 42cbb5086..f2e87fa03 100644 --- a/docs/src/overview.md +++ b/docs/src/overview.md @@ -81,11 +81,10 @@ simpler. So what does MLIR look like, how does it work and get parsed? The hierarchy of an MLIR Module is as shown: ``` -#any_device = #tt.operand_constraint module attributes {tt.system_desc = #tt.system_desc<[<#tt.arch, #tt.grid<8x8>>], [0], [], [<0, 0, 0, 0>]>} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 47e770bf7..ff1cc61be 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -185,8 +185,7 @@ class TTIR_ElementwiseOp traits = []> : }]; let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - TT_OperandConstraintArrayAttr:$operand_constraints); + Variadic:$outputs); let results = (outs Variadic:$results); } @@ -199,9 +198,9 @@ class TTIR_ElementwiseTernaryOp traits = []> : let builders = [ - OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out, "ArrayAttr": $operand_constraints), + OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out), [{ - build($_builder, $_state, {out.getType()}, {first, second, third}, out, operand_constraints); + build($_builder, $_state, {out.getType()}, {first, second, third}, out); }]> ]; } @@ -222,9 +221,9 @@ class TTIR_ElementwiseUnaryOp traits = []> : let builders = [ - OpBuilder<(ins "Value": $in, "Value": $out, "ArrayAttr": $operand_constraints), + OpBuilder<(ins "Value": $in, "Value": $out), [{ - build($_builder, $_state, {out.getType()}, in, out, operand_constraints); + build($_builder, $_state, {out.getType()}, in, out); }]> ]; } @@ -408,14 +407,13 @@ class TTIR_ElementwiseUnaryWithFloatParameterOp tra let arguments = (ins Variadic:$inputs, Variadic:$outputs, - F32Attr:$parameter, - TT_OperandConstraintArrayAttr:$operand_constraints); + F32Attr:$parameter); let builders = [ - OpBuilder<(ins "Value": $in, "Value": $out, "FloatAttr":$parameter, "ArrayAttr": $operand_constraints), + OpBuilder<(ins "Value": $in, "Value": $out, "FloatAttr":$parameter), [{ - build($_builder, $_state, {out.getType()}, {in}, {out}, parameter, operand_constraints); + build($_builder, $_state, {out.getType()}, {in}, {out}, parameter); }]> ]; } @@ -452,9 +450,9 @@ class TTIR_ElementwiseBinaryOp traits = []> : let builders = [ - OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out, "ArrayAttr": $operand_constraints), + OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out), [{ - build($_builder, $_state, {out.getType()}, {lhs, rhs}, out, operand_constraints); + build($_builder, $_state, {out.getType()}, {lhs, rhs}, out); }]> ]; } @@ -568,8 +566,7 @@ class TTIR_ReductionOp traits = []> : let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, BoolAttr:$keep_dim, - OptionalAttr:$dim_arg, - TT_OperandConstraintArrayAttr:$operand_constraints); + OptionalAttr:$dim_arg); let results = (outs AnyRankedTensor:$result); @@ -636,8 +633,7 @@ def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$weight, - AnyRankedTensor:$output, - TT_OperandConstraintArrayAttr:$operand_constraints); + AnyRankedTensor:$output); let results = (outs AnyRankedTensor:$result); @@ -657,8 +653,7 @@ def TTIR_EmbeddingBackwardOp : TTIR_DPSOp<"embedding_backward"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$weight, AnyRankedTensor:$in_gradient, - AnyRankedTensor:$output, - TT_OperandConstraintArrayAttr:$operand_constraints); + AnyRankedTensor:$output); let results = (outs AnyRankedTensor:$result); @@ -677,8 +672,7 @@ def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, - SI32Attr:$dimension, - TT_OperandConstraintArrayAttr:$operand_constraints); + SI32Attr:$dimension); let results = (outs AnyRankedTensor:$result); @@ -698,8 +692,7 @@ def TTIR_TransposeOp : TTIR_DPSOp<"transpose"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, SI32Attr:$dim0, - SI32Attr:$dim1, - TT_OperandConstraintArrayAttr:$operand_constraints); + SI32Attr:$dim1); let results = (outs AnyRankedTensor:$result); @@ -718,8 +711,7 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> { let arguments = (ins Variadic:$inputs, AnyRankedTensor:$output, - SI32Attr:$dim, - TT_OperandConstraintArrayAttr:$operand_constraints); + SI32Attr:$dim); let results = (outs AnyRankedTensor:$result); @@ -739,8 +731,7 @@ def TTIR_UpdateCacheOp : TTIR_DPSOp<"update_cache"> { let arguments = (ins AnyRankedTensor:$cache, AnyRankedTensor:$input, AnyRankedTensor:$update_index, - I32Attr:$batch_offset, - TT_OperandConstraintArrayAttr:$operand_constraints); + I32Attr:$batch_offset); let results = (outs AnyRankedTensor:$result); @@ -759,8 +750,7 @@ def TTIR_FillCacheOp : TTIR_DPSOp<"fill_cache"> { let arguments = (ins AnyRankedTensor:$cache, AnyRankedTensor:$input, - I32Attr:$batch_offset, - TT_OperandConstraintArrayAttr:$operand_constraints); + I32Attr:$batch_offset); let results = (outs AnyRankedTensor:$result); @@ -779,8 +769,7 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, - I64ArrayAttr:$dimension, - TT_OperandConstraintArrayAttr:$operand_constraints); + I64ArrayAttr:$dimension); let results = (outs AnyRankedTensor:$result); @@ -807,8 +796,7 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> { SI32Attr:$padding_left, SI32Attr:$padding_right, SI32Attr:$padding_top, - SI32Attr:$padding_bottom, - TT_OperandConstraintArrayAttr:$operand_constraints); + SI32Attr:$padding_bottom); let results = (outs AnyRankedTensor:$result); @@ -841,8 +829,7 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> { DenseBoolArrayAttr:$window_reversal, TTIR_ConvolutionLayoutAttr:$convolution_layout, ConfinedAttr:$feature_group_count, - ConfinedAttr:$batch_group_count, - TT_OperandConstraintArrayAttr:$operand_constraints + ConfinedAttr:$batch_group_count ); let results = (outs AnyRankedTensor); @@ -869,8 +856,7 @@ def TTIR_GatherOp: TTIR_DPSOp<"gather"> { DenseI64ArrayAttr:$start_index_map, SI64Attr:$index_vector_dim, DenseI64ArrayAttr:$slice_sizes, - BoolAttr:$indices_are_sorted, - TT_OperandConstraintArrayAttr:$operand_constraints); + BoolAttr:$indices_are_sorted); let results = (outs AnyRankedTensor:$result); let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } @@ -891,8 +877,7 @@ def TTIR_PoolingOp : TTIR_DPSOp<"pooling", [AttrSizedOperandSegments]> { DenseI64ArrayAttr:$window_strides, DenseI64ArrayAttr:$base_dilations, DenseI64ArrayAttr:$window_dilations, - DenseI64ArrayAttr:$padding, - TT_OperandConstraintArrayAttr:$operand_constraints + DenseI64ArrayAttr:$padding ); let results = (outs Variadic); @@ -918,8 +903,7 @@ def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> { SI32Attr:$padding_left, SI32Attr:$padding_right, SI32Attr:$padding_top, - SI32Attr:$padding_bottom, - TT_OperandConstraintArrayAttr:$operand_constraints); + SI32Attr:$padding_bottom); let results = (outs AnyRankedTensor:$result); @@ -938,8 +922,7 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, - I32ArrayAttr:$shape, - TT_OperandConstraintArrayAttr:$operand_constraints); + I32ArrayAttr:$shape); let results = (outs AnyRankedTensor:$result); @@ -964,8 +947,7 @@ def TTIR_SliceOp: TTIR_DPSOp<"slice"> { AnyRankedTensor:$output, I32ArrayAttr:$begins, I32ArrayAttr:$ends, - I32ArrayAttr:$step, - TT_OperandConstraintArrayAttr:$operand_constraints); + I32ArrayAttr:$step); let results = (outs AnyRankedTensor:$result); @@ -991,8 +973,7 @@ def TTIR_SelectOp: TTIR_DPSOp<"select"> { SI32Attr:$dim, SI32Attr:$begin, SI32Attr:$length, - DefaultValuedOptionalAttr:$stride, - TT_OperandConstraintArrayAttr:$operand_constraints); + DefaultValuedOptionalAttr:$stride); let results = (outs AnyRankedTensor:$result); @@ -1017,8 +998,7 @@ def TTIR_IndexOp: TTIR_DPSOp<"index"> { I32Attr:$dim, I32Attr:$begin, I32Attr:$end, - I32Attr:$step, - TT_OperandConstraintArrayAttr:$operand_constraints); + I32Attr:$step); let results = (outs AnyRankedTensor:$result); @@ -1038,8 +1018,7 @@ def TTIR_SqueezeOp : TTIR_DPSOp<"squeeze"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, - SI32Attr:$dim, - TT_OperandConstraintArrayAttr:$operand_constraints); + SI32Attr:$dim); let results = (outs AnyRankedTensor:$result); @@ -1058,8 +1037,7 @@ def TTIR_UnsqueezeOp : TTIR_DPSOp<"unsqueeze"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, - SI32Attr:$dim, - TT_OperandConstraintArrayAttr:$operand_constraints); + SI32Attr:$dim); let results = (outs AnyRankedTensor:$result); @@ -1087,8 +1065,7 @@ def TTIR_ClampOp : TTIR_DPSOp<"clamp"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, F32Attr:$min, - F32Attr:$max, - TT_OperandConstraintArrayAttr:$operand_constraints); + F32Attr:$max); let extraClassDeclaration = [{ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); } @@ -1191,8 +1168,7 @@ def TTIR_FillOp : TTIR_DPSOp<"fill", [AllShapesMatch<["value", "result"]>]> { }]; let arguments = (ins AnyRankedTensor:$output, - ElementsAttr:$value, - TT_OperandConstraintArrayAttr:$operand_constraints); + ElementsAttr:$value); let results = (outs AnyRankedTensor:$result); @@ -1217,8 +1193,7 @@ def TTIR_LinearOp : TTIR_DPSOp<"linear"> { let arguments = (ins AnyRankedTensor:$a, AnyRankedTensor:$b, Optional:$bias, - AnyRankedTensor:$output, - TT_OperandConstraintArrayAttr:$operand_constraints); + AnyRankedTensor:$output); let results = (outs AnyRankedTensor:$result); @@ -1238,8 +1213,7 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> { let arguments = (ins AnyRankedTensor:$a, AnyRankedTensor:$b, - AnyRankedTensor:$output, - TT_OperandConstraintArrayAttr:$operand_constraints); + AnyRankedTensor:$output); let results = (outs AnyRankedTensor:$result); @@ -1362,8 +1336,7 @@ def TTIR_ScatterOp: TTIR_DPSOp<"scatter"> { I32Attr:$index_vector_dim, BoolAttr:$indices_are_sorted, BoolAttr:$unique_indices, - AnyRankedTensor:$output, - TT_OperandConstraintArrayAttr:$operand_constraints); + AnyRankedTensor:$output); let regions = (region SizedRegion<1>:$update_computation); @@ -1391,8 +1364,7 @@ def TTIR_KernelOp : TTIR_DPSOp<"kernel", [AttrSizedOperandSegments]> { let arguments = (ins FlatSymbolRefAttr:$op, FlatSymbolRefAttr:$kind, Variadic:$inputs, - Variadic:$outputs, - TT_OperandConstraintArrayAttr:$operand_constraints); + Variadic:$outputs); let results = (outs Variadic:$results); } @@ -1417,8 +1389,7 @@ def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> { let arguments = (ins AnyRankedTensor:$input, AnyRankedTensor:$output, - SI32Attr:$dim, - TT_OperandConstraintArrayAttr:$operand_constraints); + SI32Attr:$dim); let results = (outs AnyRankedTensor:$result); @@ -1442,8 +1413,7 @@ def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> { SI32Attr:$dim, OptionalAttr:$channel_handle, UnitAttr:$use_global_device_ids, - TT_ReduceTypeAttr:$reduce_type, - TT_OperandConstraintArrayAttr:$operand_constraints + TT_ReduceTypeAttr:$reduce_type ); let results = (outs Variadic:$results); @@ -1490,8 +1460,7 @@ def TTIR_MeshShardOp : TTIR_DPSOp<"mesh_shard"> { AnyRankedTensor:$output, TT_MeshShardTypeAttr:$shard_type, TT_MeshShardDirectionAttr:$shard_direction, - TT_GridAttr:$shard_shape, - TT_OperandConstraintArrayAttr:$operand_constraints + TT_GridAttr:$shard_shape ); let results = (outs AnyRankedTensor:$result); diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td index a130332f0..64c314279 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td @@ -11,16 +11,6 @@ include "ttmlir/Dialect/TT/IR/TTOpsTypes.td" def TTIROpInterface : OpInterface<"TTIROp"> { let cppNamespace = "::mlir::tt::ttir"; let methods = [ - InterfaceMethod< - /*desc=*/[{ - Return the constraints on the operands of this operation. - }], - /*retTy=*/"::mlir::ArrayAttr", - /*methodName=*/"getOperandConstraints", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/"" - >, InterfaceMethod< /*desc=*/[{ Get the device of the current scope. diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 96ef7ca01..4f2f82361 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -53,11 +53,7 @@ class StableHLOToTTIROpDefaultConversionPattern srcOp, TypeRange( this->getTypeConverter()->convertType(outputTensor.getType())), - adaptor.getOperands(), ValueRange(outputTensor), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + adaptor.getOperands(), ValueRange(outputTensor)); return success(); } }; @@ -125,18 +121,9 @@ class StableHLOToTTIRReduceOpConversionPattern ? adaptor.getDimensionsAttr()[0] : 1))); - // If someone changes definition of TTIR_ReductionOp this constant will - // become outdated, but I currently see no way to get this info (without - // manually constructing the adaptor for dest OP). - const std::size_t ttirReduceOpOperandsCount = 2; - mlir::ArrayAttr operandConstraints = - rewriter.getArrayAttr(SmallVector( - ttirReduceOpOperandsCount, rewriter.getAttr( - OperandConstraint::AnyDeviceTile))); - rewriter.replaceOpWithNewOp( srcOp, outputType, adaptor.getInputs().front(), outputTensor, - false /* keep_dim */, dimArg, operandConstraints); + false /* keep_dim */, dimArg); return success(); } @@ -171,11 +158,7 @@ class StableHLOToTTIRTransposeOpConversionPattern input = rewriter.create( srcOp.getLoc(), outputType, input, outputTensor, - rewriter.getSI32IntegerAttr(dim0), rewriter.getSI32IntegerAttr(dim1), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + rewriter.getSI32IntegerAttr(dim0), rewriter.getSI32IntegerAttr(dim1)); } rewriter.replaceOp(srcOp, input); return success(); @@ -218,11 +201,7 @@ class StableHLOToTTIRReshapeOpConversionPattern ArrayAttr new_shape_attr = rewriter.getI32ArrayAttr(new_shape_i32); rewriter.replaceOpWithNewOp( srcOp, getTypeConverter()->convertType(outputTensor.getType()), - adaptor.getOperand(), outputTensor, new_shape_attr, - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + adaptor.getOperand(), outputTensor, new_shape_attr); return success(); } @@ -265,11 +244,7 @@ class StableHLOToTTIRDotGeneralOpConversionPattern rewriter.replaceOpWithNewOp( srcOp, getTypeConverter()->convertType(outputTensor.getType()), - adaptor.getLhs(), adaptor.getRhs(), Value(outputTensor), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + adaptor.getLhs(), adaptor.getRhs(), Value(outputTensor)); return success(); } @@ -583,11 +558,7 @@ class StableHLOToTTIRConvolutionOpConversionPattern dimNums.getOutputBatchDimension(), dimNums.getOutputFeatureDimension(), dimNums.getOutputSpatialDimensions()), - adaptor.getFeatureGroupCountAttr(), adaptor.getBatchGroupCountAttr(), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + adaptor.getFeatureGroupCountAttr(), adaptor.getBatchGroupCountAttr()); return success(); } @@ -683,10 +654,6 @@ class StableHLOToTTIRReduceWindowOpConversionPattern : rewriter.getDenseI64ArrayAttr( SmallVector(windowDimensions.size() * 2, 0)); - auto operandConstraints = rewriter.getArrayAttr(SmallVector( - adaptor.getOperands().size(), rewriter.getAttr( - OperandConstraint::AnyDeviceTile))); - mlir::tt::ttir::PoolingMethod poolingMethod; if (isMaxPool(srcOp)) { poolingMethod = mlir::tt::ttir::PoolingMethod::Max; @@ -701,7 +668,7 @@ class StableHLOToTTIRReduceWindowOpConversionPattern rewriter.replaceOpWithNewOp( srcOp, outputType, adaptor.getInputs(), outputs, poolingMethod, windowDimensions, windowStrides, baseDilations, window_dilations, - padding, operandConstraints); + padding); return success(); } @@ -836,11 +803,7 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern rewriter.replaceOpWithNewOp( srcOp, getTypeConverter()->convertType(outputTensor.getType()), - Value(adaptor.getOperand()), Value(outputTensor), dimArg, - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + Value(adaptor.getOperand()), Value(outputTensor), dimArg); return success(); } @@ -932,11 +895,7 @@ class StableHLOToTTIRCompareOpConversionPattern srcOp, TypeRange( this->getTypeConverter()->convertType(outputTensor.getType())), - adaptor.getOperands(), ValueRange(outputTensor), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + adaptor.getOperands(), ValueRange(outputTensor)); return success(); } @@ -975,11 +934,7 @@ class StableHLOToTTIRConcatOpConversionPattern adaptor.getInputs(), // input values Value(outputTensor), // output value rewriter.getSI32IntegerAttr( - static_cast(adaptor.getDimension())), // dimension - rewriter.getArrayAttr( // operand constraints - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + static_cast(adaptor.getDimension()))); // dimension return success(); } @@ -1035,11 +990,7 @@ class StableHLOToTTIROpLogicalOpConversionPattern srcOp, TypeRange( this->getTypeConverter()->convertType(outputTensor.getType())), - adaptor.getOperands(), ValueRange(outputTensor), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + adaptor.getOperands(), ValueRange(outputTensor)); return success(); } @@ -1197,14 +1148,6 @@ class StableHLOToTTIRAllReduceOpConversionPattern Attribute reduceTypeAttr = rewriter.getAttr(reduceType); ttirAttrs.push_back({reduceTypeAttrName, reduceTypeAttr}); - StringAttr operationConstraintAttrName = - StringAttr::get(this->getContext(), "operand_constraints"); - Attribute operationConstraintAttr = rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile))); - ttirAttrs.push_back({operationConstraintAttrName, operationConstraintAttr}); - auto ttirAllReduceOp = rewriter.create( srcOp.getLoc(), ttirTypes, ValueRange(ttirOperands.getAsOperandRange()), ttirAttrs); @@ -1463,15 +1406,6 @@ class StableHLOToTTIRCustomCallOpConversionPattern GridAttr::get(this->getContext(), shardAttrValue.shardShape); meshShardAttrs.push_back({shardShapeAttrName, shardShape}); - StringAttr operationConstraintAttrName = - StringAttr::get(this->getContext(), "operand_constraints"); - Attribute operationConstraintAttr = rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::SystemScalar))); - meshShardAttrs.push_back( - {operationConstraintAttrName, operationConstraintAttr}); - auto meshShardOp = rewriter.create( srcOp.getLoc(), outputTypes, ValueRange(meshShardOperands.getAsOperandRange()), meshShardAttrs); @@ -1527,11 +1461,7 @@ class StableHLOToTTIRSliceOpConversionPattern adaptor.getOperand(), // input values outputTensor, // output value rewriter.getI32ArrayAttr(start_indices), - rewriter.getI32ArrayAttr(end_indices), rewriter.getI32ArrayAttr(step), - rewriter.getArrayAttr( // operand constraints - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + rewriter.getI32ArrayAttr(end_indices), rewriter.getI32ArrayAttr(step)); return success(); } }; @@ -1575,31 +1505,19 @@ class StableHLOToTTIROpClampOpConversionPattern this->getTypeConverter()->convertType(outputTensor.getType()), Value(adaptor.getOperand()), Value(outputTensor), rewriter.getF32FloatAttr(minValue), - rewriter.getF32FloatAttr(maxValue), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + rewriter.getF32FloatAttr(maxValue)); return success(); } } ttir::MaximumOp maximumOp = rewriter.create( - srcOp->getLoc(), min, adaptor.getOperand(), outputTensor, - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + srcOp->getLoc(), min, adaptor.getOperand(), outputTensor); tensor::EmptyOp finalOutputTensor = rewriter.create( srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); rewriter.replaceOpWithNewOp( - srcOp, maximumOp->getResult(0), max, finalOutputTensor, - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + srcOp, maximumOp->getResult(0), max, finalOutputTensor); return success(); } }; @@ -1629,11 +1547,7 @@ class StableHLOToTTIRGatherOpConversionPattern dimensionNumbers.getOperandBatchingDims(), dimensionNumbers.getStartIndicesBatchingDims(), dimensionNumbers.getStartIndexMap(), - dimensionNumbers.getIndexVectorDim(), srcOp.getSliceSizesAttr(), false, - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + dimensionNumbers.getIndexVectorDim(), srcOp.getSliceSizesAttr(), false); return success(); } }; @@ -1686,9 +1600,6 @@ class StableHLOToTTIRScatterOpConversionPattern Value operand = srcOp.getInputs()[0]; Value scatterIndices = srcOp.getScatterIndices(); Value update = srcOp.getUpdates()[0]; - mlir::ArrayAttr binaryConstraints = rewriter.getArrayAttr( - SmallVector(4, rewriter.getAttr( - OperandConstraint::AnyDeviceTile))); auto updateWindowsDims = adaptor.getScatterDimensionNumbers().getUpdateWindowDims(); auto insertedWindowDims = @@ -1716,8 +1627,7 @@ class StableHLOToTTIRScatterOpConversionPattern convertArrayRefToInt32vector(scatterIndicesBatchingDims)), llvm::ArrayRef( convertArrayRefToInt32vector(scatterDimsToOperandDims)), - indexVectorDim, indicesAreSorted, uniqueIndices, outputTensor, - binaryConstraints); + indexVectorDim, indicesAreSorted, uniqueIndices, outputTensor); // Replaces with different types do not work and will fail silently, so we // manually set the second operand, since the type changes there from i32 to diff --git a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp index 795ef6688..668801e0b 100644 --- a/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp +++ b/lib/Conversion/TTIRToTTIRDecomposition/TTIRToTTIRDecomposition.cpp @@ -65,7 +65,7 @@ struct IndexToSliceConversionPattern auto newOp = rewriter.create( op.getLoc(), op.getType(), adaptor.getInput(), adaptor.getOutput(), rewriter.getArrayAttr(begins), rewriter.getArrayAttr(ends), - rewriter.getArrayAttr(steps), adaptor.getOperandConstraints()); + rewriter.getArrayAttr(steps)); rewriter.replaceOp(op, newOp.getResult()); return success(); @@ -145,8 +145,7 @@ generateTransposeIndices(std::vector currentLayout, * result at the end of the sequence */ static Value generateTransposeOps(Value input, PatternRewriter &rewriter, - std::vector transposeIndices, - ::mlir::ArrayAttr operandConstraints) { + std::vector transposeIndices) { for (auto [dim0, dim1] : transposeIndices) { auto inputType = mlir::cast(input.getType()); @@ -163,8 +162,7 @@ static Value generateTransposeOps(Value input, PatternRewriter &rewriter, input.getLoc(), outputShape, outputType.getElementType()); input = rewriter .create(input.getLoc(), outputType, input, - dpsOutput, dim0Attr, dim1Attr, - operandConstraints) + dpsOutput, dim0Attr, dim1Attr) .getResult(); } @@ -317,12 +315,10 @@ struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { weightShape.end()); reshapeWeightShape.push_back(1); - ttir::ReshapeOp reshapeInput = - createReshapeOp(op.getLoc(), adaptor.getInput(), reshapeInputShape, - op.getOperandConstraints(), rewriter); - ttir::ReshapeOp reshapeWeight = - createReshapeOp(op.getLoc(), adaptor.getWeight(), reshapeWeightShape, - op.getOperandConstraints(), rewriter); + ttir::ReshapeOp reshapeInput = createReshapeOp( + op.getLoc(), adaptor.getInput(), reshapeInputShape, rewriter); + ttir::ReshapeOp reshapeWeight = createReshapeOp( + op.getLoc(), adaptor.getWeight(), reshapeWeightShape, rewriter); mlir::DenseI64ArrayAttr conv2dOpWindowsStridesAttr = addIntegerToDenseArrayAttr(rewriter, adaptor.getWindowStridesAttr(), 1); @@ -375,14 +371,9 @@ struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { convolutionLayout.getOutputFeatureDimension(), conv2dOutputSpatialDimensions), adaptor.getFeatureGroupCountAttr(), - adaptor.getBatchGroupCountAttr(), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + adaptor.getBatchGroupCountAttr()); ttir::ReshapeOp reshapeOutput = - createReshapeOp(op.getLoc(), new2dConvolutionOp, outputShape, - op.getOperandConstraints(), rewriter); + createReshapeOp(op.getLoc(), new2dConvolutionOp, outputShape, rewriter); rewriter.replaceOp(op, reshapeOutput); @@ -392,7 +383,6 @@ struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { private: ttir::ReshapeOp createReshapeOp(Location loc, Value tensor, llvm::ArrayRef target_input_shape, - ::mlir::ArrayAttr constraints, ConversionPatternRewriter &rewriter) const { auto inputType = mlir::cast(tensor.getType()); @@ -407,7 +397,7 @@ struct Legalize1DConvolutionPattern : public ConvolutionDecompositionPattern { loc, mlir::RankedTensorType::get(target_input_shape, inputType.getElementType()), - tensor, DPSReshapeOutput, shape_attr, constraints); + tensor, DPSReshapeOutput, shape_attr); } mlir::DenseI64ArrayAttr @@ -503,26 +493,23 @@ struct ConvolutionToConv2dPattern : public ConvolutionDecompositionPattern { auto transposeIndices = generateConvTransposeIndices(op, conv2dLayout); Value input = - generateTransposeOps(adaptor.getInput(), rewriter, transposeIndices, - adaptor.getOperandConstraints()); + generateTransposeOps(adaptor.getInput(), rewriter, transposeIndices); auto kernelTransposeIndices = generateConvKernelTransposeIndices(op, conv2dKernelLayout); Value weight = generateTransposeOps(adaptor.getWeight(), rewriter, - kernelTransposeIndices, - adaptor.getOperandConstraints()); + kernelTransposeIndices); ttir::Conv2dOp newConv = rewriter.create( op.getLoc(), outputType, input, weight, adaptor.getBias(), convDPSOutput, strideHeightAttr, strideWidthAttr, dilationHeightAttr, dilationWidthAttr, groupsAttr, paddingLeftAttr, paddingRightAttr, - paddingTopAttr, paddingBottomAttr, adaptor.getOperandConstraints()); + paddingTopAttr, paddingBottomAttr); // Applying the transposes in reverse order to the output will restore the // tensor to the original layout std::reverse(transposeIndices.begin(), transposeIndices.end()); Value output = - generateTransposeOps(newConv.getResult(), rewriter, transposeIndices, - adaptor.getOperandConstraints()); + generateTransposeOps(newConv.getResult(), rewriter, transposeIndices); rewriter.replaceOp(op, output); return success(); @@ -619,8 +606,7 @@ struct GatherToEmbeddingConversionPattern ttir::ReshapeOp createReshapeOp(PatternRewriter &rewriter, Location loc, Value input, - ::llvm::ArrayRef shapei64, - ::mlir::ArrayAttr operandConstraints) const { + ::llvm::ArrayRef shapei64) const { // reshape start indices (input) to remove the last dimension auto ty = mlir::cast(input.getType()); @@ -631,7 +617,7 @@ struct GatherToEmbeddingConversionPattern return rewriter.create( loc, mlir::RankedTensorType::get(shapei64, ty.getElementType()), input, - output, shape_attr, operandConstraints); + output, shape_attr); } /** @@ -681,8 +667,7 @@ struct GatherToEmbeddingConversionPattern startIndicesType.getShape().end() - 1); ttir::ReshapeOp reshapeOp = - createReshapeOp(rewriter, op.getLoc(), startIndices, newShapeI64, - op.getOperandConstraints()); + createReshapeOp(rewriter, op.getLoc(), startIndices, newShapeI64); assert(reshapeOp && "Failed to create reshape op"); reshapeOp->moveBefore(op); @@ -692,10 +677,7 @@ struct GatherToEmbeddingConversionPattern // convert gather to embedding, use reshaped input if needed ttir::EmbeddingOp embeddingOp = rewriter.create( op.getLoc(), op.getResult().getType(), input, op.getOperands()[0], - op.getOutput(), - rewriter.getArrayAttr(SmallVector( - op.getNumOperands() + 1, rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + op.getOutput()); assert(embeddingOp != nullptr && "Failed to create embedding op"); rewriter.replaceOp(op, embeddingOp); @@ -841,12 +823,10 @@ struct PoolingToPool2dPattern : public OpConversionPattern { rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1]]); auto paddingRightAttr = rewriter.getSI32IntegerAttr(op.getPadding()[2 * spatialDims[1] + 1]); - auto operandConstraints = adaptor.getOperandConstraints(); std::vector outputs; for (Value input : adaptor.getInputs()) { - input = generateTransposeOps(input, rewriter, transposeIndices, - operandConstraints); + input = generateTransposeOps(input, rewriter, transposeIndices); auto outputType = mlir::cast(op.getResult(0).getType()); auto newOutputShape = outputType.getShape().vec(); @@ -864,14 +844,13 @@ struct PoolingToPool2dPattern : public OpConversionPattern { op.getLoc(), newOutputType, input, outputTensor, kernelHeightAttr, kernelWidthAttr, strideHeightAttr, strideWidthAttr, dilationHeightAttr, dilationWidthAttr, ceilModeAttr, paddingTopAttr, - paddingBottomAttr, paddingLeftAttr, paddingRightAttr, - operandConstraints); + paddingBottomAttr, paddingLeftAttr, paddingRightAttr); // Applying the transposes in reverse order to the output will restore the // tensor to the original layout std::reverse(transposeIndices.begin(), transposeIndices.end()); - Value output = generateTransposeOps(newPool.getResult(), rewriter, - transposeIndices, operandConstraints); + Value output = + generateTransposeOps(newPool.getResult(), rewriter, transposeIndices); // Reverse back so the proper input transposes are generated for the next // pool @@ -1043,7 +1022,7 @@ struct SelectToSliceConversionPattern auto newOp = rewriter.create( op.getLoc(), resultType, adaptor.getInput(), sliceDpsResult, rewriter.getI32ArrayAttr(begins), rewriter.getI32ArrayAttr(ends), - rewriter.getI32ArrayAttr(steps), adaptor.getOperandConstraints()); + rewriter.getI32ArrayAttr(steps)); slices.push_back(newOp->getResult(0)); } @@ -1054,12 +1033,7 @@ struct SelectToSliceConversionPattern auto concatOp = rewriter.create( op.getLoc(), outputType, slices, concatDpsResult, - rewriter.getSI32IntegerAttr(dim), - // Generate an array of AnyDeviceTile constraints for the output and - // all the slices. - rewriter.getArrayAttr(SmallVector( - slices.size() + 1, rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + rewriter.getSI32IntegerAttr(dim)); rewriter.replaceOp(op, concatOp.getResult()); } else { @@ -1140,10 +1114,7 @@ struct ArangeForceLastDimensionPattern output = rewriter.create( op.getLoc(), transposeType, output, dpsOutput, arangeDimensionNegative + transposeShape.size(), - arangeOutputType.getRank() - 1, - rewriter.getArrayAttr(SmallVector( - 2, rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + arangeOutputType.getRank() - 1); outputShape = transposeShape; } @@ -1167,10 +1138,7 @@ struct ArangeForceLastDimensionPattern reshapeType.getElementType()); output = rewriter.create( op.getLoc(), reshapeType, output, dpsOutput, - rewriter.getI32ArrayAttr(reshapeShape), - rewriter.getArrayAttr(SmallVector( - 2, rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + rewriter.getI32ArrayAttr(reshapeShape)); outputShape = std::vector(reshapeShape.begin(), reshapeShape.end()); @@ -1193,10 +1161,7 @@ struct ArangeForceLastDimensionPattern output = rewriter.create( op.getLoc(), broadcastType, output, dpsOutput, - rewriter.getArrayAttr(broadcastDims), - rewriter.getArrayAttr(SmallVector( - 2, rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + rewriter.getArrayAttr(broadcastDims)); assert(mlir::cast(output.getType()).getShape() == outputType.getShape() && diff --git a/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp b/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp index 1a00957eb..f8b18926b 100644 --- a/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp +++ b/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp @@ -46,11 +46,7 @@ class TosaToTTIRDefaultDPSOpConversionPattern srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); rewriter.replaceOpWithNewOp( srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(), - ValueRange(outputTensor), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + ValueRange(outputTensor)); return success(); } @@ -97,11 +93,7 @@ class TosaToTTIRClampOpConversionPattern rewriter.replaceOpWithNewOp( srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands()[0], - outputTensor, adaptor.getMinFp(), adaptor.getMaxFp(), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + outputTensor, adaptor.getMinFp(), adaptor.getMaxFp()); return success(); } }; @@ -122,11 +114,7 @@ class TosaToTTIRConcatOpConversionPattern rewriter.replaceOpWithNewOp( srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(), - Value(outputTensor), adaptor.getAxis(), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + Value(outputTensor), adaptor.getAxis()); return success(); } }; @@ -153,12 +141,7 @@ class TosaToTTIRMatmulOpConversionPattern rewriter.replaceOpWithNewOp( srcOp, TypeRange(outputTensor.getType()), operands[0], operands[1], - outputTensor, - - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + outputTensor); return success(); } @@ -191,11 +174,8 @@ class TosaToTTIRReduceOpConversionPattern : public OpConversionPattern { rewriter.replaceOpWithNewOp( srcOp, outputTensor.getType(), adaptor.getInput(), outputTensor, true /*keepdim*/, - rewriter.getArrayAttr(SmallVector(1, adaptor.getAxisAttr())), rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + SmallVector(1, adaptor.getAxisAttr()))); return success(); } }; @@ -220,11 +200,7 @@ class TosaToTTIRMaxPool2DOpConversionPattern rewriter.replaceOpWithNewOp( srcOp, TypeRange(outputTensor.getType()), adaptor.getInput(), outputTensor, dims[0], dims[1], strides[0], strides[1], 1, 1, false, - pad[2], pad[3], pad[0], pad[1], - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); + pad[2], pad[3], pad[0], pad[1]); return success(); } }; diff --git a/lib/Dialect/TTIR/IR/TTIROps.cpp b/lib/Dialect/TTIR/IR/TTIROps.cpp index e8dbe338e..cfc149de1 100644 --- a/lib/Dialect/TTIR/IR/TTIROps.cpp +++ b/lib/Dialect/TTIR/IR/TTIROps.cpp @@ -1653,23 +1653,16 @@ void mlir::tt::ttir::MaximumOp::buildGenericRegion(::mlir::OpBuilder &opBuilder, static mlir::tt::ttir::KernelOp buildKernelOp(::mlir::OpBuilder &opBuilder, ::mlir::Location loc, ::mlir::StringRef kernelName, ::mlir::StringRef kernelKind, - ::mlir::ValueRange inputs, ::mlir::ValueRange outputs, - ::mlir::ArrayAttr operandConstraints) { + ::mlir::ValueRange inputs, ::mlir::ValueRange outputs) { return opBuilder.create( - loc, outputs.getTypes(), kernelName, kernelKind, inputs, outputs, - operandConstraints); + loc, outputs.getTypes(), kernelName, kernelKind, inputs, outputs); } // Reduce op kernel builder static void createReduceOp(::mlir::OpBuilder &opBuilder, ::mlir::Block *block, mlir::Location loc, ::mlir::StringRef kernelKind) { - auto kernelOp = - buildKernelOp(opBuilder, loc, "reduce", kernelKind, block->getArgument(0), - block->getArgument(1), - opBuilder.getArrayAttr(llvm::SmallVector( - block->getNumArguments(), - opBuilder.getAttr( - mlir::tt::OperandConstraint::AnyDeviceTile)))); + auto kernelOp = buildKernelOp(opBuilder, loc, "reduce", kernelKind, + block->getArgument(0), block->getArgument(1)); opBuilder.create(loc, kernelOp->getResults()); } diff --git a/lib/Dialect/TTIR/Transforms/Constant.cpp b/lib/Dialect/TTIR/Transforms/Constant.cpp index 775dda928..2151d2a9b 100644 --- a/lib/Dialect/TTIR/Transforms/Constant.cpp +++ b/lib/Dialect/TTIR/Transforms/Constant.cpp @@ -26,11 +26,8 @@ class TTIRConstantAsFillRewriter : public OpRewritePattern { auto empty = rewriter.create( op.getLoc(), resultTy.getShape(), resultTy.getElementType(), resultTy.getEncoding()); - auto operandConstraints = rewriter.getArrayAttr(SmallVector( - 1, - rewriter.getAttr(OperandConstraint::AnyDevice))); - rewriter.replaceOpWithNewOp( - op, resultTy, empty, op.getValue(), operandConstraints); + rewriter.replaceOpWithNewOp(op, resultTy, empty, + op.getValue()); return success(); } }; diff --git a/lib/Dialect/TTIR/Transforms/Generic.cpp b/lib/Dialect/TTIR/Transforms/Generic.cpp index 3bf96f3cd..15064ed34 100644 --- a/lib/Dialect/TTIR/Transforms/Generic.cpp +++ b/lib/Dialect/TTIR/Transforms/Generic.cpp @@ -68,7 +68,7 @@ class TTIRNamedToKernelRewriter : public OpRewritePattern { auto kernel = rewriter.create( op.getLoc(), op.getResultTypes(), kernelName, kernelKind, - op.getInputs(), op.getOutputs(), op.getOperandConstraints()); + op.getInputs(), op.getOutputs()); rewriter.replaceOp(op, kernel); diff --git a/lib/Dialect/TTIR/Transforms/Layout.cpp b/lib/Dialect/TTIR/Transforms/Layout.cpp index c3ccbf1a4..eca974730 100644 --- a/lib/Dialect/TTIR/Transforms/Layout.cpp +++ b/lib/Dialect/TTIR/Transforms/Layout.cpp @@ -174,23 +174,6 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, ->getResult(0); } -static std::optional -createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, - OperandConstraint operandConstraint, - MemorySpace defaultMemorySpace, - TensorMemoryLayout defaultDeviceMemoryLayout) { - auto desiredMemorySpace = - getLegalMemorySpace(operandConstraint, defaultMemorySpace); - - auto desiredMemoryLayout = getLegalTensorMemoryLayout( - operandConstraint, desiredMemorySpace, defaultDeviceMemoryLayout); - - bool tiled = - !bitEnumContainsAny(operandConstraint, OperandConstraint::Scalar); - return createToLayoutOp(rewriter, loc, input, desiredMemorySpace, - desiredMemoryLayout, tiled); -} - class TTIRLayoutDPSOperandsRewriter : public OpInterfaceRewritePattern { public: @@ -223,16 +206,12 @@ class TTIRLayoutDPSOperandsRewriter if (mlir::isa(op.getOperation()) && !isResult) { continue; } - auto operandConstraint = - mlir::cast( - mlir::cast(op.getOperation()) - .getOperandConstraints()[operand.getOperandNumber()]) - .getValue(); + Location newLoc = appendInputSuffix(op.getLoc(), operand.getOperandNumber()); auto desiredLayout = - createToLayoutOp(rewriter, newLoc, operand.get(), operandConstraint, - defaultMemorySpace, defaultDeviceMemoryLayout); + createToLayoutOp(rewriter, newLoc, operand.get(), defaultMemorySpace, + defaultDeviceMemoryLayout, true /* isTiled */); if (desiredLayout) { rewriter.modifyOpInPlace(op, [&]() { diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index 712e12ad0..80b76d6d4 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -248,37 +248,6 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, ->getResult(0); } -static std::optional -createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, - OperandConstraint operandConstraint) { - // Find out which buffer type we want - tt::MemorySpace ttDefaultMemSpace = - utils::toTTMemorySpace(g_defaultMemorySpaceDevice); - tt::MemorySpace desiredMemorySpace = - getLegalMemorySpace(operandConstraint, ttDefaultMemSpace); - BufferType desiredBufferType = utils::toTTNNBufferType(desiredMemorySpace); - - // Find out which memory layout we want - tt::TensorMemoryLayout ttMemoryLayout = - utils::toTTTensorMemoryLayout(g_defaultMemoryLayout); - tt::TensorMemoryLayout desiredMemoryLayout = getLegalTensorMemoryLayout( - operandConstraint, desiredMemorySpace, ttMemoryLayout); - TensorMemoryLayoutAttr ttnnMemoryLayoutAttr; - if (desiredMemoryLayout != tt::TensorMemoryLayout::None) { - TensorMemoryLayout ttnnMemoryLayout = - utils::toTTNNTensorMemoryLayout(desiredMemoryLayout); - ttnnMemoryLayoutAttr = - TensorMemoryLayoutAttr::get(rewriter.getContext(), ttnnMemoryLayout); - } - - // Check if the tensor should be tiled - bool tiled = - !bitEnumContainsAny(operandConstraint, OperandConstraint::Scalar); - - return createToLayoutOp(rewriter, loc, input, desiredBufferType, - ttnnMemoryLayoutAttr, tiled); -} - static bool changeLayoutToHost(DestinationStyleOpInterface &op, OpOperand &operand, PatternRewriter &rewriter) { Location newLoc = appendInputSuffix(op.getLoc(), operand.getOperandNumber()); @@ -334,17 +303,14 @@ class TTNNLayoutDPSOperandsRewriter continue; } - // Read operand constrait for current operand - OperandConstraint operandConstraint = - mlir::cast( - mlir::cast(op.getOperation()) - .getOperandConstraints()[operand.getOperandNumber()]) - .getValue(); Location newLoc = appendInputSuffix(op.getLoc(), operand.getOperandNumber()); // Given the operand constraint, create the desired layout for the operand - std::optional desiredLayout = - createToLayoutOp(rewriter, newLoc, operand.get(), operandConstraint); + std::optional desiredLayout = createToLayoutOp( + rewriter, newLoc, operand.get(), g_defaultMemorySpaceDevice, + TensorMemoryLayoutAttr::get(rewriter.getContext(), + g_defaultMemoryLayout), + true /* isTiled */); // If layout changed update the operand if (desiredLayout) { diff --git a/python/test_infra/ttir_builder.py b/python/test_infra/ttir_builder.py index a6a302d4a..9c832d014 100644 --- a/python/test_infra/ttir_builder.py +++ b/python/test_infra/ttir_builder.py @@ -192,33 +192,6 @@ def _override_golden(self, operand: Operand, golden: Golden) -> None: def _get_golden_tensor(self, operand: Operand) -> torch.Tensor: return self._get_golden(operand).tensor - def _get_operand_constraint_attr( - self, - num_operands: int, - operand_constraints: Optional[List[tt.OperandConstraint]] = None, - ) -> tt.ir.OperandConstraintAttr: - """ - Helper method to prepack operand constraints given as a list of enums - to a list of tt.ir.OperandConstraintAttr and wrap that list in an - tt.ir.OperandConstraintAttr. - - If no `operand_constraints` are passed, `tt.OperandConstraint.Any` will - be used for each operand. - """ - operand_constraints = ( - operand_constraints - if operand_constraints is not None - else [tt.OperandConstraint.Any for _ in range(num_operands)] - ) - - return tt.ir.OperandConstraintAttr.get( - self._ctx, - [ - tt.ir.OperandConstraintAttr.get(self._ctx, operand_constraint) - for operand_constraint in operand_constraints - ], - ) - @property def _default_dtype(self) -> Type: return F32Type.get(self._ctx) @@ -303,7 +276,6 @@ def eltwise_proxy( [self._get_type(output)], inputs, [output], - self._get_operand_constraint_attr(3), loc=Location.name(str(id)), ) diff --git a/test/python/smoketest.py b/test/python/smoketest.py index c81ca1023..dfc324e38 100644 --- a/test/python/smoketest.py +++ b/test/python/smoketest.py @@ -16,7 +16,7 @@ %0 = tensor.empty() : tensor<64x128xf32> %1 = tensor.empty() : tensor<64x128xf32> %2 = tensor.empty() : tensor<64x128xf32> - %3 = "ttir.multiply"(%0, %1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %3 = "ttir.multiply"(%0, %1, %2) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> """ ) # CHECK: %[[C:.*]] = tensor.empty() : tensor<64x128xf32> diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir index 51cfd214b..055a5fa37 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/binary/concat_op.mlir @@ -6,7 +6,7 @@ module @jit_concat attributes {} { dimension = 1 : i64 } : (tensor<32x32xf32>, tensor<32x64xf32>) -> tensor<32x96xf32> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> return %0 : tensor<32x96xf32> } @@ -15,7 +15,7 @@ module @jit_concat attributes {} { dimension = 0 : i64 } : (tensor<3x2xi64>, tensor<1x2xi64>) -> tensor<4x2xi64> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x2xi32>, tensor<1x2xi32>, tensor<4x2xi32>) -> tensor<4x2xi32> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 0 : si32}> : (tensor<3x2xi32>, tensor<1x2xi32>, tensor<4x2xi32>) -> tensor<4x2xi32> return %0 : tensor<4x2xi64> } @@ -24,7 +24,7 @@ module @jit_concat attributes {} { dimension = 1 : i64 } : (tensor<4x3xf32>, tensor<4x5xf32>) -> tensor<4x8xf32> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<4x3xf32>, tensor<4x5xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<4x3xf32>, tensor<4x5xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return %0 : tensor<4x8xf32> } @@ -33,7 +33,7 @@ module @jit_concat attributes {} { dimension = 1 : i64 } : (tensor<128x64xf32>, tensor<128x96xf32>) -> tensor<128x160xf32> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128x64xf32>, tensor<128x96xf32>, tensor<128x160xf32>) -> tensor<128x160xf32> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<128x64xf32>, tensor<128x96xf32>, tensor<128x160xf32>) -> tensor<128x160xf32> return %0 : tensor<128x160xf32> } @@ -42,7 +42,7 @@ module @jit_concat attributes {} { dimension = 1 : i64 } : (tensor<256x512xi64>, tensor<256x256xi64>) -> tensor<256x768xi64> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<256x512xi32>, tensor<256x256xi32>, tensor<256x768xi32>) -> tensor<256x768xi32> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<256x512xi32>, tensor<256x256xi32>, tensor<256x768xi32>) -> tensor<256x768xi32> return %0 : tensor<256x768xi64> } @@ -51,7 +51,7 @@ module @jit_concat attributes {} { dimension = 1 : i64 } : (tensor<64x32xf64>, tensor<64x64xf64>) -> tensor<64x96xf64> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x32xf32>, tensor<64x64xf32>, tensor<64x96xf32>) -> tensor<64x96xf32> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<64x32xf32>, tensor<64x64xf32>, tensor<64x96xf32>) -> tensor<64x96xf32> return %0 : tensor<64x96xf64> } @@ -60,7 +60,7 @@ module @jit_concat attributes {} { dimension = 0 : i64 } : (tensor<1000x128xi32>, tensor<500x128xi32>) -> tensor<1500x128xi32> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<1000x128xi32>, tensor<500x128xi32>, tensor<1500x128xi32>) -> tensor<1500x128xi32> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 0 : si32}> : (tensor<1000x128xi32>, tensor<500x128xi32>, tensor<1500x128xi32>) -> tensor<1500x128xi32> return %0 : tensor<1500x128xi32> } @@ -69,7 +69,7 @@ module @jit_concat attributes {} { dimension = 3 : i64 } : (tensor<3x2x4x5xf64>, tensor<3x2x4x3xf64>) -> tensor<3x2x4x8xf64> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 3 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x2x4x5xf32>, tensor<3x2x4x3xf32>, tensor<3x2x4x8xf32>) -> tensor<3x2x4x8xf32> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 3 : si32}> : (tensor<3x2x4x5xf32>, tensor<3x2x4x3xf32>, tensor<3x2x4x8xf32>) -> tensor<3x2x4x8xf32> return %0 : tensor<3x2x4x8xf64> } @@ -78,7 +78,7 @@ module @jit_concat attributes {} { dimension = 2 : i64 } : (tensor<8x4x6xi32>, tensor<8x4x2xi32>) -> tensor<8x4x8xi32> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] - // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 2 : si32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<8x4x6xi32>, tensor<8x4x2xi32>, tensor<8x4x8xi32>) -> tensor<8x4x8xi32> + // CHECK: %[[C:.*]] = "ttir.concat"(%arg0, %arg1, %0) <{dim = 2 : si32}> : (tensor<8x4x6xi32>, tensor<8x4x2xi32>, tensor<8x4x8xi32>) -> tensor<8x4x8xi32> return %0 : tensor<8x4x8xi32> } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir index 6bd602e27..d46b00e6a 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/clamp_op.mlir @@ -6,7 +6,7 @@ module @jit_transpose attributes {} { %cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<4xf32> // CHECK: %[[EMPTY:.*]] = tensor.empty() : [[TENSOR:tensor<4xf32>]] // CHECK: "ttir.clamp"(%arg0, %[[EMPTY]]) - // CHECK-SAME: max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, + // CHECK-SAME: max = 3.000000e+00 : f32, min = 2.000000e+00 : f32 // CHECK-SAME: ([[TENSOR]], [[TENSOR]]) -> [[TENSOR]] %0 = stablehlo.clamp %cst, %arg0, %cst_0 : tensor<4xf32> return %0 : tensor<4xf32> diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/exponential_minus_one_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/exponential_minus_one_op.mlir index 179268b76..6a86fed84 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/exponential_minus_one_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/exponential_minus_one_op.mlir @@ -1,11 +1,10 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module @jit_eltwise_expm1 attributes {} { func.func public @test_expm1(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = stablehlo.exponential_minus_one %arg0 : tensor<13x21x3xf32> // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] - // CHECK: [[VAL1:%[0-9]+]] = "ttir.expm1"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.expm1"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/log_plus_one_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/log_plus_one_op.mlir index d1d44f3af..feba16b6f 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/log_plus_one_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/log_plus_one_op.mlir @@ -1,11 +1,10 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module @jit_eltwise_log_plus_one attributes {} { func.func public @test_log_plus_one(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = stablehlo.log_plus_one %arg0 : tensor<13x21x3xf32> // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] - // CHECK: [[VAL1:%[0-9]+]] = "ttir.log1p"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.log1p"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir index 92cd8895f..34ce84855 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/scatter_op.mlir @@ -1,6 +1,5 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module @jit_scatter attributes {} { func.func public @test_scatter(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x1xi64>, %arg2: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE1:tensor<[0-9]+x[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] @@ -8,7 +7,7 @@ module @jit_scatter attributes {} { ^bb0(%arg3: tensor, %arg4: tensor): stablehlo.return %arg4 : tensor }) : (tensor<1x3x320x320xf32>, tensor<1x1xi64>, tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> - // CHECK: [[VAL1:%[0-9]+]] = "ttir.scatter"(%arg0, %arg1, %arg2, [[VAL0]]) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array} + // CHECK: [[VAL1:%[0-9]+]] = "ttir.scatter"(%arg0, %arg1, %arg2, [[VAL0]]) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array} // CHECK: ([[TENSOR_SIZE1]], tensor<1x1xi32>, tensor<1x3x32x32xf32>, [[TENSOR_SIZE1]]) -> tensor<1x3x320x320xf32> return %result : tensor<1x3x320x320xf32> // CHECK: return [[VAL1]] : [[TENSOR_SIZE1]] diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir index 458879081..24df823d4 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/select_op.mlir @@ -1,13 +1,12 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module @jit_eltwise_select attributes {} { func.func public @test_select(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { %0 = stablehlo.compare EQ, %arg0, %arg1 : (tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xi1> %1 = stablehlo.select %0, %arg0, %arg1 : (tensor<13x37xi1>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> // CHECK: %[[EMPTY:[0-9]+]] = tensor.empty() // CHECK: %[[VAL1:[0-9]+]] = "ttir.eq" - // CHECK: %[[SELECT:[0-9]+]] = "ttir.where"(%[[VAL1:[0-9]+]], %arg0, %arg1, %[[EMPTY:[0-9]+]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + // CHECK: %[[SELECT:[0-9]+]] = "ttir.where"(%[[VAL1:[0-9]+]], %arg0, %arg1, %[[EMPTY:[0-9]+]]) <{operandSegmentSizes = array}> : (tensor<13x37xbf16>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> return %1 : tensor<13x37xf32> } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/sign_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/sign_op.mlir index 0bf4a1bca..a0b9d7c05 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/sign_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/sign_op.mlir @@ -1,11 +1,10 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module @jit_eltwise_sign attributes {} { func.func public @test_sign(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = stablehlo.sign %arg0 : tensor<13x21x3xf32> // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] - // CHECK: [[VAL1:%[0-9]+]] = "ttir.sign"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.sign"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> // CHECK: return [[VAL1]] : [[TENSOR_SIZE]] } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/ceil_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/ceil_op.mlir index f81c4b37b..48d60ebd0 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/unary/ceil_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/ceil_op.mlir @@ -4,7 +4,7 @@ module @jit_eltwise_ceil attributes {} { func.func public @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = stablehlo.ceil %arg0 : tensor<13x21x3xf32> // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] - // CHECK: [[VAL1:%[0-9]+]] = "ttir.ceil"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.ceil"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/cosine_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/cosine_op.mlir index fb54f073e..3c4853e6c 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/unary/cosine_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/cosine_op.mlir @@ -4,7 +4,7 @@ module @jit_eltwise_cosine attributes {} { func.func public @test_cosine(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = stablehlo.cosine %arg0 : tensor<13x21x3xf32> // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] - // CHECK: [[VAL1:%[0-9]+]] = "ttir.cos"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.cos"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/log_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/log_op.mlir index 702bc155d..3057c3a12 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/unary/log_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/log_op.mlir @@ -4,7 +4,7 @@ module @jit_eltwise_log attributes {} { func.func public @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = stablehlo.log %arg0 : tensor<13x21x3xf32> // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] - // CHECK: [[VAL1:%[0-9]+]] = "ttir.log"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.log"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/logit_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/logit_op.mlir index 48c64d12d..a0c27b375 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/unary/logit_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/logit_op.mlir @@ -4,7 +4,7 @@ module @jit_eltwise_logit attributes {} { func.func public @test_logit(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = stablehlo.logistic %arg0 : tensor<13x21x3xf32> // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] - // CHECK: [[VAL1:%[0-9]+]] = "ttir.sigmoid"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.sigmoid"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/sine_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/sine_op.mlir index 24ea37238..be0883c3f 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/unary/sine_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/sine_op.mlir @@ -4,7 +4,7 @@ module @jit_eltwise_sine attributes {} { func.func public @test_sine(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = stablehlo.sine %arg0 : tensor<13x21x3xf32> // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] - // CHECK: [[VAL1:%[0-9]+]] = "ttir.sin"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.sin"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/tan_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/tan_op.mlir index 77b8f3b8b..a19344466 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/unary/tan_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/tan_op.mlir @@ -4,7 +4,7 @@ module @jit_eltwise_tan attributes {} { func.func public @test_tan(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = stablehlo.tan %arg0 : tensor<13x21x3xf32> // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] - // CHECK: [[VAL1:%[0-9]+]] = "ttir.tan"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.tan"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> } } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/unary/tanh_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/unary/tanh_op.mlir index 5d420c43c..b0171e1e6 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/unary/tanh_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/unary/tanh_op.mlir @@ -4,7 +4,7 @@ module @jit_eltwise_tanh attributes {} { func.func public @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %0 = stablehlo.tanh %arg0 : tensor<13x21x3xf32> // CHECK: [[VAL0:%[0-9]+]] = tensor.empty() : [[TENSOR_SIZE:tensor<[0-9]+x[0-9]+x[0-9]+xf[0-9]+>]] - // CHECK: [[VAL1:%[0-9]+]] = "ttir.tanh"(%arg0, [[VAL0]]) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] + // CHECK: [[VAL1:%[0-9]+]] = "ttir.tanh"(%arg0, [[VAL0]]) <{operandSegmentSizes = array}> : ([[TENSOR_SIZE]], [[TENSOR_SIZE]]) -> [[TENSOR_SIZE]] return %0 : tensor<13x21x3xf32> } } diff --git a/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir b/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir index 8365bbddd..8231cd703 100644 --- a/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir +++ b/test/ttmlir/Dialect/TTIR/Decomposition/select_decomposition_tests.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttir-decomposition %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_identity(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = tensor.empty() : tensor<4x4xf32> // CHECK: %{{[0-9]+}} = "ttir.slice" - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 4: si32}> : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> } @@ -18,7 +17,7 @@ module attributes {} { // CHECK: %{{[0-9]+}} = "ttir.slice" // CHECK: %{{[0-9]+}} = "ttir.slice" // CHECK: %{{[0-9]+}} = "ttir.concat" - %1 = "ttir.select"(%arg0, %0) <{dim = -1: si32, begin = 0: si32, length = 4: si32, stride = 16: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = -1: si32, begin = 0: si32, length = 4: si32, stride = 16: si32}> : (tensor<4x2x64x128xf32>, tensor<4x2x64x32xf32>) -> tensor<4x2x64x32xf32> return %1 : tensor<4x2x64x32xf32> diff --git a/test/ttmlir/Dialect/TTIR/clamp/clamp_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/clamp/clamp_tests_negative.mlir index 804a98963..b89a91ec8 100644 --- a/test/ttmlir/Dialect/TTIR/clamp/clamp_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/clamp/clamp_tests_negative.mlir @@ -2,12 +2,11 @@ // Negative test for clamp operation // Verify that the parsing fails if input and output shapes do not match. -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @clamp(%arg0: tensor<64x64xbf16>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: error: 'ttir.clamp' op input and output must have same shape. - %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}> : (tensor<64x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTIR/constant_as_fill.mlir b/test/ttmlir/Dialect/TTIR/constant_as_fill.mlir index dbe7f079b..4cc8c0d8a 100644 --- a/test/ttmlir/Dialect/TTIR/constant_as_fill.mlir +++ b/test/ttmlir/Dialect/TTIR/constant_as_fill.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-constant-as-fill %s | FileCheck %s - -#any_device = #tt.operand_constraint - func.func public @add5(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] // CHECK: %[[C:.*]] = "ttir.fill"[[C:.*]] %0 = "ttir.constant"() <{value = dense<5.000000e+00> : tensor<32x32xf32>}> : () -> tensor<32x32xf32> // CHECK: %[[C:.*]] = tensor.empty[[C:.*]] %1 = tensor.empty() : tensor<32x32xf32> - %2 = "ttir.add"(%arg0, %0, %1) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + %2 = "ttir.add"(%arg0, %0, %1) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %2 : tensor<32x32xf32> } diff --git a/test/ttmlir/Dialect/TTIR/convolution/convolution_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/convolution/convolution_tests_negative.mlir index afdb92cc0..278bb9f21 100644 --- a/test/ttmlir/Dialect/TTIR/convolution/convolution_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/convolution/convolution_tests_negative.mlir @@ -1,6 +1,4 @@ // RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s -#any_device_tile = #tt.operand_constraint - module @jit_convolution_bad_spatial_dimensions { func.func public @test_illegal_convolution(%arg0: tensor<1x3x100x100xbf16>, %arg1: tensor<7x3x3x3xbf16>) -> tensor<1x7x100x100xbf16> { %0 = tensor.empty() : tensor<1x7x100x100xbf16> @@ -20,7 +18,6 @@ module @jit_convolution_bad_spatial_dimensions { >, feature_group_count = 1 : i64, input_dilation = array, - operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array, weight_dilation = array, window_reversal = array, @@ -51,7 +48,6 @@ module @jit_convolution_bad_stride_dimensions { >, feature_group_count = 1 : i64, input_dilation = array, - operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array, weight_dilation = array, window_reversal = array, @@ -82,7 +78,6 @@ module @jit_convolution_bad_input_tensor { >, feature_group_count = 1 : i64, input_dilation = array, - operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array, weight_dilation = array, window_reversal = array, @@ -113,7 +108,6 @@ module @jit_convolution_bad_weight_tensor { >, feature_group_count = 1 : i64, input_dilation = array, - operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array, weight_dilation = array, window_reversal = array, @@ -144,7 +138,6 @@ module @jit_convolution_bad_bias_tensor { >, feature_group_count = 1 : i64, input_dilation = array, - operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array, weight_dilation = array, window_reversal = array, diff --git a/test/ttmlir/Dialect/TTIR/index/index_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/index/index_tests_negative.mlir index 03c2e6faf..9f5d8b04a 100644 --- a/test/ttmlir/Dialect/TTIR/index/index_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/index/index_tests_negative.mlir @@ -2,12 +2,11 @@ // Negative tests for index operation // Verify that the parsing fails if the begins attribute is not a 3D tensor -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @index_negative_invalid_shape(%arg0: tensor) -> tensor<1xbf16> { %0 = tensor.empty() : tensor<1xbf16> // CHECK: error: 'ttir.index' op Input must be at least a 1D tensor - %1 = "ttir.index"(%arg0, %0) <{dim = 0: i32, begin = 0: i32, end = 0: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 0: i32, begin = 0: i32, end = 0: i32, step = 1: i32}> : (tensor, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } } @@ -19,7 +18,7 @@ module attributes {} { func.func @index_negative_invalid_begins(%arg0: tensor<3x128x64xbf16>) -> tensor<3x128x64xbf16> { %0 = tensor.empty() : tensor<3x128x64xbf16> // CHECK: error: 'ttir.index' op Invalid dimension index 3. Input tensor rank is 3 - %1 = "ttir.index"(%arg0, %0) <{dim = 3 : i32, begin = 0: i32, end = 0: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<3x128x64xbf16>) -> tensor<3x128x64xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 3 : i32, begin = 0: i32, end = 0: i32, step = 1: i32}> : (tensor<3x128x64xbf16>, tensor<3x128x64xbf16>) -> tensor<3x128x64xbf16> return %1 : tensor<3x128x64xbf16> } } @@ -31,7 +30,7 @@ module attributes {} { func.func @index_negative_invalid_output_datatype(%arg0: tensor<3x128x64xbf16>) -> tensor<3x128x32xf32> { %0 = tensor.empty() : tensor<3x128x32xf32> // CHECK: error: 'ttir.index' op Output tensor must have the same element type as the input tensor - %1 = "ttir.index"(%arg0, %0) <{dim = 2 : i32, begin = 0: i32, end = 32: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<3x128x32xf32>) -> tensor<3x128x32xf32> + %1 = "ttir.index"(%arg0, %0) <{dim = 2 : i32, begin = 0: i32, end = 32: i32, step = 1: i32}> : (tensor<3x128x64xbf16>, tensor<3x128x32xf32>) -> tensor<3x128x32xf32> return %1 : tensor<3x128x32xf32> } } @@ -43,7 +42,7 @@ module attributes {} { func.func @index_negative_input_output_rank_missmatch(%arg0: tensor<3x128x64xbf16>) -> tensor<3x64x64x1xbf16> { %0 = tensor.empty() : tensor<3x64x64x1xbf16> // CHECK: error: 'ttir.index' op Output tensor must have the same rank as the input tensor - %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = 0: i32, end = 64: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<3x64x64x1xbf16>) -> tensor<3x64x64x1xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = 0: i32, end = 64: i32, step = 1: i32}> : (tensor<3x128x64xbf16>, tensor<3x64x64x1xbf16>) -> tensor<3x64x64x1xbf16> return %1 : tensor<3x64x64x1xbf16> } } @@ -55,7 +54,7 @@ module attributes {} { func.func @index_negative_invalid_begin_positive(%arg0: tensor<10x3x128x64xbf16>) -> tensor<10x1x128x64xbf16> { %0 = tensor.empty() : tensor<10x1x128x64xbf16> // CHECK: error: 'ttir.index' op Invalid begin index for dimension 1. Expected value in range [-3, 3), got 3. Input shape: (10, 3, 128, 64) - %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = 3: i32, end = 3: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<10x1x128x64xbf16>) -> tensor<10x1x128x64xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = 3: i32, end = 3: i32, step = 1: i32}> : (tensor<10x3x128x64xbf16>, tensor<10x1x128x64xbf16>) -> tensor<10x1x128x64xbf16> return %1 : tensor<10x1x128x64xbf16> } } @@ -67,7 +66,7 @@ module attributes {} { func.func @index_negative_invalid_begin_negative(%arg0: tensor<10x3x128x64xbf16>) -> tensor<10x3x64x64xbf16> { %0 = tensor.empty() : tensor<10x3x64x64xbf16> // CHECK: error: 'ttir.index' op Invalid begin index for dimension 2. Expected value in range [-128, 128), got -129. Input shape: (10, 3, 128, 64) - %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = -129: i32, end = 64: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<10x3x64x64xbf16>) -> tensor<10x3x64x64xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = -129: i32, end = 64: i32, step = 1: i32}> : (tensor<10x3x128x64xbf16>, tensor<10x3x64x64xbf16>) -> tensor<10x3x64x64xbf16> return %1 : tensor<10x3x64x64xbf16> } } @@ -79,7 +78,7 @@ module attributes {} { func.func @index_negative_invalid_end_positive(%arg0: tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> { %0 = tensor.empty() : tensor<10x3x128x64xbf16> // CHECK: error: 'ttir.index' op Invalid end index for dimension 1. Expected value in range [-3, 3], got 4. Input shape: (10, 3, 128, 64) - %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = 0: i32, end = 4: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = 0: i32, end = 4: i32, step = 1: i32}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> return %1 : tensor<10x3x128x64xbf16> } } @@ -91,7 +90,7 @@ module attributes {} { func.func @index_negative_invalid_end_negative(%arg0: tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> { %0 = tensor.empty() : tensor<10x3x128x64xbf16> // CHECK: error: 'ttir.index' op Invalid end index for dimension 1. Expected value in range [-3, 3], got -4. Input shape: (10, 3, 128, 64) - %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = -1: i32, end = -4: i32, step = -1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = -1: i32, end = -4: i32, step = -1: i32}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> return %1 : tensor<10x3x128x64xbf16> } } @@ -103,7 +102,7 @@ module attributes {} { func.func @index_negative_step_is_zero(%arg0: tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> { %0 = tensor.empty() : tensor<10x3x128x64xbf16> // CHECK: error: 'ttir.index' op Step value for dimension 1 cannot be zero - %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = -1: i32, end = -3: i32, step = 0: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = -1: i32, end = -3: i32, step = 0: i32}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> return %1 : tensor<10x3x128x64xbf16> } } @@ -115,7 +114,7 @@ module attributes {} { func.func @index_negative_begin_greater_than_end_positive_step(%arg0: tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> { %0 = tensor.empty() : tensor<10x3x128x64xbf16> // CHECK: error: 'ttir.index' op For positive step, begin index must be less than or equal to end index for dimension 2. Got begin: 2, end: 0, step: 1, input shape: (10, 3, 128, 64) - %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = 2: i32, end = 0: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = 2: i32, end = 0: i32, step = 1: i32}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> return %1 : tensor<10x3x128x64xbf16> } } @@ -127,7 +126,7 @@ module attributes {} { func.func @index_negative_begin_less_than_end_negative_step(%arg0: tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> { %0 = tensor.empty() : tensor<10x3x128x64xbf16> // CHECK: error: 'ttir.index' op For negative step, begin index must be greater than or equal to end index for dimension 3. Got begin: 0, end: 64, step: -1, input shape: (10, 3, 128, 64) - %1 = "ttir.index"(%arg0, %0) <{dim = 3: i32, begin = 0: i32, end = 64: i32, step = -1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 3: i32, begin = 0: i32, end = 64: i32, step = -1: i32}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x64xbf16>) -> tensor<10x3x128x64xbf16> return %1 : tensor<10x3x128x64xbf16> } } @@ -139,7 +138,7 @@ module attributes {} { func.func @index_negative_invalid_output_shape(%arg0: tensor<10x3x128x64xbf16>) -> tensor<10x3x128x32xbf16> { %0 = tensor.empty() : tensor<10x3x128x32xbf16> // CHECK: error: 'ttir.index' op Mismatch in dimension 3 of the output tensor: expected size 16, but got 32 - %1 = "ttir.index"(%arg0, %0) <{dim = 3: i32, begin = 0: i32, end = 64: i32, step = 4: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x32xbf16>) -> tensor<10x3x128x32xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 3: i32, begin = 0: i32, end = 64: i32, step = 4: i32}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x32xbf16>) -> tensor<10x3x128x32xbf16> return %1 : tensor<10x3x128x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTIR/index/index_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/index/index_tests_positive.mlir index f3ccbfeda..4a6f39999 100644 --- a/test/ttmlir/Dialect/TTIR/index/index_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTIR/index/index_tests_positive.mlir @@ -1,59 +1,58 @@ // RUN: ttmlir-opt %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @index_1d(%arg0: tensor<64xbf16>) -> tensor<32xbf16> { %0 = tensor.empty() : tensor<32xbf16> // CHECK: %[[C:.*]] = "ttir.index"[[C:.*]] - %1 = "ttir.index"(%arg0, %0) <{dim = 0: i32, begin = 0: i32, end = 32: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64xbf16>, tensor<32xbf16>) -> tensor<32xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 0: i32, begin = 0: i32, end = 32: i32, step = 1: i32}> : (tensor<64xbf16>, tensor<32xbf16>) -> tensor<32xbf16> return %1 : tensor<32xbf16> } func.func @index_1d_step(%arg0: tensor<64xbf16>) -> tensor<16xbf16> { %0 = tensor.empty() : tensor<16xbf16> // CHECK: %[[C:.*]] = "ttir.index"[[C:.*]] - %1 = "ttir.index"(%arg0, %0) <{dim = 0: i32, begin = 0: i32, end = 32: i32, step = 2: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64xbf16>, tensor<16xbf16>) -> tensor<16xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 0: i32, begin = 0: i32, end = 32: i32, step = 2: i32}> : (tensor<64xbf16>, tensor<16xbf16>) -> tensor<16xbf16> return %1 : tensor<16xbf16> } func.func @index_2d(%arg0: tensor<128x64xbf16>) -> tensor<128x32xbf16> { %0 = tensor.empty() : tensor<128x32xbf16> // CHECK: %[[C:.*]] = "ttir.index"[[C:.*]] - %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = 0: i32, end = 32: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<128x64xbf16>, tensor<128x32xbf16>) -> tensor<128x32xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = 0: i32, end = 32: i32, step = 1: i32}> : (tensor<128x64xbf16>, tensor<128x32xbf16>) -> tensor<128x32xbf16> return %1 : tensor<128x32xbf16> } func.func @index_2d_step(%arg0: tensor<128x64xbf16>) -> tensor<128x16xbf16> { %0 = tensor.empty() : tensor<128x16xbf16> // CHECK: %[[C:.*]] = "ttir.index"[[C:.*]] - %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = 32: i32, end = 64: i32, step = 2: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<128x64xbf16>, tensor<128x16xbf16>) -> tensor<128x16xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 1: i32, begin = 32: i32, end = 64: i32, step = 2: i32}> : (tensor<128x64xbf16>, tensor<128x16xbf16>) -> tensor<128x16xbf16> return %1 : tensor<128x16xbf16> } func.func @index_3d(%arg0: tensor<3x128x64xbf16>) -> tensor<3x128x32xbf16> { %0 = tensor.empty() : tensor<3x128x32xbf16> // CHECK: %[[C:.*]] = "ttir.index"[[C:.*]] - %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = 0: i32, end = 32: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<3x128x32xbf16>) -> tensor<3x128x32xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = 0: i32, end = 32: i32, step = 1: i32}> : (tensor<3x128x64xbf16>, tensor<3x128x32xbf16>) -> tensor<3x128x32xbf16> return %1 : tensor<3x128x32xbf16> } func.func @index_3d_step(%arg0: tensor<3x128x64xbf16>) -> tensor<3x128x8xbf16> { %0 = tensor.empty() : tensor<3x128x8xbf16> // CHECK: %[[C:.*]] = "ttir.index"[[C:.*]] - %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = -1: i32, end = 0: i32, step = -8: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<3x128x8xbf16>) -> tensor<3x128x8xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = -1: i32, end = 0: i32, step = -8: i32}> : (tensor<3x128x64xbf16>, tensor<3x128x8xbf16>) -> tensor<3x128x8xbf16> return %1 : tensor<3x128x8xbf16> } func.func @index_4d(%arg0: tensor<10x3x128x64xbf16>) -> tensor<10x3x128x32xbf16> { %0 = tensor.empty() : tensor<10x3x128x32xbf16> // CHECK: %[[C:.*]] = "ttir.index"[[C:.*]] - %1 = "ttir.index"(%arg0, %0) <{dim = 3: i32, begin = 0: i32, end = 32: i32, step = 1: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x32xbf16>) -> tensor<10x3x128x32xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 3: i32, begin = 0: i32, end = 32: i32, step = 1: i32}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x32xbf16>) -> tensor<10x3x128x32xbf16> return %1 : tensor<10x3x128x32xbf16> } func.func @index_4d_step(%arg0: tensor<10x3x128x64xbf16>) -> tensor<10x3x128x24xbf16> { %0 = tensor.empty() : tensor<10x3x128x24xbf16> // CHECK: %[[C:.*]] = "ttir.index"[[C:.*]] - %1 = "ttir.index"(%arg0, %0) <{dim = 3: i32, begin = 0: i32, end = -16: i32, step = 2: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x24xbf16>) -> tensor<10x3x128x24xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 3: i32, begin = 0: i32, end = -16: i32, step = 2: i32}> : (tensor<10x3x128x64xbf16>, tensor<10x3x128x24xbf16>) -> tensor<10x3x128x24xbf16> return %1 : tensor<10x3x128x24xbf16> } } diff --git a/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir index 522628160..0154deff9 100644 --- a/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/linear/linear_tests_negative.mlir @@ -2,193 +2,176 @@ // Negative tests for linear operation // Verify that the parsing fails if either of operands is a scalar -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_1d_1d_scalar_a(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttir.linear' op Input A must be at least a 1D tensor %0 = tensor.empty() : tensor<1xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_1d_1d_scalar_b(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { // CHECK: error: 'ttir.linear' op Input B must be at least a 1D tensor %0 = tensor.empty() : tensor<1xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_1d_1d_scalar_bias(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor) -> tensor<1xbf16> { // CHECK: error: 'ttir.linear' op Bias must be at least a 1D tensor %0 = tensor.empty() : tensor<1xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } } // Verifty that the parsing fails if the output is a scalar // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_1d_1d_scalar_output(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { // CHECK: error: 'ttir.linear' op Scalar output is not supported, output must be at least a 1D tensor %0 = tensor.empty() : tensor - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor) -> tensor + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor) -> tensor return %1 : tensor } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_1d_1d_output_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { // CHECK: error: 'ttir.linear' op Scalar output must be a 1D tensor of size 1 %0 = tensor.empty() : tensor<2xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<2xbf16>) -> tensor<2xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<2xbf16>) -> tensor<2xbf16> return %1 : tensor<2xbf16> } } // Inner dimension mismatch tests // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_1d_1d_inner_dimension_mismatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<1xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_1d_2d_inner_dimension_mismatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttir.linear' op Input A[-1](64) and B[-2](128) must have matching inner dimensions %0 = tensor.empty() : tensor<64xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_2d_1d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<64xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_2d_2d_inner_dimension_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<64x64xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_nd_nd_inner_dimension_mismatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttir.linear' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<7x64x64xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } } // Batch dimension mismatch tests // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_nd_nd_same_rank_batch_broadcast_incompatible_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttir.linear' op Batch dimensions of input A(7) and B(2) are not broadcast compatible %0 = tensor.empty() : tensor<7x64x64xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<2x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<2x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_nd_nd_same_rank_batch_broadcast_incompatible_2(%arg0: tensor<2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { // CHECK: error: 'ttir.linear' op Batch dimensions of input A(2,7) and B(7,1) are not broadcast compatible %0 = tensor.empty() : tensor<7x64x64xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x7x64x64xbf16> return %1 : tensor<7x7x64x64xbf16> } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_nd_nd_different_rank_batch_broadcast_incompatible(%arg0: tensor<12x2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { // CHECK: error: 'ttir.linear' op Batch dimensions of input A(12,2,7) and B(7,1) are not broadcast compatible %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<12x2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> return %1 : tensor<12x7x7x64x64xbf16> } } // Bias shape mismatch tests // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_matmul_bias_broadcast_incompatible(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<2x64xbf16>) -> tensor<64x64xbf16> { // CHECK: error: 'ttir.linear' op Bias shape(2,64) is not broadcast compatible with the matmul output shape(64,64) %0 = tensor.empty() : tensor<64x64xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<2x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<2x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_nd_nd_matmul_bias_broadcast_incompatible(%arg0: tensor<3x64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<2x64x64xbf16>) -> tensor<3x64x64xbf16> { // CHECK: error: 'ttir.linear' op Bias shape(2,64,64) is not broadcast compatible with the matmul output shape(3,64,64) %0 = tensor.empty() : tensor<3x64x64xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64x128xbf16>, tensor<128x64xbf16>, tensor<2x64x64xbf16>, tensor<3x64x64xbf16>) -> tensor<3x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<3x64x128xbf16>, tensor<128x64xbf16>, tensor<2x64x64xbf16>, tensor<3x64x64xbf16>) -> tensor<3x64x64xbf16> return %1 : tensor<3x64x64xbf16> } } // Output shape mismatch tests // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_2d_2d_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttir.linear' op Output shape rank(1) must match the expected output shape rank(2) %0 = tensor.empty() : tensor<64xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } } // ----- -#any_device_tile = #tt.operand_constraint module { func.func @linear_negative_2d_2d_output_shape_mismatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x128xbf16> { // CHECK: error: 'ttir.linear' op Output shape dimension[1](128) doesn't match the expected output shape dimension[1](64) %0 = tensor.empty() : tensor<64x128xbf16> - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_negative.mlir index 67cd1af5e..f15379c8f 100644 --- a/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_negative.mlir @@ -2,12 +2,11 @@ // Negative tests for matmul operation // Verify that the parsing fails if either of operands is a scalar -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttir.matmul' op Input A must be at least a 1D tensor %0 = tensor.empty() : tensor<1xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } } @@ -18,19 +17,19 @@ module attributes {} { func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { // CHECK: error: 'ttir.matmul' op Input B must be at least a 1D tensor %0 = tensor.empty() : tensor<1xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } } -// Verifty that the parsing fails if the output is a scalar +// Verify that the parsing fails if the output is a scalar // ----- #any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { // CHECK: error: 'ttir.matmul' op Scalar output is not supported, output must be at least a 1D tensor %0 = tensor.empty() : tensor - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor) -> tensor + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor) -> tensor return %1 : tensor } } @@ -41,7 +40,7 @@ module attributes {} { func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { // CHECK: error: 'ttir.matmul' op Scalar output must be a 1D tensor of size 1 %0 = tensor.empty() : tensor<2xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<2xbf16>) -> tensor<2xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<2xbf16>) -> tensor<2xbf16> return %1 : tensor<2xbf16> } } @@ -53,7 +52,7 @@ module attributes {} { func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttir.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<1xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<64xbf16>, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } } @@ -64,7 +63,7 @@ module attributes {} { func.func @matmul_negative_1d_2d_inner_dimension_missmatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttir.matmul' op Input A[-1](64) and B[-2](128) must have matching inner dimensions %0 = tensor.empty() : tensor<64xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } } @@ -75,7 +74,7 @@ module attributes {} { func.func @matmul_negative_2d_1d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttir.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<64xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } } @@ -86,7 +85,7 @@ module attributes {} { func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { // CHECK: error: 'ttir.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<64x64xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } } @@ -97,7 +96,7 @@ module attributes {} { func.func @matmul_negative_nd_nd_inner_dimension_missmatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttir.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions %0 = tensor.empty() : tensor<7x64x64xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x64x128xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } } @@ -109,7 +108,7 @@ module attributes {} { func.func @matmul_negative_nd_nd_same_rank_batch_broadcast_incompatible_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttir.matmul' op Batch dimensions of input A(7) and B(2) are not broadcast compatible %0 = tensor.empty() : tensor<7x64x64xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<2x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<2x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } } @@ -120,7 +119,7 @@ module attributes {} { func.func @matmul_negative_nd_nd_same_rank_batch_broadcast_incompatible_2(%arg0: tensor<2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { // CHECK: error: 'ttir.matmul' op Batch dimensions of input A(2,7) and B(7,1) are not broadcast compatible %0 = tensor.empty() : tensor<7x64x64xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x7x64x64xbf16> return %1 : tensor<7x7x64x64xbf16> } } @@ -131,7 +130,7 @@ module attributes {} { func.func @matmul_negative_nd_nd_different_rank_batch_broadcast_incompatible(%arg0: tensor<12x2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { // CHECK: error: 'ttir.matmul' op Batch dimensions of input A(12,2,7) and B(7,1) are not broadcast compatible %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x2x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> return %1 : tensor<12x7x7x64x64xbf16> } } @@ -143,7 +142,7 @@ module attributes {} { func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttir.matmul' op Output shape rank(1) must match the expected output shape rank(2) %0 = tensor.empty() : tensor<64xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } } @@ -154,7 +153,7 @@ module attributes {} { func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x128xbf16> { // CHECK: error: 'ttir.matmul' op Output shape dimension[1](128) doesn't match the expected output shape dimension[1](64) %0 = tensor.empty() : tensor<64x128xbf16> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_positive.mlir index 3823edbc3..cfc77c0fb 100644 --- a/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTIR/matmul/matmul_tests_positive.mlir @@ -1,59 +1,58 @@ // RUN: ttmlir-opt %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_1d_1d(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<1xbf16> { %0 = tensor.empty() : tensor<1xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } func.func @matmul_1d_2d(%arg0: tensor<128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { %0 = tensor.empty() : tensor<64xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } func.func @matmul_2d_1d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128xbf16>) -> tensor<64xbf16> { %0 = tensor.empty() : tensor<64xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } func.func @matmul_2d_2d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } func.func @matmul_1d_nd(%arg0: tensor<128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64xbf16> { %0 = tensor.empty() : tensor<12x7x64xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> return %1 : tensor<12x7x64xbf16> } func.func @matmul_nd_1d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64xbf16>) -> tensor<12x7x128xbf16> { %0 = tensor.empty() : tensor<12x7x128xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> return %1 : tensor<12x7x128xbf16> } func.func @matmul_2d_nd(%arg0: tensor<64x128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64x64xbf16> { %0 = tensor.empty() : tensor<12x7x64x64xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> return %1 : tensor<12x7x64x64xbf16> } func.func @matmul_nd_2d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<12x7x128x128xbf16> { %0 = tensor.empty() : tensor<12x7x128x128xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> return %1 : tensor<12x7x128x128xbf16> } @@ -61,28 +60,28 @@ module attributes {} { func.func @matmul_nd_nd_same_rank_same_dims(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<7x128x64xbf16>) -> tensor<7x64x64xbf16> { %0 = tensor.empty() : tensor<7x64x64xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } func.func @matmul_nd_nd_same_rank_broadcastable_dims_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x128x64xbf16>) -> tensor<7x64x64xbf16> { %0 = tensor.empty() : tensor<7x64x64xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } func.func @matmul_nd_nd_same_rank_broadcastable_dims_2(%arg0: tensor<1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { %0 = tensor.empty() : tensor<7x7x64x64xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> return %1 : tensor<7x7x64x64xbf16> } func.func @matmul_nd_nd_different_rank_broadcastable_dims_2(%arg0: tensor<12x1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> // CHECK: %[[C:.*]] = "ttir.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> return %1 : tensor<12x7x7x64x64xbf16> } } diff --git a/test/ttmlir/Dialect/TTIR/select/select_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/select/select_tests_negative.mlir index f505bfcb7..9ac44f4f2 100644 --- a/test/ttmlir/Dialect/TTIR/select/select_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/select/select_tests_negative.mlir @@ -1,11 +1,10 @@ // RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_negative_invalid_dim(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = tensor.empty() : tensor<4x4xf32> // CHECK: {{.*error.*Invalid dimension}} - %1 = "ttir.select"(%arg0, %0) <{dim = -3: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = -3: si32, begin = 0: si32, length = 4: si32, stride = 4: si32}> : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> } @@ -13,12 +12,11 @@ module attributes {} { // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_negative_invalid_stride(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = tensor.empty() : tensor<4x4xf32> // CHECK: {{.*error.*Invalid stride.*}} - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 7: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 7: si32}> : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> } @@ -26,12 +24,11 @@ module attributes {} { // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_negative_invalid_stride_2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = tensor.empty() : tensor<4x4xf32> // CHECK: {{.*error.*Invalid stride.*}} - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = -1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = -1: si32}> : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> } @@ -39,12 +36,11 @@ module attributes {} { // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_negative_invalid_begin(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = tensor.empty() : tensor<4x4xf32> // CHECK: {{.*error.*Invalid begin index.*}} - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = -3: si32, length = 4: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = -3: si32, length = 4: si32, stride = 1: si32}> : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> } @@ -52,12 +48,11 @@ module attributes {} { // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_negative_invalid_begin_2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = tensor.empty() : tensor<4x4xf32> // CHECK: {{.*error.*Invalid begin index.*}} - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 4: si32, length = 4: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 4: si32, length = 4: si32, stride = 1: si32}> : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> } @@ -65,12 +60,11 @@ module attributes {} { // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_negative_invalid_length(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = tensor.empty() : tensor<4x4xf32> // CHECK: {{.*error.*Invalid length.*}} - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 5: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 5: si32, stride = 1: si32}> : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> } @@ -78,12 +72,11 @@ module attributes {} { // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_negative_invalid_length_2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = tensor.empty() : tensor<4x4xf32> // CHECK: {{.*error.*Invalid length.*}} - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 0: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 0: si32, stride = 1: si32}> : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> } @@ -91,12 +84,11 @@ module attributes {} { // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_negative_invalid_length_3(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = tensor.empty() : tensor<4x4xf32> // CHECK: {{.*error.*Invalid length.*}} - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 2: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 2: si32, stride = 1: si32}> : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> } @@ -104,12 +96,11 @@ module attributes {} { // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_negative_invalid_total_size(%arg0: tensor<4x2x64x48xf32>) -> tensor<4x2x4x48xf32> { %0 = tensor.empty() : tensor<4x2x4x48xf32> // CHECK: {{.*error.*Sum of all slices.*}} - %1 = "ttir.select"( %arg0, %0) <{dim = 2: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"( %arg0, %0) <{dim = 2: si32, begin = 0: si32, length = 4: si32, stride = 4: si32}> : (tensor<4x2x64x48xf32>, tensor<4x2x4x48xf32>) -> tensor<4x2x4x48xf32> return %1 : tensor<4x2x4x48xf32> } diff --git a/test/ttmlir/Dialect/TTIR/select/select_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/select/select_tests_positive.mlir index b613c85bf..87be5c7cc 100644 --- a/test/ttmlir/Dialect/TTIR/select/select_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTIR/select/select_tests_positive.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @select_identity(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { %0 = tensor.empty() : tensor<4x4xf32> // CHECK: %{{[0-9]+}} = "ttir.select" - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 4: si32, stride = 4: si32}> : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> return %1 : tensor<4x4xf32> } @@ -13,7 +12,7 @@ module attributes {} { func.func @select_half(%arg0: tensor<4x4xf32>) -> tensor<4x2xf32> { %0 = tensor.empty() : tensor<4x2xf32> // CHECK: %{{[0-9]+}} = "ttir.select" - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 2: si32, stride = 4: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 0: si32, length = 2: si32, stride = 4: si32}> : (tensor<4x4xf32>, tensor<4x2xf32>) -> tensor<4x2xf32> return %1 : tensor<4x2xf32> } @@ -21,7 +20,7 @@ module attributes {} { func.func @select_single(%arg0: tensor<4x4xf32>) -> tensor<4x1xf32> { %0 = tensor.empty() : tensor<4x1xf32> // CHECK: %{{[0-9]+}} = "ttir.select" - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 3: si32, length = 1: si32, stride = 1: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 3: si32, length = 1: si32, stride = 1: si32}> : (tensor<4x4xf32>, tensor<4x1xf32>) -> tensor<4x1xf32> return %1 : tensor<4x1xf32> } @@ -29,7 +28,7 @@ module attributes {} { func.func @select_half_2_no_stride(%arg0: tensor<4x4xf32>) -> tensor<4x2xf32> { %0 = tensor.empty() : tensor<4x2xf32> // CHECK: %{{[0-9]+}} = "ttir.select" - %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 2: si32, length = 2: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = 1: si32, begin = 2: si32, length = 2: si32}> : (tensor<4x4xf32>, tensor<4x2xf32>) -> tensor<4x2xf32> return %1 : tensor<4x2xf32> } @@ -37,7 +36,7 @@ module attributes {} { func.func @select_neg_dim(%arg0: tensor<10x3x128x64xf32>) -> tensor<10x3x8x64xf32> { %0 = tensor.empty() : tensor<10x3x8x64xf32> // CHECK: %{{[0-9]+}} = "ttir.select" - %1 = "ttir.select"(%arg0, %0) <{dim = -2: si32, begin = 0: si32, length = 2: si32, stride = 32: si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : + %1 = "ttir.select"(%arg0, %0) <{dim = -2: si32, begin = 0: si32, length = 2: si32, stride = 32: si32}> : (tensor<10x3x128x64xf32>, tensor<10x3x8x64xf32>) -> tensor<10x3x8x64xf32> return %1 : tensor<10x3x8x64xf32> } diff --git a/test/ttmlir/Dialect/TTIR/slice/slice_tests_negative.mlir b/test/ttmlir/Dialect/TTIR/slice/slice_tests_negative.mlir index ba640992f..db444258e 100644 --- a/test/ttmlir/Dialect/TTIR/slice/slice_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/slice/slice_tests_negative.mlir @@ -2,12 +2,11 @@ // Negative tests for slice operation // Verify that the parsing fails if the begins attribute is not a 3D tensor -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @slice_negative_invalid_shape(%arg0: tensor) -> tensor<1xbf16> { %0 = tensor.empty() : tensor<1xbf16> // CHECK: error: 'ttir.slice' op Input must be at least a 1D tensor - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32], ends = [0: i32], step = [1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32], ends = [0: i32], step = [1: i32]}> : (tensor, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } } @@ -19,7 +18,7 @@ module attributes {} { func.func @slice_negative_invalid_begins(%arg0: tensor<3x128x64xbf16>) -> tensor<1x64x64xbf16> { %0 = tensor.empty() : tensor<1x64x64xbf16> // CHECK: error: 'ttir.slice' op Begins, ends, and step attributes must have the same number of elements as the input tensor rank - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32], ends = [0: i32, 63: i32, 63: i32], step = [1: i32, 1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<1x64x64xbf16>) -> tensor<1x64x64xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32], ends = [0: i32, 63: i32, 63: i32], step = [1: i32, 1: i32, 1: i32]}> : (tensor<3x128x64xbf16>, tensor<1x64x64xbf16>) -> tensor<1x64x64xbf16> return %1 : tensor<1x64x64xbf16> } } @@ -31,7 +30,7 @@ module attributes {} { func.func @slice_negative_invalid_ends(%arg0: tensor<3x128x64xbf16>) -> tensor<1x64x64xbf16> { %0 = tensor.empty() : tensor<1x64x64xbf16> // CHECK: error: 'ttir.slice' op Begins, ends, and step attributes must have the same number of elements as the input tensor rank - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [0: i32, 63: i32], step = [1: i32, 1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<1x64x64xbf16>) -> tensor<1x64x64xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [0: i32, 63: i32], step = [1: i32, 1: i32, 1: i32]}> : (tensor<3x128x64xbf16>, tensor<1x64x64xbf16>) -> tensor<1x64x64xbf16> return %1 : tensor<1x64x64xbf16> } } @@ -43,7 +42,7 @@ module attributes {} { func.func @slice_negative_invalid_step(%arg0: tensor<3x128x64xbf16>) -> tensor<1x64x64xbf16> { %0 = tensor.empty() : tensor<1x64x64xbf16> // CHECK: error: 'ttir.slice' op Begins, ends, and step attributes must have the same number of elements as the input tensor rank - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [0: i32, 63: i32, 63: i32], step = [1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<1x64x64xbf16>) -> tensor<1x64x64xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [0: i32, 63: i32, 63: i32], step = [1: i32, 1: i32]}> : (tensor<3x128x64xbf16>, tensor<1x64x64xbf16>) -> tensor<1x64x64xbf16> return %1 : tensor<1x64x64xbf16> } } @@ -55,7 +54,7 @@ module attributes {} { func.func @slice_negative_invalid_output_datatype(%arg0: tensor<3x128x64xbf16>) -> tensor<1x64x64xf32> { %0 = tensor.empty() : tensor<1x64x64xf32> // CHECK: error: 'ttir.slice' op Output tensor must have the same element type as the input tensor - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [0: i32, 63: i32, 63: i32], step = [1: i32, 1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<1x64x64xf32>) -> tensor<1x64x64xf32> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [0: i32, 63: i32, 63: i32], step = [1: i32, 1: i32, 1: i32]}> : (tensor<3x128x64xbf16>, tensor<1x64x64xf32>) -> tensor<1x64x64xf32> return %1 : tensor<1x64x64xf32> } } @@ -67,7 +66,7 @@ module attributes {} { func.func @slice_negative_input_output_rank_missmatch(%arg0: tensor<3x128x64xbf16>) -> tensor<1x1x64x64xbf16> { %0 = tensor.empty() : tensor<1x1x64x64xbf16> // CHECK: error: 'ttir.slice' op Output tensor must have the same rank as the input tensor - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [0: i32, 63: i32, 63: i32], step = [1: i32, 1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<1x1x64x64xbf16>) -> tensor<1x1x64x64xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [0: i32, 63: i32, 63: i32], step = [1: i32, 1: i32, 1: i32]}> : (tensor<3x128x64xbf16>, tensor<1x1x64x64xbf16>) -> tensor<1x1x64x64xbf16> return %1 : tensor<1x1x64x64xbf16> } } @@ -79,7 +78,7 @@ module attributes {} { func.func @slice_negative_invalid_begin_positive(%arg0: tensor<10x3x128x64xbf16>) -> tensor<4x1x16x8xbf16> { %0 = tensor.empty() : tensor<4x1x16x8xbf16> // CHECK: error: 'ttir.slice' op Invalid begin index for dimension 2. Expected value in range [-128, 128), got 128. Input shape: (10, 3, 128, 64) - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 128: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 4: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 128: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 4: i32]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> return %1 : tensor<4x1x16x8xbf16> } } @@ -91,7 +90,7 @@ module attributes {} { func.func @slice_negative_invalid_begin_negative(%arg0: tensor<10x3x128x64xbf16>) -> tensor<4x1x16x8xbf16> { %0 = tensor.empty() : tensor<4x1x16x8xbf16> // CHECK: error: 'ttir.slice' op Invalid begin index for dimension 2. Expected value in range [-128, 128), got -129. Input shape: (10, 3, 128, 64) - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, -129: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 4: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, -129: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 4: i32]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> return %1 : tensor<4x1x16x8xbf16> } } @@ -103,7 +102,7 @@ module attributes {} { func.func @slice_negative_invalid_end_positive(%arg0: tensor<10x3x128x64xbf16>) -> tensor<4x1x16x8xbf16> { %0 = tensor.empty() : tensor<4x1x16x8xbf16> // CHECK: error: 'ttir.slice' op Invalid end index for dimension 3. Expected value in range [-64, 64], got 65. Input shape: (10, 3, 128, 64) - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 65: i32], step = [3: i32, 3: i32, 8: i32, 4: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 65: i32], step = [3: i32, 3: i32, 8: i32, 4: i32]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> return %1 : tensor<4x1x16x8xbf16> } } @@ -115,7 +114,7 @@ module attributes {} { func.func @slice_negative_invalid_end_negative(%arg0: tensor<10x3x128x64xbf16>) -> tensor<4x1x16x8xbf16> { %0 = tensor.empty() : tensor<4x1x16x8xbf16> // CHECK: error: 'ttir.slice' op Invalid end index for dimension 3. Expected value in range [-64, 64], got -65. Input shape: (10, 3, 128, 64) - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, -65: i32], step = [3: i32, 3: i32, 8: i32, 4: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, -65: i32], step = [3: i32, 3: i32, 8: i32, 4: i32]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> return %1 : tensor<4x1x16x8xbf16> } } @@ -127,7 +126,7 @@ module attributes {} { func.func @slice_negative_step_is_zero(%arg0: tensor<10x3x128x64xbf16>) -> tensor<4x1x16x8xbf16> { %0 = tensor.empty() : tensor<4x1x16x8xbf16> // CHECK: error: 'ttir.slice' op Step value for dimension 3 cannot be zero - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 0: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 0: i32]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> return %1 : tensor<4x1x16x8xbf16> } } @@ -139,7 +138,7 @@ module attributes {} { func.func @slice_negative_begin_greater_than_end_positive_step(%arg0: tensor<10x3x128x64xbf16>) -> tensor<4x1x16x8xbf16> { %0 = tensor.empty() : tensor<4x1x16x8xbf16> // CHECK: error: 'ttir.slice' op For positive step, begin index must be less than or equal to end index for dimension 0. Got begin: 9, end: 0, step: 3, input shape: (10, 3, 128, 64) - %1 = "ttir.slice"(%arg0, %0) <{begins = [9: i32, 0: i32, 0: i32, 32: i32], ends = [0: i32, 3: i32, 32: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 4: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [9: i32, 0: i32, 0: i32, 32: i32], ends = [0: i32, 3: i32, 32: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 4: i32]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> return %1 : tensor<4x1x16x8xbf16> } } @@ -150,7 +149,7 @@ module attributes {} { func.func @slice_negative_begin_greater_than_end_positive_step(%arg0: tensor<10x3x128x64xbf16>) -> tensor<4x1x8x8xbf16> { %0 = tensor.empty() : tensor<4x1x8x8xbf16> // CHECK: error: 'ttir.slice' op For positive step, begin index must be less than or equal to end index for dimension 2. Got begin: 96 (-32), end: 32 (-96), step: 8, input shape: (10, 3, 128, 64) - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, -32: i32, 32: i32], ends = [10: i32, 3: i32, -96: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 4: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x8x8xbf16>) -> tensor<4x1x8x8xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, -32: i32, 32: i32], ends = [10: i32, 3: i32, -96: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 4: i32]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x8x8xbf16>) -> tensor<4x1x8x8xbf16> return %1 : tensor<4x1x8x8xbf16> } } @@ -162,7 +161,7 @@ module attributes {} { func.func @slice_negative_begin_less_than_end_negative_step(%arg0: tensor<10x3x128x64xbf16>) -> tensor<4x1x16x8xbf16> { %0 = tensor.empty() : tensor<4x1x16x8xbf16> // CHECK: error: 'ttir.slice' op For negative step, begin index must be greater than or equal to end index for dimension 1. Got begin: 0 (-3), end: 2 (-1), step: -3, input shape: (10, 3, 128, 64) - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, -3: i32, 0: i32, 32: i32], ends = [10: i32, -1: i32, 32: i32, 128: i32], step = [3: i32, -3: i32, 8: i32, 8: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, -3: i32, 0: i32, 32: i32], ends = [10: i32, -1: i32, 32: i32, 128: i32], step = [3: i32, -3: i32, 8: i32, 8: i32]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x8xbf16>) -> tensor<4x1x16x8xbf16> return %1 : tensor<4x1x16x8xbf16> } } @@ -173,7 +172,7 @@ module attributes {} { func.func @slice_negative_begin_less_than_end_negative_step(%arg0: tensor<10x3x128x64xbf16>) -> tensor<5x1x16x8xbf16> { %0 = tensor.empty() : tensor<5x1x16x8xbf16> // CHECK: error: 'ttir.slice' op For negative step, begin index must be greater than or equal to end index for dimension 0. Got begin: 0, end: 10, step: -2, input shape: (10, 3, 128, 64) - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 64: i32], step = [-2: i32, 3: i32, 8: i32, 4: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<5x1x16x8xbf16>) -> tensor<5x1x16x8xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 64: i32], step = [-2: i32, 3: i32, 8: i32, 4: i32]}> : (tensor<10x3x128x64xbf16>, tensor<5x1x16x8xbf16>) -> tensor<5x1x16x8xbf16> return %1 : tensor<5x1x16x8xbf16> } } @@ -185,7 +184,7 @@ module attributes {} { func.func @slice_negative_invalid_output_shape(%arg0: tensor<10x3x128x64xbf16>) -> tensor<4x1x16x16xbf16> { %0 = tensor.empty() : tensor<4x1x16x16xbf16> // CHECK: error: 'ttir.slice' op Mismatch in dimension 3 of the output tensor: expected size 8, but got 16 - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 4: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x16xbf16>) -> tensor<4x1x16x16xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32, 32: i32], ends = [10: i32, 3: i32, 128: i32, 64: i32], step = [3: i32, 3: i32, 8: i32, 4: i32]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x16xbf16>) -> tensor<4x1x16x16xbf16> return %1 : tensor<4x1x16x16xbf16> } } diff --git a/test/ttmlir/Dialect/TTIR/slice/slice_tests_positive.mlir b/test/ttmlir/Dialect/TTIR/slice/slice_tests_positive.mlir index 1ff464b3e..af5fa4cc9 100644 --- a/test/ttmlir/Dialect/TTIR/slice/slice_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTIR/slice/slice_tests_positive.mlir @@ -1,59 +1,58 @@ // RUN: ttmlir-opt %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @slice_1d(%arg0: tensor<64xbf16>) -> tensor<32xbf16> { %0 = tensor.empty() : tensor<32xbf16> // CHECK: %[[C:.*]] = "ttir.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32], ends = [32: i32], step = [1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64xbf16>, tensor<32xbf16>) -> tensor<32xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32], ends = [32: i32], step = [1: i32]}> : (tensor<64xbf16>, tensor<32xbf16>) -> tensor<32xbf16> return %1 : tensor<32xbf16> } func.func @slice_1d_step(%arg0: tensor<64xbf16>) -> tensor<16xbf16> { %0 = tensor.empty() : tensor<16xbf16> // CHECK: %[[C:.*]] = "ttir.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32], ends = [64: i32], step = [4: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64xbf16>, tensor<16xbf16>) -> tensor<16xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32], ends = [64: i32], step = [4: i32]}> : (tensor<64xbf16>, tensor<16xbf16>) -> tensor<16xbf16> return %1 : tensor<16xbf16> } func.func @slice_2d(%arg0: tensor<128x64xbf16>) -> tensor<64x32xbf16> { %0 = tensor.empty() : tensor<64x32xbf16> // CHECK: %[[C:.*]] = "ttir.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32], ends = [64: i32, 32: i32], step = [1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<128x64xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32], ends = [64: i32, 32: i32], step = [1: i32, 1: i32]}> : (tensor<128x64xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> return %1 : tensor<64x32xbf16> } func.func @slice_2d_step(%arg0: tensor<128x64xbf16>) -> tensor<32x16xbf16> { %0 = tensor.empty() : tensor<32x16xbf16> // CHECK: %[[C:.*]] = "ttir.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [64: i32, 0: i32], ends = [128: i32, 64: i32], step = [2: i32, 4: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<128x64xbf16>, tensor<32x16xbf16>) -> tensor<32x16xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [64: i32, 0: i32], ends = [128: i32, 64: i32], step = [2: i32, 4: i32]}> : (tensor<128x64xbf16>, tensor<32x16xbf16>) -> tensor<32x16xbf16> return %1 : tensor<32x16xbf16> } func.func @slice_3d(%arg0: tensor<3x128x64xbf16>) -> tensor<1x64x64xbf16> { %0 = tensor.empty() : tensor<1x64x64xbf16> // CHECK: %[[C:.*]] = "ttir.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [1: i32, 64: i32, 64: i32], step = [1: i32, 1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<1x64x64xbf16>) -> tensor<1x64x64xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [1: i32, 64: i32, 64: i32], step = [1: i32, 1: i32, 1: i32]}> : (tensor<3x128x64xbf16>, tensor<1x64x64xbf16>) -> tensor<1x64x64xbf16> return %1 : tensor<1x64x64xbf16> } func.func @slice_3d_step(%arg0: tensor<3x128x64xbf16>) -> tensor<2x32x32xbf16> { %0 = tensor.empty() : tensor<2x32x32xbf16> // CHECK: %[[C:.*]] = "ttir.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 32: i32], ends = [3: i32, 128: i32, 64: i32], step = [2: i32, 4: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<3x128x64xbf16>, tensor<2x32x32xbf16>) -> tensor<2x32x32xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 32: i32], ends = [3: i32, 128: i32, 64: i32], step = [2: i32, 4: i32, 1: i32]}> : (tensor<3x128x64xbf16>, tensor<2x32x32xbf16>) -> tensor<2x32x32xbf16> return %1 : tensor<2x32x32xbf16> } func.func @slice_4d(%arg0: tensor<10x3x128x64xbf16>) -> tensor<5x3x32x64xbf16> { %0 = tensor.empty() : tensor<5x3x32x64xbf16> // CHECK: %[[C:.*]] = "ttir.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [3: i32, 0: i32, 32: i32, 0: i32], ends = [8: i32, 3: i32, 64: i32, 64: i32], step = [1: i32, 1: i32, 1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<5x3x32x64xbf16>) -> tensor<5x3x32x64xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [3: i32, 0: i32, 32: i32, 0: i32], ends = [8: i32, 3: i32, 64: i32, 64: i32], step = [1: i32, 1: i32, 1: i32, 1: i32]}> : (tensor<10x3x128x64xbf16>, tensor<5x3x32x64xbf16>) -> tensor<5x3x32x64xbf16> return %1 : tensor<5x3x32x64xbf16> } func.func @slice_4d_step(%arg0: tensor<10x3x128x64xbf16>) -> tensor<4x1x16x32xbf16> { %0 = tensor.empty() : tensor<4x1x16x32xbf16> // CHECK: %[[C:.*]] = "ttir.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 2: i32, 0: i32, -64: i32], ends = [10: i32, 0: i32, -1: i32, -1: i32], step = [3: i32, -2: i32, 8: i32, 2: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x32xbf16>) -> tensor<4x1x16x32xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 2: i32, 0: i32, -64: i32], ends = [10: i32, 0: i32, -1: i32, -1: i32], step = [3: i32, -2: i32, 8: i32, 2: i32]}> : (tensor<10x3x128x64xbf16>, tensor<4x1x16x32xbf16>) -> tensor<4x1x16x32xbf16> return %1 : tensor<4x1x16x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTIR/test_allocate.mlir b/test/ttmlir/Dialect/TTIR/test_allocate.mlir index 5888cf3f6..4acbda2f9 100644 --- a/test/ttmlir/Dialect/TTIR/test_allocate.mlir +++ b/test/ttmlir/Dialect/TTIR/test_allocate.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-allocate %s | FileCheck %s -#any_device = #tt.operand_constraint #l1_ = #tt.memory_space #layout = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>, interleaved> module attributes {} { @@ -7,7 +6,7 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttir.alloc"[[C:.*]] // CHECK-NOT: %[[C:.*]] = tensor.empty() : tensor<64x128xf32> %0 = tensor.empty() : tensor<64x128xf32, #layout> - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> return %1 : tensor<64x128xf32, #layout> } } diff --git a/test/ttmlir/Dialect/TTIR/test_generic.mlir b/test/ttmlir/Dialect/TTIR/test_generic.mlir index ff50eef4b..0899746e1 100644 --- a/test/ttmlir/Dialect/TTIR/test_generic.mlir +++ b/test/ttmlir/Dialect/TTIR/test_generic.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-generic %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttir.generic"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTIR/test_layout.mlir b/test/ttmlir/Dialect/TTIR/test_layout.mlir index 3253f6d23..232c69504 100644 --- a/test/ttmlir/Dialect/TTIR/test_layout.mlir +++ b/test/ttmlir/Dialect/TTIR/test_layout.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<8x64x128xf32>, %arg1: tensor<8x64x128xf32>) -> tensor<8x64x128xf32> { // CHECK: %[[C:.*]] = tensor.empty() : tensor<8x64x128xf32, #layout> %0 = tensor.empty() : tensor<8x64x128xf32> - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8x64x128xf32>, tensor<8x64x128xf32>, tensor<8x64x128xf32>) -> tensor<8x64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<8x64x128xf32>, tensor<8x64x128xf32>, tensor<8x64x128xf32>) -> tensor<8x64x128xf32> return %1 : tensor<8x64x128xf32> } diff --git a/test/ttmlir/Dialect/TTIR/test_remove_dead_values_pass.mlir b/test/ttmlir/Dialect/TTIR/test_remove_dead_values_pass.mlir index 8b6df4d0f..88c56039e 100644 --- a/test/ttmlir/Dialect/TTIR/test_remove_dead_values_pass.mlir +++ b/test/ttmlir/Dialect/TTIR/test_remove_dead_values_pass.mlir @@ -1,22 +1,21 @@ // RUN: ttmlir-opt --remove-dead-values %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> %2 = tensor.empty() : tensor<64x128xf32> // CHECK-NOT: %[[C:.*]] = "ttir.add"[[C:.*]] - %3 = "ttir.add"(%arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %3 = "ttir.add"(%arg0, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> %4 = tensor.empty() : tensor<64x128xf32> // CHECK-NOT: %[[C:.*]] = "ttir.subtract"[[C:.*]] - %5 = "ttir.subtract"(%arg0, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %5 = "ttir.subtract"(%arg0, %arg1, %4) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> %6 = tensor.empty() : tensor<64x128xf32> // CHECK-NOT: %[[C:.*]] = "ttir.div"[[C:.*]] - %7 = "ttir.div"(%arg0, %arg1, %6) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %7 = "ttir.div"(%arg0, %arg1, %6) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> %8 = tensor.empty() : tensor<64x128xf32> // CHECK-NOT: %[[C:.*]] = "ttir.eq"[[C:.*]] - %9 = "ttir.eq"(%arg0, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %9 = "ttir.eq"(%arg0, %arg1, %8) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir b/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir index e1454ad0a..002a09921 100644 --- a/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/ttir_broadcastable_negative.mlir @@ -2,27 +2,27 @@ // Negative tests for Broadcastable interface // CHECK: 'ttir.abs' op Result shape must match operand shapes after broadcasting -#any_device_tile = #tt.operand_constraint + func.func @eltwise_unary(%arg0: tensor<1x64xbf16>) -> tensor<2x64xbf16> { %0 = tensor.empty() : tensor<2x64xbf16> - %1 = "ttir.abs"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x64xbf16>, tensor<2x64xbf16>) -> tensor<2x64xbf16> + %1 = "ttir.abs"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<1x64xbf16>, tensor<2x64xbf16>) -> tensor<2x64xbf16> return %1 : tensor<2x64xbf16> } // ----- // CHECK: error: 'ttir.add' op Result shape must match operand shapes after broadcasting -#any_device_tile = #tt.operand_constraint + func.func @eltwise_binary(%arg0: tensor<2x3x64xf32>, %arg1: tensor<64xf32>) -> tensor<4x2x3x64xf32> { %0 = tensor.empty() : tensor<4x2x3x64xf32> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<2x3x64xf32>, tensor<64xf32>, tensor<4x2x3x64xf32>) -> tensor<4x2x3x64xf32> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<2x3x64xf32>, tensor<64xf32>, tensor<4x2x3x64xf32>) -> tensor<4x2x3x64xf32> return %1 : tensor<4x2x3x64xf32> } // ----- // CHECK: error: 'ttir.where' op Result shape must match operand shapes after broadcasting -#any_device_tile = #tt.operand_constraint + func.func @eltwise_ternary(%arg0: tensor<3x64xf32>, %arg1: tensor<1x3x64xf32>, %arg2: tensor<2x1x64xf32>) -> tensor<1x2x3x64xf32> { %0 = tensor.empty() : tensor<1x2x3x64xf32> - %1 = "ttir.where"(%arg0, %arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64xf32>, tensor<1x3x64xf32>, tensor<2x1x64xf32>, tensor<1x2x3x64xf32>) -> tensor<1x2x3x64xf32> + %1 = "ttir.where"(%arg0, %arg1, %arg2, %0) <{operandSegmentSizes = array}> : (tensor<3x64xf32>, tensor<1x3x64xf32>, tensor<2x1x64xf32>, tensor<1x2x3x64xf32>) -> tensor<1x2x3x64xf32> return %1 : tensor<1x2x3x64xf32> } diff --git a/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir b/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir index a22dc2837..76f526ef2 100644 --- a/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir +++ b/test/ttmlir/Dialect/TTIR/ttir_noperands_negative.mlir @@ -2,36 +2,32 @@ // Negative tests for NOperands trait // CHECK: error: 'ttir.abs' op expected 2 operands, but found 3 -#any_device_tile = #tt.operand_constraint func.func @eltwise_unary(%arg0: tensor<64x64xbf16>) -> tensor<64x64xbf16> { %0 = tensor.empty() : tensor<64x64xbf16> - %1 = "ttir.abs"(%arg0, %arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.abs"(%arg0, %arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } // ----- // CHECK: error: 'ttir.add' op expected 3 operands, but found 4 -#any_device_tile = #tt.operand_constraint func.func @eltwise_binary(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> { %0 = tensor.empty() : tensor<64x64xf32> - %1 = "ttir.add"(%arg0, %arg1, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + %1 = "ttir.add"(%arg0, %arg1, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> return %1 : tensor<64x64xf32> } // ----- // CHECK: error: 'ttir.add' op expected 3 operands, but found 2 -#any_device_tile = #tt.operand_constraint func.func @eltwise_binary(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32> { %0 = tensor.empty() : tensor<64x64xf32> - %1 = "ttir.add"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + %1 = "ttir.add"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> return %1 : tensor<64x64xf32> } // ----- // CHECK: error: 'ttir.where' op expected 4 operands, but found 5 -#any_device_tile = #tt.operand_constraint func.func @eltwise_ternary(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { %0 = tensor.empty() : tensor<64x64xf32> - %1 = "ttir.where"(%arg0, %arg1, %arg2, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + %1 = "ttir.where"(%arg0, %arg1, %arg2, %arg2, %0) <{operandSegmentSizes = array}> : (tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> return %1 : tensor<64x64xf32> } diff --git a/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir b/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir index dc3f09fba..184f62c5b 100644 --- a/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/arange/arange_tests_negative.mlir @@ -1,12 +1,11 @@ // RUN: not ttmlir-opt --split-input-file %s 2>&1 | FileCheck %s // Negative tests for matmul operation -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { // CHECK: error: 'ttir.arange' op Output tensor shape must be 16 at dim 1 (since start=0, end=32, step=2), but got 32 %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 2: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> %dps = tensor.empty() : tensor<1x32x128x128xf32> - %2 = "ttir.multiply"(%arg0, %1, %dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + %2 = "ttir.multiply"(%arg0, %1, %dps) <{operandSegmentSizes = array}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> return %2 : tensor<1x32x128x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir index 945b6da5b..18026e583 100644 --- a/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTNN/arange/arange_tests_positive.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s // UNSUPPORTED: true // https://github.com/tenstorrent/tt-mlir/issues/1448 -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> { // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] %1 = "ttir.arange"() <{start = 0: si64, end = 32: si64, step = 1: si64, arange_dimension = 1: i64}> : () -> tensor<1x32x128x128xf32> %dps = tensor.empty() : tensor<1x32x128x128xf32> - %2 = "ttir.multiply"(%arg0, %1, %dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> + %2 = "ttir.multiply"(%arg0, %1, %dps) <{operandSegmentSizes = array}> : (tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>, tensor<1x32x128x128xf32>) -> tensor<1x32x128x128xf32> return %2 : tensor<1x32x128x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir index a12034794..0789b67c5 100644 --- a/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir +++ b/test/ttmlir/Dialect/TTNN/ccl/all_gather.mlir @@ -1,9 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> { %0 = tensor.empty() : tensor<1x1x32x128xbf16> - %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] return %1 : tensor<1x1x32x128xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/ccl/all_gather_negative.mlir b/test/ttmlir/Dialect/TTNN/ccl/all_gather_negative.mlir index d3f6ac3da..3e5dec812 100644 --- a/test/ttmlir/Dialect/TTNN/ccl/all_gather_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/ccl/all_gather_negative.mlir @@ -1,10 +1,9 @@ // RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s // CHECK: error: 'ttir.all_gather' op Invalid dimension for all gather op -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x1x32x32xbf16>) -> tensor<1x1x32x128xbf16> { %0 = tensor.empty() : tensor<1x1x32x128xbf16> - %1 = "ttir.all_gather"(%arg0, %0) <{dim = 4 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + %1 = "ttir.all_gather"(%arg0, %0) <{dim = 4 : si32}> : (tensor<1x1x32x32xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> return %1 : tensor<1x1x32x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/ccl/all_reduce.mlir b/test/ttmlir/Dialect/TTNN/ccl/all_reduce.mlir index c5712eb53..1eaf04df7 100644 --- a/test/ttmlir/Dialect/TTNN/ccl/all_reduce.mlir +++ b/test/ttmlir/Dialect/TTNN/ccl/all_reduce.mlir @@ -2,11 +2,10 @@ // Unit tests for ttnn all_reduce op // Verify lowering of ttir all_reduce to ttnn ops -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @all_reduce(%arg0: tensor<4096x16384xf32>) -> tensor<4096x16384xf32> { %0 = tensor.empty() : tensor<4096x16384xf32> - %1 = "ttir.all_reduce"(%arg0, %0) <{channel_handle = 1 : si32, dim = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile], reduce_type = #tt.reduce_type, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<4096x16384xf32>, tensor<4096x16384xf32>) -> tensor<4096x16384xf32> + %1 = "ttir.all_reduce"(%arg0, %0) <{channel_handle = 1 : si32, dim = 0 : si32, reduce_type = #tt.reduce_type, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<4096x16384xf32>, tensor<4096x16384xf32>) -> tensor<4096x16384xf32> return %1 : tensor<4096x16384xf32> } } @@ -22,7 +21,7 @@ module attributes {} { module attributes {} { func.func @all_reduce(%arg0: tensor<1x1x4096x16384xf32>) -> tensor<1x1x4096x16384xf32> { %0 = tensor.empty() : tensor<1x1x4096x16384xf32> - %1 = "ttir.all_reduce"(%arg0, %0) <{channel_handle = 1 : si32, dim = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile], reduce_type = #tt.reduce_type, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<1x1x4096x16384xf32>, tensor<1x1x4096x16384xf32>) -> tensor<1x1x4096x16384xf32> + %1 = "ttir.all_reduce"(%arg0, %0) <{channel_handle = 1 : si32, dim = 0 : si32, reduce_type = #tt.reduce_type, replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>, use_global_device_ids}> : (tensor<1x1x4096x16384xf32>, tensor<1x1x4096x16384xf32>) -> tensor<1x1x4096x16384xf32> return %1 : tensor<1x1x4096x16384xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir b/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir index e7d20cfa8..2f488b6c2 100644 --- a/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir +++ b/test/ttmlir/Dialect/TTNN/ccl/mesh_shard.mlir @@ -1,9 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#operand_constraint = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<8192x784xf32>) -> tensor<4096x196xf32> { %0 = tensor.empty() : tensor<4096x196xf32> - %1 = "ttir.mesh_shard"(%arg0, %0) <{operand_constraints = [#operand_constraint, #operand_constraint], shard_direction = #tt.shard_direction, shard_shape = #tt.grid<2x4>, shard_type = #tt.shard_type}> : (tensor<8192x784xf32>, tensor<4096x196xf32>) -> tensor<4096x196xf32> + %1 = "ttir.mesh_shard"(%arg0, %0) <{shard_direction = #tt.shard_direction, shard_shape = #tt.grid<2x4>, shard_type = #tt.shard_type}> : (tensor<8192x784xf32>, tensor<4096x196xf32>) -> tensor<4096x196xf32> // CHECK: %[[C:.*]] = "ttnn.mesh_shard"[[C:.*]] return %1 : tensor<4096x196xf32> } diff --git a/test/ttmlir/Dialect/TTNN/concat/concat_dim_oob.mlir b/test/ttmlir/Dialect/TTNN/concat/concat_dim_oob.mlir index 5b93d0c50..1e5c1fcb6 100644 --- a/test/ttmlir/Dialect/TTNN/concat/concat_dim_oob.mlir +++ b/test/ttmlir/Dialect/TTNN/concat/concat_dim_oob.mlir @@ -1,10 +1,9 @@ // RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s // CHECK: error: 'ttir.concat' op Invalid dimension 2 for concatenation. -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { %0 = tensor.empty() : tensor<32x96xf32> - %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 2 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 2 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> return %1 : tensor<32x96xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/concat/concat_multiple_tensors.mlir b/test/ttmlir/Dialect/TTNN/concat/concat_multiple_tensors.mlir index 30bf6926b..1695bcd97 100644 --- a/test/ttmlir/Dialect/TTNN/concat/concat_multiple_tensors.mlir +++ b/test/ttmlir/Dialect/TTNN/concat/concat_multiple_tensors.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward() -> tensor<32x224xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] @@ -11,7 +10,7 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %3 = tensor.empty() : tensor<32x224xf32> // CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]] - %4 = "ttir.concat"(%0, %1, %2, %3) <{dim = 1 : si32, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x128xf32>, tensor<32x224xf32>) -> tensor<32x224xf32> + %4 = "ttir.concat"(%0, %1, %2, %3) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x128xf32>, tensor<32x224xf32>) -> tensor<32x224xf32> return %4 : tensor<32x224xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim.mlir b/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim.mlir index f8a4f2db3..b026f7732 100644 --- a/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim.mlir +++ b/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<32x96xf32> // CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]] - %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = -1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = -1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> return %1 : tensor<32x96xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim_oob.mlir b/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim_oob.mlir index 5d3a6fbd6..4aebe9fde 100644 --- a/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim_oob.mlir +++ b/test/ttmlir/Dialect/TTNN/concat/concat_negative_dim_oob.mlir @@ -1,10 +1,9 @@ // RUN: not ttmlir-opt --ttir-to-ttnn-backend-pipeline %s 2>&1 | FileCheck %s // CHECK: error: 'ttir.concat' op Invalid dimension -3 for concatenation. -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { %0 = tensor.empty() : tensor<32x96xf32> - %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = -3 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = -3 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> return %1 : tensor<32x96xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/concat/simple_concat.mlir b/test/ttmlir/Dialect/TTNN/concat/simple_concat.mlir index 0199d95ff..1acb8252b 100644 --- a/test/ttmlir/Dialect/TTNN/concat/simple_concat.mlir +++ b/test/ttmlir/Dialect/TTNN/concat/simple_concat.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<32x96xf32> // CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]] - %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> return %1 : tensor<32x96xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/convolution/complex_conv_channel_first.mlir b/test/ttmlir/Dialect/TTNN/convolution/complex_conv_channel_first.mlir index 5428c0dec..f4633b7a8 100644 --- a/test/ttmlir/Dialect/TTNN/convolution/complex_conv_channel_first.mlir +++ b/test/ttmlir/Dialect/TTNN/convolution/complex_conv_channel_first.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module @jit_convolution { func.func public @test_NCHW_IOHW_to_NHWC_OIHW_conv2d(%arg0: tensor<1x3x100x100xbf16>, %arg1: tensor<7x3x3x3xbf16>) -> tensor<1x7x100x100xbf16> { %0 = tensor.empty() : tensor<1x7x100x100xbf16> @@ -21,7 +20,6 @@ module @jit_convolution { >, feature_group_count = 1 : i64, input_dilation = array, - operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array, weight_dilation = array, window_reversal = array, diff --git a/test/ttmlir/Dialect/TTNN/convolution/simple_conv.mlir b/test/ttmlir/Dialect/TTNN/convolution/simple_conv.mlir index 5a016c596..46e9334a9 100644 --- a/test/ttmlir/Dialect/TTNN/convolution/simple_conv.mlir +++ b/test/ttmlir/Dialect/TTNN/convolution/simple_conv.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x32x32x64xbf16> { %0 = tensor.empty() : tensor<1x32x32x64xbf16> // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] - %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) <{stride_height=1: si32, stride_width=1: si32, dilation_height=1: si32, dilation_width=1: si32, groups=1: si32, padding_left=1: si32, padding_right=1: si32, padding_top=1: si32, padding_bottom=1: si32, is_convtranspose2d=0: si32, output_height_transpose=0: si32, output_width_transpose=0: si32, stride_transpose=0: si32, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) <{stride_height=1: si32, stride_width=1: si32, dilation_height=1: si32, dilation_width=1: si32, groups=1: si32, padding_left=1: si32, padding_right=1: si32, padding_top=1: si32, padding_bottom=1: si32, is_convtranspose2d=0: si32, output_height_transpose=0: si32, output_width_transpose=0: si32, stride_transpose=0: si32}> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> return %1 : tensor<1x32x32x64xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir b/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir index 8f75362a0..bd86d0218 100644 --- a/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir +++ b/test/ttmlir/Dialect/TTNN/convolution/simple_conv1d.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module { func.func @main(%arg0: tensor<1x256x512xf32>, %arg1: tensor<1024x256x1xf32>, %arg2: tensor<1024xf32>) -> tensor<1x1024x512xf32> { %0 = tensor.empty() : tensor<1x1024x512xf32> @@ -10,7 +9,7 @@ module { // CHECK: [[VAL4:%[0-9]+]] = "ttnn.reshape"([[VAL3]]) <{shape = [1 : i32, 1 : i32, 512 : i32, 256 : i32]}> : (tensor<[[TENSOR_SHAPE5]], #{{.*}}>) -> tensor<[[TENSOR_SHAPE6:[0-9]+x[0-9]+x[0-9]+x[0-9]+xf32]], #{{.*}}> // CHECK: [[VAL5:%[0-9]+]] = "ttnn.conv2d"([[VAL4]], %10, %{{[0-9]+}}, %{{[0-9]+}}) // CHECK: (tensor<[[TENSOR_SHAPE6]], #{{.*}}>, tensor<1024x256x1x1xf32, #{{.*}}>, tensor<1x1x512x1024xf32, #{{.*}}>, !tt.device<#device>) -> tensor<1x1x512x1024xf32, #{{.*}}> - %1 = "ttir.convolution"(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir, feature_group_count = 1 : i64, input_dilation = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array, weight_dilation = array, window_reversal = array, window_strides = array}> : (tensor<1x256x512xf32>, tensor<1024x256x1xf32>, tensor<1x1024x512xf32>) -> tensor<1x1024x512xf32> + %1 = "ttir.convolution"(%arg0, %arg1, %0) <{batch_group_count = 1 : i64, convolution_layout = #ttir, feature_group_count = 1 : i64, input_dilation = array, padding = array, weight_dilation = array, window_reversal = array, window_strides = array}> : (tensor<1x256x512xf32>, tensor<1024x256x1xf32>, tensor<1x1024x512xf32>) -> tensor<1x1024x512xf32> // CHECK: return %{{.*}} : tensor<1x1024x512xf32, #ttnn_layout3> return %1 : tensor<1x1024x512xf32> } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_and/simple_and.mlir b/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_and/simple_and.mlir index eca8b639d..e6400a752 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_and/simple_and.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_and/simple_and.mlir @@ -1,11 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s - -#any_device = #tt.operand_constraint module attributes {} { func.func @logical_and(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: {{.*}} = "ttnn.empty"{{.*}} - %1 = "ttir.logical_and"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.logical_and"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.logical_and" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, diff --git a/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_or/simple_or.mlir b/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_or/simple_or.mlir index 7a51b7159..bb35140eb 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_or/simple_or.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_or/simple_or.mlir @@ -1,11 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s - -#any_device = #tt.operand_constraint module attributes {} { func.func @logical_or(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: {{.*}} = "ttnn.empty"{{.*}} - %1 = "ttir.logical_or"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.logical_or"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.logical_or" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, diff --git a/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_xor/simple_xor.mlir b/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_xor/simple_xor.mlir index f59f49040..302c76d3f 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_xor/simple_xor.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/binary/logical_xor/simple_xor.mlir @@ -1,6 +1,5 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]] @@ -10,7 +9,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/binary/minimum/simple_minimum.mlir b/test/ttmlir/Dialect/TTNN/eltwise/binary/minimum/simple_minimum.mlir index 8ebdfe0a4..7b3576cb7 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/binary/minimum/simple_minimum.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/binary/minimum/simple_minimum.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.minimum"[[C:.*]] - %1 = "ttir.minimum"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.minimum"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/binary/remainder/simple_remainder.mlir b/test/ttmlir/Dialect/TTNN/eltwise/binary/remainder/simple_remainder.mlir index 281dccfdd..67d283c07 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/binary/remainder/simple_remainder.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/binary/remainder/simple_remainder.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}} - %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> // CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}} return %1 : tensor<32x32xf32> // CHECK: return {{.*}} : tensor<32x32xf32, {{.*}} diff --git a/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts.mlir b/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts.mlir index d5c1173b9..9b5df3852 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @bcast_one_dim(%arg0: tensor<2x64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<2x64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<2x64x128xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2x64x128xf32>, tensor<64x128xf32>, tensor<2x64x128xf32>) -> tensor<2x64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<2x64x128xf32>, tensor<64x128xf32>, tensor<2x64x128xf32>) -> tensor<2x64x128xf32> return %1 : tensor<2x64x128xf32> } @@ -13,7 +12,7 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<17x16x15x14xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<17x16x15x14xf32>, tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<17x16x15x14xf32>, tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> return %1 : tensor<17x16x15x14xf32> } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts_negative.mlir b/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts_negative.mlir index 1d26cb7f9..d0e89b66c 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts_negative.mlir @@ -1,10 +1,9 @@ // RUN: not ttmlir-opt --ttir-load-system-desc --ttir-layout --convert-ttir-to-ttnn %s 2>&1 | FileCheck %s // CHECK: error: 'ttir.multiply' op Operands are not broadcast compatible -#any_device = #tt.operand_constraint module attributes {} { func.func @bcast_one_dim(%arg0: tensor<2x64x128xf32>, %arg1: tensor<4x64x128xf32>) -> tensor<4x64x128xf32> { %0 = tensor.empty() : tensor<4x64x128xf32> - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2x64x128xf32>, tensor<4x64x128xf32>, tensor<4x64x128xf32>) -> tensor<4x64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<2x64x128xf32>, tensor<4x64x128xf32>, tensor<4x64x128xf32>) -> tensor<4x64x128xf32> return %1 : tensor<4x64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/abs/simple_abs.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/abs/simple_abs.mlir index e1b2862f7..eceb8d058 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/abs/simple_abs.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/abs/simple_abs.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.abs"[[C:.*]] - %1 = "ttir.abs"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.abs"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/cast/simple_cast.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/cast/simple_cast.mlir index eb3575992..e37481265 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/cast/simple_cast.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/cast/simple_cast.mlir @@ -1,9 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> - %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> // CHECK: %[[C:.*]] = "ttnn.typecast" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xbf16, @@ -13,7 +12,7 @@ module attributes {} { func.func @cast_fold(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CEHCK-LABEL: func.func @cast_fold %0 = tensor.empty() : tensor<64x128xf32> - %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK-NOT: typecast // CHECK: return %arg0 : tensor<64x128xf32 return %1 : tensor<64x128xf32> @@ -24,9 +23,9 @@ module attributes {} { // CHECK-NOT: typecast // CHECK: ttnn.add %0 = tensor.empty() : tensor<64x128xf32> - %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> %2 = tensor.empty() : tensor<64x128xf32> - %3 = "ttir.add"(%1, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %3 = "ttir.add"(%1, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %3 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/cbrt/simple_cbrt.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/cbrt/simple_cbrt.mlir index bb7254f91..bdb78fed8 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/cbrt/simple_cbrt.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/cbrt/simple_cbrt.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.cbrt"[[C:.*]] - %1 = "ttir.cbrt"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.cbrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/ceil/simple_ceil.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/ceil/simple_ceil.mlir index fb90280e3..d0250d5cd 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/ceil/simple_ceil.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/ceil/simple_ceil.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.ceil"[[C:.*]] - %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/cos/simple_cos.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/cos/simple_cos.mlir index 2e53a4f3f..e990aa59c 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/cos/simple_cos.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/cos/simple_cos.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.cos"[[C:.*]] - %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/expm1/simple_expm1.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/expm1/simple_expm1.mlir index 59a7b2a18..bbcbf5dd6 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/expm1/simple_expm1.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/expm1/simple_expm1.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.expm1"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.expm1"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.expm1"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/floor/simple_floor.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/floor/simple_floor.mlir index 820e429ec..fd418fbda 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/floor/simple_floor.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/floor/simple_floor.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @floor(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %{{[0-9]+}} = "ttnn.empty" @@ -9,7 +8,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.floor"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.floor"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir index 0fe3e9c3b..1cec49a35 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/gelu/simple_gelu.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: "ttnn.empty" @@ -9,7 +8,7 @@ module attributes {} { // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, - %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir index 3089da669..7745adf06 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/isfinite/simple_isfinite.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty" @@ -9,7 +8,7 @@ module attributes {} { // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/leaky_relu/simple_leaky_relu.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/leaky_relu/simple_leaky_relu.mlir index aa372f3c6..93dbb03f0 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/leaky_relu/simple_leaky_relu.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/leaky_relu/simple_leaky_relu.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @leaky_relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty" @@ -9,7 +8,7 @@ module attributes {} { // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.leaky_relu"(%arg0, %0) <{parameter = 0.01 : f32, operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.leaky_relu"(%arg0, %0) <{parameter = 0.01 : f32, operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir index b65aa3c21..4258e639c 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/log1p/simple_log1p.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/logical_not/simple_not.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/logical_not/simple_not.mlir index 54375b451..a80dffca8 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/logical_not/simple_not.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/logical_not/simple_not.mlir @@ -1,11 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s - -#any_device = #tt.operand_constraint module attributes {} { func.func @logical_not(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: {{.*}} = "ttnn.empty"{{.*}} - %1 = "ttir.logical_not"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.logical_not"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.logical_not" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/negate/simple_neg.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/negate/simple_neg.mlir index e786434a4..aa63ee6e5 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/negate/simple_neg.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/negate/simple_neg.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.neg"[[C:.*]] - %1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/reciprocal/simple_reciprocal.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/reciprocal/simple_reciprocal.mlir index 0f940c0fc..fd98ade3e 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/reciprocal/simple_reciprocal.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/reciprocal/simple_reciprocal.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.reciprocal"[[C:.*]] - %1 = "ttir.reciprocal"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.reciprocal"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir index 1d75b8ee0..d6b46aae6 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint #l1 = #ttnn.buffer_type #system = #ttnn.buffer_type #ttnn_layout = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #system>> @@ -10,7 +9,7 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32, #ttnn_layout1> // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] - %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32, #ttnn_layout>, tensor<64x128xf32, #ttnn_layout1>) -> tensor<64x128xf32, #ttnn_layout1> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32, #ttnn_layout>, tensor<64x128xf32, #ttnn_layout1>) -> tensor<64x128xf32, #ttnn_layout1> return %1 : tensor<64x128xf32, #ttnn_layout1> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/rsqrt/simple_rsqrt.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/rsqrt/simple_rsqrt.mlir index f86d6f59e..b7a339d22 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/rsqrt/simple_rsqrt.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/rsqrt/simple_rsqrt.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.rsqrt"[[C:.*]] - %1 = "ttir.rsqrt"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.rsqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/sigmoid/simple_sigmoid.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/sigmoid/simple_sigmoid.mlir index c88457717..d3762db91 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/sigmoid/simple_sigmoid.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/sigmoid/simple_sigmoid.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.sigmoid"[[C:.*]] - %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/sign/simple_sign.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/sign/simple_sign.mlir index c82547bff..170eb1b53 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/sign/simple_sign.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/sign/simple_sign.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.sign"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/sin/simple_sin.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/sin/simple_sin.mlir index dfe4c7a18..a1ebaa368 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/sin/simple_sin.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/sin/simple_sin.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.sin"[[C:.*]] - %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/sqrt/simple_sqrt.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/sqrt/simple_sqrt.mlir index 3802f00da..bd468bd8e 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/sqrt/simple_sqrt.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/sqrt/simple_sqrt.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.sqrt"[[C:.*]] - %1 = "ttir.sqrt"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir index 8ae9f0bec..72d8e1416 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/tan/simple_tan.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.tan"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir index 351476448..530b4c79b 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/tanh/simple_tanh.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> diff --git a/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir b/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir index 45318423b..192697ed7 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x128xbf16> { %0 = tensor.empty() : tensor<32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xbf16>, tensor<512x128xbf16>, tensor<32x128xbf16>) -> tensor<32x128xbf16> + %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32xbf16>, tensor<512x128xbf16>, tensor<32x128xbf16>) -> tensor<32x128xbf16> return %1 : tensor<32x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir b/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir index 1d2813668..cd039a0fb 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/embedding_non_tile.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<1x32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<1x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16> + %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16> return %1 : tensor<1x32x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir index 6404ee6e9..06df7506a 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @gather_0(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] @@ -13,8 +12,7 @@ module attributes {} { start_index_map = array, index_vector_dim = 1 : si64, slice_sizes = array, - indices_are_sorted = false, - operand_constraints = [#any_device, #any_device, #any_device] + indices_are_sorted = false } : (tensor<32000x1024xbf16>, tensor<1x32xi32>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16> return %1 : tensor<1x32x1024xbf16> } @@ -31,8 +29,7 @@ module attributes {} { start_index_map = array, index_vector_dim = 2 : si64, slice_sizes = array, - indices_are_sorted = false, - operand_constraints = [#any_device, #any_device, #any_device] + indices_are_sorted = false }> : (tensor<448x384xbf16>, tensor<1x2x1xi32>, tensor<1x2x384xbf16>) -> tensor<1x2x384xbf16> return %1 : tensor<1x2x384xbf16> } @@ -49,8 +46,7 @@ module attributes {} { start_index_map = array, index_vector_dim = 2 : si64, slice_sizes = array, - indices_are_sorted = false, - operand_constraints = [#any_device, #any_device, #any_device] + indices_are_sorted = false }> : (tensor<51864x384xbf16>, tensor<1x2xi32>, tensor<1x2x384xbf16>) -> tensor<1x2x384xbf16> return %1 : tensor<1x2x384xbf16> } diff --git a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir index 44ffea73e..3e3904a62 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/gather_to_embedding_negative.mlir @@ -3,7 +3,6 @@ // Verify that the parsing fails if the slice_sizes.size <= 1 // ----- -#any_device = #tt.operand_constraint module attributes {} { func.func @negative_slice_sizes_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { // CHECK: error: failed to legalize operation 'ttir.gather' that was explicitly marked illegal @@ -16,8 +15,7 @@ module attributes {} { start_index_map = array, index_vector_dim = 1 : si64, slice_sizes = array, - indices_are_sorted = false, - operand_constraints = [#any_device, #any_device, #any_device] + indices_are_sorted = false } : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> return %1 : tensor<1x32x1024xf32> } @@ -25,7 +23,6 @@ module attributes {} { // Verify that the parsing fails if the slice_sizes.size != [1, hiddenDim] // ----- -#any_device = #tt.operand_constraint module attributes {} { func.func @negative_slice_sizes_1(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { %0 = tensor.empty() : tensor<1x32x1024xf32> @@ -38,8 +35,7 @@ module attributes {} { start_index_map = array, index_vector_dim = 1 : si64, slice_sizes = array, - indices_are_sorted = false, - operand_constraints = [#any_device, #any_device, #any_device] + indices_are_sorted = false } : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> return %1 : tensor<1x32x1024xf32> } @@ -47,7 +43,6 @@ module attributes {} { // Verify that the parsing fails if the offsetDims != [2] // ----- -#any_device = #tt.operand_constraint module attributes {} { func.func @negative_slice_sizes_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { // CHECK: error: failed to legalize operation 'ttir.gather' that was explicitly marked illegal @@ -60,8 +55,7 @@ module attributes {} { start_index_map = array, index_vector_dim = 1 : si64, slice_sizes = array, - indices_are_sorted = false, - operand_constraints = [#any_device, #any_device, #any_device] + indices_are_sorted = false } : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> return %1 : tensor<1x32x1024xf32> } @@ -69,7 +63,6 @@ module attributes {} { // Verify that the parsing fails if collapsed_slice_dims != [0] // ----- -#any_device = #tt.operand_constraint module attributes {} { func.func @negative_collapsed_slice_dims(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { %0 = tensor.empty() : tensor<1x32x1024xf32> @@ -82,8 +75,7 @@ module attributes {} { start_index_map = array, index_vector_dim = 1 : si64, slice_sizes = array, - indices_are_sorted = false, - operand_constraints = [#any_device, #any_device, #any_device] + indices_are_sorted = false } : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> return %1 : tensor<1x32x1024xf32> } @@ -91,7 +83,6 @@ module attributes {} { // Verify that the parsing fails slice_indices != 1 when slice_indices.size == output.size // ----- -#any_device = #tt.operand_constraint module attributes {} { func.func @negative_start_indices(%operand: tensor<448x384xf32>, %start_indices: tensor<1x2x2xi32>) -> tensor<1x2x384xf32> { %0 = tensor.empty() : tensor<1x2x384xf32> @@ -104,8 +95,7 @@ module attributes {} { start_index_map = array, index_vector_dim = 2 : si64, slice_sizes = array, - indices_are_sorted = false, - operand_constraints = [#any_device, #any_device, #any_device] + indices_are_sorted = false }> : (tensor<448x384xf32>, tensor<1x2x2xi32>, tensor<1x2x384xf32>) -> tensor<1x2x384xf32> return %1 : tensor<1x2x384xf32> } @@ -113,7 +103,6 @@ module attributes {} { // Verify that the parsing fails for data type other than bfloat16. // ----- -#any_device = #tt.operand_constraint module attributes {} { func.func @gather_0(%operand: tensor<32000x1024xf32>, %start_indices: tensor<1x32xi32>) -> tensor<1x32x1024xf32> { %0 = tensor.empty() : tensor<1x32x1024xf32> @@ -126,8 +115,7 @@ module attributes {} { start_index_map = array, index_vector_dim = 1 : si64, slice_sizes = array, - indices_are_sorted = false, - operand_constraints = [#any_device, #any_device, #any_device] + indices_are_sorted = false } : (tensor<32000x1024xf32>, tensor<1x32xi32>, tensor<1x32x1024xf32>) -> tensor<1x32x1024xf32> return %1 : tensor<1x32x1024xf32> } diff --git a/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir b/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir index e5fb1421c..e55d9a879 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> { %0 = tensor.empty() : tensor<32x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> + %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> return %1 : tensor<32x32x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir index 0e248623d..ef0a6729e 100644 --- a/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module { func.func @linear_1d_1d(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<1xbf16> { // CHECK: "ttnn.empty" @@ -10,7 +9,7 @@ module { // CHECK-SAME: tensor<128xbf16 // CHECK-SAME: tensor<1xbf16 // CHECK-SAME: tensor<1xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } @@ -24,7 +23,7 @@ module { // CHECK-SAME: tensor<1xbf16 // CHECK-SAME: tensor<1xbf16 // CHECK-SAME: tensor<1xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } @@ -38,7 +37,7 @@ module { // CHECK-SAME: tensor<128xbf16 // CHECK-SAME: tensor<128xbf16 // CHECK-SAME: tensor<128xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>) -> tensor<128xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>, tensor<128xbf16>) -> tensor<128xbf16> return %1 : tensor<128xbf16> } @@ -51,7 +50,7 @@ module { // CHECK-SAME: tensor<128xbf16 // CHECK-SAME: tensor<64xbf16 // CHECK-SAME: tensor<64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } @@ -64,7 +63,7 @@ module { // CHECK-SAME: tensor<128x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } @@ -78,7 +77,7 @@ module { // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } @@ -91,7 +90,7 @@ module { // CHECK-SAME: tensor<12x7x128x64xbf16 // CHECK-SAME: tensor<12x7x64xbf16 // CHECK-SAME: tensor<12x7x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> return %1 : tensor<12x7x64xbf16> } @@ -104,7 +103,7 @@ module { // CHECK-SAME: tensor<64xbf16 // CHECK-SAME: tensor<12x7x128xbf16 // CHECK-SAME: tensor<12x7x128xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> return %1 : tensor<12x7x128xbf16> } @@ -117,7 +116,7 @@ module { // CHECK-SAME: tensor<12x7x128x64xbf16 // CHECK-SAME: tensor<12x7x64x64xbf16 // CHECK-SAME: tensor<12x7x64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> return %1 : tensor<12x7x64x64xbf16> } @@ -130,7 +129,7 @@ module { // CHECK-SAME: tensor<64x128xbf16 // CHECK-SAME: tensor<12x7x128x128xbf16 // CHECK-SAME: tensor<12x7x128x128xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> return %1 : tensor<12x7x128x128xbf16> } @@ -144,7 +143,7 @@ module { // CHECK-SAME: tensor<7x128x64xbf16 // CHECK-SAME: tensor<7x64x64xbf16 // CHECK-SAME: tensor<7x64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } @@ -157,7 +156,7 @@ module { // CHECK-SAME: tensor<1x128x64xbf16 // CHECK-SAME: tensor<7x64x64xbf16 // CHECK-SAME: tensor<7x64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } @@ -170,7 +169,7 @@ module { // CHECK-SAME: tensor<7x1x128x64xbf16 // CHECK-SAME: tensor<7x7x64x64xbf16 // CHECK-SAME: tensor<7x7x64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> return %1 : tensor<7x7x64x64xbf16> } @@ -183,7 +182,7 @@ module { // CHECK-SAME: tensor<7x1x128x64xbf16 // CHECK-SAME: tensor<12x7x7x64x64xbf16 // CHECK-SAME: tensor<12x7x7x64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> return %1 : tensor<12x7x7x64x64xbf16> } @@ -197,7 +196,7 @@ module { // CHECK-SAME: tensor<64xbf16 // CHECK-SAME: tensor<14x7x32x64xbf16 // CHECK-SAME: tensor<14x7x32x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<14x7x32x32xbf16>, tensor<14x1x32x64xbf16>, tensor<64xbf16>, tensor<14x7x32x64xbf16>) -> tensor<14x7x32x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<14x7x32x32xbf16>, tensor<14x1x32x64xbf16>, tensor<64xbf16>, tensor<14x7x32x64xbf16>) -> tensor<14x7x32x64xbf16> return %1 : tensor<14x7x32x64xbf16> } @@ -210,7 +209,7 @@ module { // CHECK-SAME: tensor<4x3x128x32xbf16 // CHECK-SAME: tensor<14x4x3x64x32xbf16 // CHECK-SAME: tensor<14x4x3x64x32xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<3x64x128xbf16>, tensor<4x3x128x32xbf16>, tensor<14x4x3x64x32xbf16>, tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<3x64x128xbf16>, tensor<4x3x128x32xbf16>, tensor<14x4x3x64x32xbf16>, tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16> return %1 : tensor<14x4x3x64x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir index 56728eb52..44165e05d 100644 --- a/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir +++ b/test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module { func.func @simple_linear_without_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { @@ -11,7 +10,7 @@ module { // CHECK-SAME: tensor<128x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } @@ -25,7 +24,7 @@ module { // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir b/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir index 252aaea9f..7ca7efeec 100644 --- a/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir +++ b/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_negative.mlir @@ -2,7 +2,6 @@ // Negative tests for matmul operation // Verify that the parsing fails if either of operands is a scalar -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttnn.matmul' op Input A must be at least a 1D tensor @@ -13,7 +12,6 @@ module attributes {} { } // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor) -> tensor<1xbf16> { // CHECK: error: 'ttnn.matmul' op Input B must be at least a 1D tensor @@ -23,9 +21,8 @@ module attributes {} { } } -// Verifty that the parsing fails if the output is a scalar +// Verify that the parsing fails if the output is a scalar // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor { // CHECK: error: 'ttnn.matmul' op Scalar output is not supported, output must be at least a 1D tensor @@ -36,7 +33,6 @@ module attributes {} { } // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<2xbf16> { // CHECK: error: 'ttnn.matmul' op Scalar output must be a 1D tensor of size 1 @@ -48,7 +44,6 @@ module attributes {} { // Inner dimension mismatch tests // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_1d_1d_inner_dimension_missmatch(%arg0: tensor<128xbf16>, %arg1: tensor<64xbf16>) -> tensor<1xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions @@ -59,7 +54,6 @@ module attributes {} { } // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_1d_2d_inner_dimension_missmatch(%arg0: tensor<64xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](64) and B[-2](128) must have matching inner dimensions @@ -70,7 +64,6 @@ func.func @matmul_negative_1d_2d_inner_dimension_missmatch(%arg0: tensor<64xbf16 } // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_2d_1d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions @@ -81,7 +74,6 @@ module attributes {} { } // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions @@ -92,7 +84,6 @@ module attributes {} { } // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_nd_nd_inner_dimension_missmatch(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x64x128xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Input A[-1](128) and B[-2](64) must have matching inner dimensions @@ -104,7 +95,6 @@ module attributes {} { // Batch dimension mismatch tests // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_nd_nd_same_rank_batch_broadcast_incompatible_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<2x128x64xbf16>) -> tensor<7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Batch dimensions of input A(7) and B(2) are not broadcast compatible @@ -115,7 +105,6 @@ module attributes {} { } // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_nd_nd_same_rank_batch_broadcast_incompatible_2(%arg0: tensor<2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Batch dimensions of input A(2,7) and B(7,1) are not broadcast compatible @@ -126,7 +115,6 @@ module attributes {} { } // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_nd_nd_different_rank_batch_broadcast_incompatible(%arg0: tensor<12x2x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { // CHECK: error: 'ttnn.matmul' op Batch dimensions of input A(12,2,7) and B(7,1) are not broadcast compatible @@ -138,7 +126,6 @@ module attributes {} { // Output shape mismatch tests // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { // CHECK: error: 'ttnn.matmul' op Output shape rank(1) must match the expected output shape rank(2) @@ -149,7 +136,6 @@ module attributes {} { } // ----- -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_negative_2d_2d_inner_dimension_missmatch(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x128xbf16> { // CHECK: error: 'ttnn.matmul' op Output shape dimension[1](128) doesn't match the expected output shape dimension[1](64) diff --git a/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_positive.mlir b/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_positive.mlir index c1921ce8b..a62e53211 100644 --- a/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_positive.mlir +++ b/test/ttmlir/Dialect/TTNN/matmul/matmul_tests_positive.mlir @@ -1,59 +1,58 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @matmul_1d_1d(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>) -> tensor<1xbf16> { %0 = tensor.empty() : tensor<1xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128xbf16>, tensor<1xbf16>) -> tensor<1xbf16> return %1 : tensor<1xbf16> } func.func @matmul_1d_2d(%arg0: tensor<128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64xbf16> { %0 = tensor.empty() : tensor<64xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } func.func @matmul_2d_1d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128xbf16>) -> tensor<64xbf16> { %0 = tensor.empty() : tensor<64xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128xbf16>, tensor<64xbf16>) -> tensor<64xbf16> return %1 : tensor<64xbf16> } func.func @matmul_2d_2d(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } func.func @matmul_1d_nd(%arg0: tensor<128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64xbf16> { %0 = tensor.empty() : tensor<12x7x64xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64xbf16>) -> tensor<12x7x64xbf16> return %1 : tensor<12x7x64xbf16> } func.func @matmul_nd_1d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64xbf16>) -> tensor<12x7x128xbf16> { %0 = tensor.empty() : tensor<12x7x128xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64xbf16>, tensor<12x7x128xbf16>) -> tensor<12x7x128xbf16> return %1 : tensor<12x7x128xbf16> } func.func @matmul_2d_nd(%arg0: tensor<64x128xbf16>, %arg1: tensor<12x7x128x64xbf16>) -> tensor<12x7x64x64xbf16> { %0 = tensor.empty() : tensor<12x7x64x64xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<12x7x128x64xbf16>, tensor<12x7x64x64xbf16>) -> tensor<12x7x64x64xbf16> return %1 : tensor<12x7x64x64xbf16> } func.func @matmul_nd_2d(%arg0: tensor<12x7x128x64xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<12x7x128x128xbf16> { %0 = tensor.empty() : tensor<12x7x128x128xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x7x128x64xbf16>, tensor<64x128xbf16>, tensor<12x7x128x128xbf16>) -> tensor<12x7x128x128xbf16> return %1 : tensor<12x7x128x128xbf16> } @@ -61,28 +60,28 @@ module attributes {} { func.func @matmul_nd_nd_same_rank_same_dims(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<7x128x64xbf16>) -> tensor<7x64x64xbf16> { %0 = tensor.empty() : tensor<7x64x64xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<7x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } func.func @matmul_nd_nd_same_rank_broadcastable_dims_1(%arg0: tensor<7x64x128xbf16>, %arg1: tensor<1x128x64xbf16>) -> tensor<7x64x64xbf16> { %0 = tensor.empty() : tensor<7x64x64xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<7x64x128xbf16>, tensor<1x128x64xbf16>, tensor<7x64x64xbf16>) -> tensor<7x64x64xbf16> return %1 : tensor<7x64x64xbf16> } func.func @matmul_nd_nd_same_rank_broadcastable_dims_2(%arg0: tensor<1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<7x7x64x64xbf16> { %0 = tensor.empty() : tensor<7x7x64x64xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<7x7x64x64xbf16>) -> tensor<7x7x64x64xbf16> return %1 : tensor<7x7x64x64xbf16> } func.func @matmul_nd_nd_different_rank_broadcastable_dims_2(%arg0: tensor<12x1x7x64x128xbf16>, %arg1: tensor<7x1x128x64xbf16>) -> tensor<12x7x7x64x64xbf16> { %0 = tensor.empty() : tensor<12x7x7x64x64xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<12x1x7x64x128xbf16>, tensor<7x1x128x64xbf16>, tensor<12x7x7x64x64xbf16>) -> tensor<12x7x7x64x64xbf16> return %1 : tensor<12x7x7x64x64xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir b/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir index f82ed8575..87db65078 100644 --- a/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir +++ b/test/ttmlir/Dialect/TTNN/matmul/simple_matmul.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint // CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> return %1 : tensor<64x96xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/multiple_func.mlir b/test/ttmlir/Dialect/TTNN/multiple_func.mlir index f23863e5b..3961fac03 100644 --- a/test/ttmlir/Dialect/TTNN/multiple_func.mlir +++ b/test/ttmlir/Dialect/TTNN/multiple_func.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @main(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] @@ -10,7 +9,7 @@ module attributes {} { } func.func private @do_mult(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>, %arg2: tensor<64x128xf32>) -> tensor<64x128xf32> { - %0 = "ttir.multiply"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %0 = "ttir.multiply"(%arg0, %arg1, %arg2) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %0 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir index 4a4575f8d..97892500a 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --mlir-print-debuginfo --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true override-output-layout=matmul_1_in_1_layout=1x1:l1:interleaved:tile:bf16" %s | FileCheck %s -#any_device = #tt.operand_constraint #loc = loc("Matmul":4294967295:0) // CHECK-DAG: #[[LOC_MATMUL_IN0:.*]] = loc("matmul_1_in_0_layout"(#loc3)) // CHECK-DAG: #[[LOC_MATMUL_IN1:.*]] = loc("matmul_1_in_1_layout"(#loc3)) @@ -12,7 +11,7 @@ module attributes {} { // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} loc(#[[LOC_MATMUL_IN0]]) // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<#l1_, <<4x3>>, >}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]]) // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} loc(#[[LOC_MATMUL]]) - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> loc(#loc2) + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> loc(#loc2) return %1 : tensor<64x96xbf16> } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir index ec03a6ad5..6989e765f 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/insert_memreconfig_override.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memreconfig-enabled=true insert-memreconfig=add_0_1_2=0 override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32" %s | FileCheck %s -#any_device = #tt.operand_constraint #loc = loc("test_ops.py:17_0_0":0:0) module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> tensor<1x32x32xf32> { @@ -8,13 +7,13 @@ module attributes {} { // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <1x1>, memref<32x32xf32, #dram>, > %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) // CHECK: %[[IDX:.*]] = "ttnn.to_memory_config"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> // CHECK: %{{.*}} = "ttnn.add"(%[[IDX]]{{.*}} - %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) + %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) - %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) + %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) return %5 : tensor<1x32x32xf32> loc(#loc4) } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir index 5c34fe854..a895ca25e 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/all_l1_interleaved_policy.mlir @@ -1,27 +1,26 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>, %arg2: tensor<64x96xbf16>, %arg3: tensor<96x32xbf16>, %arg4: tensor<64x32xbf16>) -> tensor<64x32xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK: #[[LAYOUT_L1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, > %0 = tensor.empty() : tensor<64x96xbf16> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_L1]]> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> %2 = tensor.empty() : tensor<64x96xbf16> // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_L1]]> - %3 = "ttir.add"(%1, %arg2, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x96xbf16>, tensor<64x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> + %3 = "ttir.add"(%1, %arg2, %2) <{operandSegmentSizes = array}> : (tensor<64x96xbf16>, tensor<64x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> %4 = tensor.empty() : tensor<64x96xbf16> // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_L1]]> - %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> + %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array}> : (tensor<64x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> %6 = tensor.empty() : tensor<64x32xbf16> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_L1]]> - %7 = "ttir.matmul"(%5, %arg3, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x96xbf16>, tensor<96x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> + %7 = "ttir.matmul"(%5, %arg3, %6) : (tensor<64x96xbf16>, tensor<96x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> %8 = tensor.empty() : tensor<64x32xbf16> // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_L1]]> - %9 = "ttir.add"(%7, %arg4, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> + %9 = "ttir.add"(%7, %arg4, %8) <{operandSegmentSizes = array}> : (tensor<64x32xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> %10 = tensor.empty() : tensor<64x32xbf16> // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_L1]]> - %11 = "ttir.relu"(%9, %10) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> + %11 = "ttir.relu"(%9, %10) <{operandSegmentSizes = array}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> return %11 : tensor<64x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir index 67c480d8c..d9336db8a 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/fork_join.mlir @@ -18,7 +18,6 @@ // the optimizer should choose the one with lower requiredL1Usage. In // this case, [E, C] should be chosen. // -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x64xbf16>, %arg1: tensor<64x32xbf16>) -> tensor<64x32xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type @@ -26,19 +25,19 @@ module attributes {} { // CHECK: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<64x64xbf16> // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_3]]> - %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> %2 = tensor.empty() : tensor<64x64xbf16> - %3 = "ttir.relu"(%1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %3 = "ttir.relu"(%1, %2) <{operandSegmentSizes = array}> : (tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> %4 = tensor.empty() : tensor<64x32xbf16> - %5 = "ttir.matmul"(%1, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x64xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> + %5 = "ttir.matmul"(%1, %arg1, %4) : (tensor<64x64xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> %6 = tensor.empty() : tensor<64x32xbf16> - %7 = "ttir.relu"(%5, %6) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> + %7 = "ttir.relu"(%5, %6) <{operandSegmentSizes = array}> : (tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> %8 = tensor.empty() : tensor<64x32xbf16> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_5]]> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x32xbf16, #[[LAYOUT_5]]> - %9 = "ttir.matmul"(%3, %7, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x64xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> + %9 = "ttir.matmul"(%3, %7, %8) : (tensor<64x64xbf16>, tensor<64x32xbf16>, tensor<64x32xbf16>) -> tensor<64x32xbf16> return %9 : tensor<64x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir index f45c11c62..3d437a74f 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/mnist_l1_interleaved.mlir @@ -1,27 +1,26 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s -#any_device = #tt.operand_constraint #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { // CHECK: #[[LAYOUT_L1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<{{.*}}, #l1_>, > %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_L1]]> - %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) + %1 = "ttir.matmul"(%arg0, %arg4, %0) : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) %2 = tensor.empty() : tensor<1x256xf32> loc(#loc9) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_L1]]> - %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) + %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) %4 = tensor.empty() : tensor<1x256xf32> loc(#loc10) // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_L1]]> - %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) + %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) %6 = tensor.empty() : tensor<1x10xf32> loc(#loc11) // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_L1]]> - %7 = "ttir.matmul"(%5, %arg2, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) + %7 = "ttir.matmul"(%5, %arg2, %6) : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) %8 = tensor.empty() : tensor<1x10xf32> loc(#loc12) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_L1]]> - %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) + %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) %10 = tensor.empty() : tensor<1x10xf32> loc(#loc13) // CHECK: %{{.*}} = "ttnn.softmax"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_L1]]> - %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) + %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) return %11 : tensor<1x10xf32> loc(#loc7) } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir index e5a4f3fa6..ecd90f1ab 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_ABC_l1_None.mlir @@ -10,19 +10,18 @@ // => // DRAM: ABC; L1: None // -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<8192x8192xbf16>, %arg1: tensor<8192x8192xbf16>, %arg2: tensor<8192x8192xbf16>, %arg3: tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> { // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<{{.*}}>, #dram>, > %0 = tensor.empty() : tensor<8192x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x8192xbf16, #[[LAYOUT_2]]> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> %2 = tensor.empty() : tensor<8192x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x8192xbf16, #[[LAYOUT_2]]> - %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> %4 = tensor.empty() : tensor<8192x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<8192x8192xbf16, #[[LAYOUT_2]]> - %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> + %5 = "ttir.matmul"(%1, %3, %4) : (tensor<8192x8192xbf16>, tensor<8192x8192xbf16>, tensor<8192x8192xbf16>) -> tensor<8192x8192xbf16> return %5 : tensor<8192x8192xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir index ceca62840..056ded8d3 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir @@ -10,7 +10,6 @@ // => // DRAM: AB; L1: C // -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<5120x4096xbf16>, %arg1: tensor<5120x4096xbf16>, %arg2: tensor<4096x5120xbf16>, %arg3: tensor<4096x5120xbf16>) -> tensor<5120x5120xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type @@ -19,13 +18,13 @@ module attributes {} { // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<5120x4096xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_4]]> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x4096xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<5120x4096xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> %2 = tensor.empty() : tensor<4096x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_6]]> - %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> %4 = tensor.empty() : tensor<5120x5120xbf16> // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_7]]> - %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x4096xbf16>, tensor<4096x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> + %5 = "ttir.matmul"(%1, %3, %4) : (tensor<5120x4096xbf16>, tensor<4096x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> return %5 : tensor<5120x5120xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir index 74675e4e0..caaf3254d 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir @@ -10,7 +10,6 @@ // => // DRAM: AC; L1: B // -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<4096x5120xbf16>, %arg2: tensor<5120x5120xbf16>, %arg3: tensor<5120x5120xbf16>) -> tensor<4096x5120xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type @@ -18,13 +17,13 @@ module attributes {} { // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<4096x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_3]]> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> %2 = tensor.empty() : tensor<5120x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_5]]> - %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> %4 = tensor.empty() : tensor<4096x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_3]]> - %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<4096x5120xbf16>, tensor<5120x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> + %5 = "ttir.matmul"(%1, %3, %4) : (tensor<4096x5120xbf16>, tensor<5120x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> return %5 : tensor<4096x5120xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir index c3cd2740b..63cd3bcaa 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir @@ -10,7 +10,6 @@ // => // DRAM: A; L1: BC // -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<2048x2048xbf16>, %arg1: tensor<2048x2048xbf16>, %arg2: tensor<2048x8192xbf16>, %arg3: tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type @@ -18,13 +17,13 @@ module attributes {} { // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<2048x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_3]]> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x2048xbf16>, tensor<2048x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<2048x2048xbf16>, tensor<2048x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> %2 = tensor.empty() : tensor<2048x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x8192xbf16, #[[LAYOUT_5]]> - %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x8192xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<2048x8192xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> %4 = tensor.empty() : tensor<2048x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<2048x8192xbf16, #[[LAYOUT_5]]> - %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x2048xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> + %5 = "ttir.matmul"(%1, %3, %4) : (tensor<2048x2048xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> return %5 : tensor<2048x8192xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir index c9cd33f1c..9f12e8b6f 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir @@ -10,7 +10,6 @@ // => // DRAM: BC; L1: A // -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<5120x5120xbf16>, %arg1: tensor<5120x5120xbf16>, %arg2: tensor<5120x4096xbf16>, %arg3: tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type @@ -18,13 +17,13 @@ module attributes {} { // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<5120x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_5]]> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> %2 = tensor.empty() : tensor<5120x4096xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_3]]> - %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x4096xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<5120x4096xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> %4 = tensor.empty() : tensor<5120x4096xbf16> // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_3]]> - %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<5120x5120xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> + %5 = "ttir.matmul"(%1, %3, %4) : (tensor<5120x5120xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> return %5 : tensor<5120x4096xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir index 760ea2b8a..c594ca418 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir @@ -10,7 +10,6 @@ // => // DRAM: B; L1: AC // -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<8192x2048xbf16>, %arg1: tensor<8192x2048xbf16>, %arg2: tensor<2048x2048xbf16>, %arg3: tensor<2048x2048xbf16>) -> tensor<8192x2048xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type @@ -18,13 +17,13 @@ module attributes {} { // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<8192x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_5]]> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> %2 = tensor.empty() : tensor<2048x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_3]]> - %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x2048xbf16>, tensor<2048x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<2048x2048xbf16>, tensor<2048x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> %4 = tensor.empty() : tensor<8192x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_5]]> - %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x2048xbf16>, tensor<2048x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> + %5 = "ttir.matmul"(%1, %3, %4) : (tensor<8192x2048xbf16>, tensor<2048x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> return %5 : tensor<8192x2048xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir index 5d95a6204..eb2a51b17 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir @@ -10,7 +10,6 @@ // => // DRAM: C; L1: AB // -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<2048x8192xbf16>, %arg1: tensor<2048x8192xbf16>, %arg2: tensor<8192x2048xbf16>, %arg3: tensor<8192x2048xbf16>) -> tensor<2048x2048xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type @@ -19,13 +18,13 @@ module attributes {} { // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > %0 = tensor.empty() : tensor<2048x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x8192xbf16, #[[LAYOUT_4]]> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x8192xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<2048x8192xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> %2 = tensor.empty() : tensor<8192x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_6]]> - %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> %4 = tensor.empty() : tensor<2048x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_7]]> - %5 = "ttir.matmul"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2048x8192xbf16>, tensor<8192x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> + %5 = "ttir.matmul"(%1, %3, %4) : (tensor<2048x8192xbf16>, tensor<8192x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> return %5 : tensor<2048x2048xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir index 75b876dbf..883842694 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_None_l1_ABC.mlir @@ -10,20 +10,19 @@ // => // DRAM: None; L1: ABC // -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>, %arg2: tensor<32x32xbf16>, %arg3: tensor<32x32xbf16>) -> tensor<32x32xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_2:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<32x32xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x32xbf16, #[[LAYOUT_2]]> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> %2 = tensor.empty() : tensor<32x32xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x32xbf16, #[[LAYOUT_2]]> - %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> %4 = tensor.empty() : tensor<32x32xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<32x32xbf16, #[[LAYOUT_2]]> - %5 = "ttir.add"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %5 = "ttir.add"(%1, %3, %4) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %5 : tensor<32x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/single_op.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/single_op.mlir index 482079993..7b8aa0759 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/single_op.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/single_op.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=L1Interleaved" %s | FileCheck %s // UNSUPPORTED: true -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> { %0 = tensor.empty() : tensor<5120x5120xbf16> - %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> return %1 : tensor<5120x5120xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir b/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir index 8e25f97ca..66e1ec083 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/multiple_add_with_loc.mlir @@ -1,18 +1,17 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true" %s | FileCheck %s -#any_device = #tt.operand_constraint #loc = loc("test_ops.py:17_0_0":0:0) module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { // CHECK: #[[LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x1x!tt.tile<32x32, f32>, #dram>, > %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT]]> - %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) + %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT]]> - %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) + %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT]]> - %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) + %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) // CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #ttnn_layout>, tensor<1x32x32xf32, #ttnn_layout> return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir index 79bbae275..91f38d446 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/output_layout_override.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true override-output-layout=add_1_0=4x4:dram:interleaved:row_major:bf16,add_2_0=4x4:l1:interleaved:tile:f32" %s | FileCheck %s -#any_device = #tt.operand_constraint #loc = loc("test_ops.py:17_0_0":0:0) module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { @@ -10,13 +9,13 @@ module attributes {} { // CHECK: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<{{.*}} #dram>, > %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> - %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) + %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]> - %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) + %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) %4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_3]]> - %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) + %5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7) // CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #[[LAYOUT_0]]>, tensor<1x32x32xf32, #[[LAYOUT_0]]> return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/partial_output_layout_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/partial_output_layout_override.mlir index c1e79c7a0..7b57882df 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/partial_output_layout_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/partial_output_layout_override.mlir @@ -1,14 +1,13 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true override-output-layout=add_1=row_major" %s | FileCheck %s -#any_device = #tt.operand_constraint #loc = loc("test_ops.py:17_0_0":0:0) module attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<{{.*}}memref<4x4xf32{{.*}} %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]> - %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) + %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) %2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6) - %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) + %3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6) return %1, %3 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4) } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_0.mlir b/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_0.mlir index fbe476cba..e893e5d2c 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_0.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_0.mlir @@ -1,14 +1,13 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true max-legal-layouts=0" %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>, %arg2: tensor<96x64xbf16>) -> tensor<64x64xbf16> { // CHECK: #[[LAYOUT_7:ttnn_layout7]] = #ttnn.ttnn_layout<{{.*}}, memref<{{.*}}, #dram>, {{.*}}> %0 = tensor.empty() : tensor<64x96xbf16> // CHECK: {{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_7]]> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> %2 = tensor.empty() : tensor<64x64xbf16> // CHECK: {{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x64xbf16, #[[LAYOUT_7]]> - %3 = "ttir.matmul"(%1, %arg2, %2) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x96xbf16>, tensor<96x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %3 = "ttir.matmul"(%1, %arg2, %2) : (tensor<64x96xbf16>, tensor<96x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %3 : tensor<64x64xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_32.mlir b/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_32.mlir index 8c372be46..aa4616360 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_32.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/sharding_matmul_override_32.mlir @@ -1,14 +1,13 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true max-legal-layouts=32" %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>, %arg2: tensor<96x64xbf16>) -> tensor<64x64xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK: #[[LAYOUT_7:ttnn_layout7]] = #ttnn.ttnn_layout<{{.*}}, memref<{{.*}}, #l1_>, {{.*}}> %0 = tensor.empty() : tensor<64x96xbf16> // CHECK: {{.*}} = "ttnn.matmul"{{.*}} -> tensor<64x96xbf16, #[[LAYOUT_7]]> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> %2 = tensor.empty() : tensor<64x64xbf16> - %3 = "ttir.matmul"(%1, %arg2, %2) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x96xbf16>, tensor<96x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %3 = "ttir.matmul"(%1, %arg2, %2) : (tensor<64x96xbf16>, tensor<96x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %3 : tensor<64x64xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/optimizer/ttir_to_ttnn_pipeline.mlir b/test/ttmlir/Dialect/TTNN/optimizer/ttir_to_ttnn_pipeline.mlir index c12fc0771..5d924a919 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/ttir_to_ttnn_pipeline.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/ttir_to_ttnn_pipeline.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true" %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %{{.*}} = "ttnn.empty"{{.*}} %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %{{.*}} = "ttnn.multiply"{{.*}} - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/pipelines/ttir_to_emitc_add.mlir b/test/ttmlir/Dialect/TTNN/pipelines/ttir_to_emitc_add.mlir index d59a59ea6..21c665ae2 100644 --- a/test/ttmlir/Dialect/TTNN/pipelines/ttir_to_emitc_add.mlir +++ b/test/ttmlir/Dialect/TTNN/pipelines/ttir_to_emitc_add.mlir @@ -5,10 +5,8 @@ // This test checks that the (TTIR to EmitC pipeline) is equivalent to (TTIR to TTNN pipeline + dialect conversion from TTNN to EmitC). // The `diff` command will return 0 if files are identical, otherwise it will return the diff, which will make `llvm-lit` treat the test as failed. -#any_device_tile = #tt.operand_constraint - func.func @add(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir b/test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir index cbc188d34..f7d492d5f 100644 --- a/test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir +++ b/test/ttmlir/Dialect/TTNN/pooling/complex_pooling.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32x128x128xbf16>, %arg1: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { %0 = tensor.empty() : tensor<1x32x64x64xbf16> @@ -12,11 +11,10 @@ module attributes {} { window_strides = array, base_dilations = array, window_dilations = array, - padding = array, - operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x128x128xbf16>, tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) + padding = array}> : (tensor<1x32x128x128xbf16>, tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) %4 = tensor.empty() : tensor<1x32x64x64xbf16> - %6 = "ttir.add"(%2, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> + %6 = "ttir.add"(%2, %3, %4) <{operandSegmentSizes = array}> : (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> return %6 : tensor<1x32x64x64xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir b/test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir index 5116b2bfb..dc48662d7 100644 --- a/test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir +++ b/test/ttmlir/Dialect/TTNN/pooling/simple_maxpool2d.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> { %0 = tensor.empty() : tensor<1x64x64x32xbf16> // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] - %1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16> + %1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16> return %1 : tensor<1x64x64x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir b/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir index b002f8db2..2f4b65ce6 100644 --- a/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir +++ b/test/ttmlir/Dialect/TTNN/pooling/simple_pooling.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { %0 = tensor.empty() : tensor<1x32x64x64xbf16> @@ -11,8 +10,7 @@ module attributes {} { window_strides = array, base_dilations = array, window_dilations = array, - padding = array, - operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> + padding = array}> : (tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> return %1 : tensor<1x32x64x64xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/remove_empty_op.mlir b/test/ttmlir/Dialect/TTNN/remove_empty_op.mlir index 9640d91e2..19c9bb9a7 100644 --- a/test/ttmlir/Dialect/TTNN/remove_empty_op.mlir +++ b/test/ttmlir/Dialect/TTNN/remove_empty_op.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> { // CHECK-NOT: "ttnn.empty" %0 = tensor.empty() : tensor<2x4x32x32xbf16> // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] - %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> return %1 : tensor<2x4x32x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir index c7f4442f0..cc5d67cf2 100644 --- a/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir +++ b/test/ttmlir/Dialect/TTNN/reshape/reshape_folding_test.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s -#any_device_tile = #tt.operand_constraint // Tests if we fold when translating from "ttir.reshape" which is called on the two same shapes. module @reshape_test { func.func @main(%arg0: tensor<1xi32>) -> (tensor<1xi32> {jax.result_info = ""}) { %0 = tensor.empty() : tensor<1xi32> - %1 = "ttir.reshape"(%arg0, %0) <{operand_constraints = [#any_device_tile, #any_device_tile], shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32]}> : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK-NOT: %[[C:.*]] = "ttnn.reshape"[C:.*]] // CHECK: return %arg0 : tensor<1xi32, #{{.*}}> return %1 : tensor<1xi32> diff --git a/test/ttmlir/Dialect/TTNN/simple_broadcast.mlir b/test/ttmlir/Dialect/TTNN/simple_broadcast.mlir index e7aac7e2e..251924caa 100644 --- a/test/ttmlir/Dialect/TTNN/simple_broadcast.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_broadcast.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module @jit_broadcast attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<1xf32> {mhlo.layout_mode = "default"}, %arg1: tensor<512x512xf32> {mhlo.layout_mode = "default"}) -> (tensor<512x512xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] // CHECK-NOT: %[[C:.*]] = "ttnn.broadcast"[[C:.*]] %0 = tensor.empty() : tensor<512x512xf32> - %1 = "ttir.broadcast"(%arg0, %0) <{dimension = [1], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> + %1 = "ttir.broadcast"(%arg0, %0) <{dimension = [1]}> : (tensor<1xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> %2 = tensor.empty() : tensor<512x512xf32> - %3 = "ttir.maximum"(%1, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> + %3 = "ttir.maximum"(%1, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<512x512xf32>, tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> return %3 : tensor<512x512xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_clamp.mlir b/test/ttmlir/Dialect/TTNN/simple_clamp.mlir index 18da70dd8..272e07175 100644 --- a/test/ttmlir/Dialect/TTNN/simple_clamp.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_clamp.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> @@ -8,7 +7,7 @@ module attributes {} { // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #ttnn_layout{{[0-9]+}}>) -> [[TENSOR]] - %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_compare.mlir b/test/ttmlir/Dialect/TTNN/simple_compare.mlir index 3a0ce12ec..873ae745c 100644 --- a/test/ttmlir/Dialect/TTNN/simple_compare.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_compare.mlir @@ -1,7 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s - -#any_device = #tt.operand_constraint - module attributes {} { func.func @equal(%arg0: tensor<13x31xf32>, %arg1: tensor<13x31xf32>) -> tensor<13x31xf32> { // CHECK: %[[C:.*]] = "ttnn.empty" @@ -12,7 +9,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } @@ -27,7 +24,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.ne"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.ne"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } @@ -42,7 +39,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } @@ -57,7 +54,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.gt"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.gt"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } @@ -72,7 +69,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.le"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.le"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } @@ -87,7 +84,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.lt"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.lt"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_constant.mlir b/test/ttmlir/Dialect/TTNN/simple_constant.mlir index 017a1baf0..53de9a5ee 100644 --- a/test/ttmlir/Dialect/TTNN/simple_constant.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_constant.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @test_empty_int8() -> tensor<64x128xi8> { %0 = "ttir.constant"() <{value = dense<0> : tensor<64x128xi8>}> : () -> tensor<64x128xi8> diff --git a/test/ttmlir/Dialect/TTNN/simple_div.mlir b/test/ttmlir/Dialect/TTNN/simple_div.mlir index 2dad76003..15d2b4820 100644 --- a/test/ttmlir/Dialect/TTNN/simple_div.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_div.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.div"[[C:.*]] - %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_get_dimension_size.mlir b/test/ttmlir/Dialect/TTNN/simple_get_dimension_size.mlir index 6b37e89d7..f3bd6dab0 100644 --- a/test/ttmlir/Dialect/TTNN/simple_get_dimension_size.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_get_dimension_size.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> { %0 = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<13x21x3xf32>) -> tensor<1xi32> diff --git a/test/ttmlir/Dialect/TTNN/simple_max.mlir b/test/ttmlir/Dialect/TTNN/simple_max.mlir index ce791beb4..34a0120b2 100644 --- a/test/ttmlir/Dialect/TTNN/simple_max.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_max.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<512x32xbf16>) -> tensor<512xbf16> { %0 = tensor.empty() : tensor<512xbf16> // CHECK: %[[C:.*]] = "ttnn.max"[[C:.*]] - %1 = "ttir.max"(%arg0, %0) <{dim_arg = [1: i32], keep_dim = false, operand_constraints = [#any_device, #any_device]}> : (tensor<512x32xbf16>, tensor<512xbf16>) -> tensor<512xbf16> + %1 = "ttir.max"(%arg0, %0) <{dim_arg = [1: i32], keep_dim = false}> : (tensor<512x32xbf16>, tensor<512xbf16>) -> tensor<512xbf16> return %1 : tensor<512xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_maximum.mlir b/test/ttmlir/Dialect/TTNN/simple_maximum.mlir index 2cf8b525a..cd87754fa 100644 --- a/test/ttmlir/Dialect/TTNN/simple_maximum.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_maximum.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.maximum"[[C:.*]] - %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_mean.mlir b/test/ttmlir/Dialect/TTNN/simple_mean.mlir index a0fe0523a..efcba0a13 100644 --- a/test/ttmlir/Dialect/TTNN/simple_mean.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_mean.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module { func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> { %0 = tensor.empty() : tensor<512x32xbf16> // CHECK: %[[C:.*]] = "ttnn.mean"[[C:.*]] - %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> + %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> return %1 : tensor<512x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_multiply.mlir b/test/ttmlir/Dialect/TTNN/simple_multiply.mlir index 8421d9689..795f65efe 100644 --- a/test/ttmlir/Dialect/TTNN/simple_multiply.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_multiply.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_reshape.mlir b/test/ttmlir/Dialect/TTNN/simple_reshape.mlir index 6b7c0edfe..29e651239 100644 --- a/test/ttmlir/Dialect/TTNN/simple_reshape.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_reshape.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> { %0 = tensor.empty() : tensor<2x4x32x32xbf16> // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] - %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> return %1 : tensor<2x4x32x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_scatter.mlir b/test/ttmlir/Dialect/TTNN/simple_scatter.mlir index 5991efeab..22ad5c2d0 100644 --- a/test/ttmlir/Dialect/TTNN/simple_scatter.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_scatter.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { %0 = tensor.empty() : tensor<1x3x320x320xf32> %1 = tensor.empty() : tensor<1x1xi32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, shape = #ttnn.shape<[[TENSOR_SHAPE0:[0-9]+x[0-9]+x[0-9]+x[0-9]+]]>}> : (!tt.device<#device>) -> tensor<[[TENSOR_SHAPE1:[0-9]+x[0-9]+x[0-9]+x[0-9]+xf[0-9]+]], {{.*}}> - %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ + %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): "ttir.yield"(%arg4) : (tensor<1xf32>) -> () }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> diff --git a/test/ttmlir/Dialect/TTNN/simple_slice.mlir b/test/ttmlir/Dialect/TTNN/simple_slice.mlir index cc4e2063c..d8ff26bc3 100644 --- a/test/ttmlir/Dialect/TTNN/simple_slice.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_slice.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<4x32x32xbf16>) -> tensor<2x16x16xbf16> { %0 = tensor.empty() : tensor<2x16x16xbf16> // CHECK: %[[C:.*]] = "ttnn.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [2: i32, 16: i32, 16: i32], step = [1: i32, 1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x32x32xbf16>, tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [2: i32, 16: i32, 16: i32], step = [1: i32, 1: i32, 1: i32]}> : (tensor<4x32x32xbf16>, tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16> return %1 : tensor<2x16x16xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_squeeze.mlir b/test/ttmlir/Dialect/TTNN/simple_squeeze.mlir index 34367c473..e8bac061e 100644 --- a/test/ttmlir/Dialect/TTNN/simple_squeeze.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_squeeze.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x2x1x32x32xbf16>) -> tensor<1x2x32x32xbf16> { %0 = tensor.empty() : tensor<1x2x32x32xbf16> // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] - %1 = "ttir.squeeze"(%arg0, %0) <{dim = -3 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x2x1x32x32xbf16>, tensor<1x2x32x32xbf16>) -> tensor<1x2x32x32xbf16> + %1 = "ttir.squeeze"(%arg0, %0) <{dim = -3 : si32}> : (tensor<1x2x1x32x32xbf16>, tensor<1x2x32x32xbf16>) -> tensor<1x2x32x32xbf16> return %1 : tensor<1x2x32x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_subtract.mlir b/test/ttmlir/Dialect/TTNN/simple_subtract.mlir index 9716ac291..f4c69ea40 100644 --- a/test/ttmlir/Dialect/TTNN/simple_subtract.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_subtract.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]] - %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_subtract_to_add.mlir b/test/ttmlir/Dialect/TTNN/simple_subtract_to_add.mlir index 59c4eb901..4703a1fdd 100644 --- a/test/ttmlir/Dialect/TTNN/simple_subtract_to_add.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_subtract_to_add.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<1x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] @@ -7,7 +6,7 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.neg"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] // CHECK-NOT: %[[C:.*]] = "ttnn.subtract"[[C:.*]] - %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<1x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<1x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_sum.mlir b/test/ttmlir/Dialect/TTNN/simple_sum.mlir index 1b183dee6..2b107b068 100644 --- a/test/ttmlir/Dialect/TTNN/simple_sum.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_sum.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> { %0 = tensor.empty() : tensor<512x32xbf16> // CHECK: %[[C:.*]] = "ttnn.sum"[[C:.*]] - %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> + %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> return %1 : tensor<512x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_unsqueeze.mlir b/test/ttmlir/Dialect/TTNN/simple_unsqueeze.mlir index 2400b6b5e..95daab27e 100644 --- a/test/ttmlir/Dialect/TTNN/simple_unsqueeze.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_unsqueeze.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s| FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<4x2x32x32xbf16>) -> tensor<4x1x2x32x32xbf16> { %0 = tensor.empty() : tensor<4x1x2x32x32xbf16> // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] - %1 = "ttir.unsqueeze"(%arg0, %0) <{dim = -4 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x32xbf16>, tensor<4x1x2x32x32xbf16>) -> tensor<4x1x2x32x32xbf16> + %1 = "ttir.unsqueeze"(%arg0, %0) <{dim = -4 : si32}> : (tensor<4x2x32x32xbf16>, tensor<4x1x2x32x32xbf16>) -> tensor<4x1x2x32x32xbf16> return %1 : tensor<4x1x2x32x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/simple_where.mlir b/test/ttmlir/Dialect/TTNN/simple_where.mlir index a535a4fd9..c75c7f817 100644 --- a/test/ttmlir/Dialect/TTNN/simple_where.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_where.mlir @@ -1,11 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module @jit_eltwise_where { func.func public @test_where(%arg0: tensor<13x37xf32>, %arg1: tensor<13x37xf32>) -> tensor<13x37xf32> { %0 = tensor.empty() : tensor<13x37xf32> - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> %2 = tensor.empty() : tensor<13x37xf32> - %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>, tensor<13x37xf32>) -> tensor<13x37xf32> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) diff --git a/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir b/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir index ec05a3006..0d7bfc90a 100644 --- a/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir +++ b/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir @@ -1,15 +1,14 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> { %0 = tensor.empty() : tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for positive dimension attribute - %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> + %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> %2 = tensor.empty() : tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for negative dimension attribute - %3 = "ttir.softmax"(%1, %2) <{dimension = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> + %3 = "ttir.softmax"(%1, %2) <{dimension = -1 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> return %3 : tensor<512x1024xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/softmax/softmax_negative_1.mlir b/test/ttmlir/Dialect/TTNN/softmax/softmax_negative_1.mlir index 8e9b2f083..ce302b139 100644 --- a/test/ttmlir/Dialect/TTNN/softmax/softmax_negative_1.mlir +++ b/test/ttmlir/Dialect/TTNN/softmax/softmax_negative_1.mlir @@ -1,10 +1,9 @@ // RUN: not ttmlir-opt --ttir-load-system-desc --ttir-layout --convert-ttir-to-ttnn %s 2>&1 | FileCheck %s // CHECK: error: 'ttir.softmax' op Dimension attribute must be within the bounds of the input tensor -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> { %0 = tensor.empty() : tensor<512x1024xbf16> - %1 = "ttir.softmax"(%arg0, %0) <{dimension = 2 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> + %1 = "ttir.softmax"(%arg0, %0) <{dimension = 2 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> return %1 : tensor<512x1024xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/softmax/softmax_negative_2.mlir b/test/ttmlir/Dialect/TTNN/softmax/softmax_negative_2.mlir index 43a0a97f3..e42e0335c 100644 --- a/test/ttmlir/Dialect/TTNN/softmax/softmax_negative_2.mlir +++ b/test/ttmlir/Dialect/TTNN/softmax/softmax_negative_2.mlir @@ -1,10 +1,9 @@ // RUN: not ttmlir-opt --ttir-load-system-desc --ttir-layout --convert-ttir-to-ttnn %s 2>&1 | FileCheck %s // CHECK: error: 'ttir.softmax' op Dimension attribute must be within the bounds of the input tensor -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> { %0 = tensor.empty() : tensor<512x1024xbf16> - %1 = "ttir.softmax"(%arg0, %0) <{dimension = -3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> + %1 = "ttir.softmax"(%arg0, %0) <{dimension = -3 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> return %1 : tensor<512x1024xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir index fbf377df1..0e495edd7 100644 --- a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir +++ b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>) -> tensor<128x64xbf16> { %0 = tensor.empty() : tensor<128x64xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> return %1 : tensor<128x64xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x16_reverse_dims.mlir b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x16_reverse_dims.mlir index 70640d041..942551805 100644 --- a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x16_reverse_dims.mlir +++ b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x16_reverse_dims.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x16xbf16>) -> tensor<16x64xbf16> { %0 = tensor.empty() : tensor<16x64xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 1 : si32, dim1 = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x16xbf16>, tensor<16x64xbf16>) -> tensor<16x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 1 : si32, dim1 = 0 : si32}> : (tensor<64x16xbf16>, tensor<16x64xbf16>) -> tensor<16x64xbf16> return %1 : tensor<16x64xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x8.mlir b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x8.mlir index b9cedf226..3cfddd7ee 100644 --- a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x8.mlir +++ b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x8.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %1 : tensor<32x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_negative_dims.mlir b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_negative_dims.mlir index 035475bc4..74506b10c 100644 --- a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_negative_dims.mlir +++ b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_negative_dims.mlir @@ -1,10 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - %1 = "ttir.transpose"(%arg0, %0) <{dim0 = -1 : si32, dim1 = -2 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = -1 : si32, dim1 = -2 : si32}> : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %1 : tensor<32x32xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/transpose/transpose_twice.mlir b/test/ttmlir/Dialect/TTNN/transpose/transpose_twice.mlir index f18e0e5cb..b78d86b01 100644 --- a/test/ttmlir/Dialect/TTNN/transpose/transpose_twice.mlir +++ b/test/ttmlir/Dialect/TTNN/transpose/transpose_twice.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x16x32x64xbf16>) -> tensor<1x32x64x16xbf16> { %0 = tensor.empty() : tensor<1x64x32x16xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - %1 = "ttir.transpose"(%arg0, %0) <{dim0 = -3 : si32, dim1 = -1 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x16x32x64xbf16>, tensor<1x64x32x16xbf16>) -> tensor<1x64x32x16xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = -3 : si32, dim1 = -1 : si32}> : (tensor<1x16x32x64xbf16>, tensor<1x64x32x16xbf16>) -> tensor<1x64x32x16xbf16> %2 = tensor.empty() : tensor<1x32x64x16xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose - %3 = "ttir.transpose"(%1, %2) <{dim0 = -3 : si32, dim1 = -2 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x64x32x16xbf16>, tensor<1x32x64x16xbf16>) -> tensor<1x32x64x16xbf16> + %3 = "ttir.transpose"(%1, %2) <{dim0 = -3 : si32, dim1 = -2 : si32}> : (tensor<1x64x32x16xbf16>, tensor<1x32x64x16xbf16>) -> tensor<1x32x64x16xbf16> return %3 : tensor<1x32x64x16xbf16> } } diff --git a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir index 112a941a8..63d263365 100644 --- a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir +++ b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir @@ -1,12 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=false" %s | FileCheck %s -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<64x128xf32, #dram>, > + // CHECK: #[[LAYOUT_1:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, f32>, #dram>, > // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_1:.*]]> - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } } diff --git a/test/ttmlir/Silicon/TTMetal/simple_constant.mlir b/test/ttmlir/Silicon/TTMetal/simple_constant.mlir index e7556331c..3f9825031 100644 --- a/test/ttmlir/Silicon/TTMetal/simple_constant.mlir +++ b/test/ttmlir/Silicon/TTMetal/simple_constant.mlir @@ -2,13 +2,11 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttmetal-to-flatbuffer %t.mlir > %t.ttm -#any_device = #tt.operand_constraint - func.func public @add5(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] // CHECK: %[[C:.*]] = "ttmetal.host_write"[[C:.*]] %0 = "ttir.constant"() <{value = dense<5.0> : tensor<32x32xf32>}> : () -> tensor<32x32xf32> %1 = tensor.empty() : tensor<32x32xf32> - %2 = "ttir.add"(%arg0, %0, %1) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + %2 = "ttir.add"(%arg0, %0, %1) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %2 : tensor<32x32xf32> } diff --git a/test/ttmlir/Silicon/TTMetal/simple_eltwise.mlir b/test/ttmlir/Silicon/TTMetal/simple_eltwise.mlir index 4b1c3c39f..b9b10706b 100644 --- a/test/ttmlir/Silicon/TTMetal/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTMetal/simple_eltwise.mlir @@ -1,14 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttmetal-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttmetal-to-flatbuffer %t.mlir > %t.ttm - -#any_device = #tt.operand_constraint - func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -16,7 +13,7 @@ func.func @add(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<6 // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -24,7 +21,7 @@ func.func @exp(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] - %1 = "ttir.exp"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.exp"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -32,6 +29,6 @@ func.func @div(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<6 // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] - %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTMetal/simple_max.mlir b/test/ttmlir/Silicon/TTMetal/simple_max.mlir index 92bdbe72c..b8dcae064 100644 --- a/test/ttmlir/Silicon/TTMetal/simple_max.mlir +++ b/test/ttmlir/Silicon/TTMetal/simple_max.mlir @@ -2,12 +2,10 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttmetal-to-flatbuffer %t.mlir > %t.ttm -#any_device = #tt.operand_constraint - func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] - %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir b/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir index cdde621c2..a6ab52acb 100644 --- a/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir +++ b/test/ttmlir/Silicon/TTMetal/simple_reduce.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttmetal-backend-pipeline="system-desc-path=%system_desc_path%" %s | FileCheck %s -#any_device = #tt.operand_constraint #l1_ = #tt.memory_space #layout1 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x4>, memref<64x96xf32, #l1_>> #layout2 = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <4x1>, memref<64x32xf32, #l1_>> @@ -9,8 +8,7 @@ func.func @reduceW(%arg0: tensor<256x384xf32, #layout1>) -> tensor<256x32xf32, # // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, dim_arg = [-1: i32], - keep_dim = true, - operand_constraints = [#any_device, #any_device, #any_device]}> : + keep_dim = true}> : (tensor<256x384xf32, #layout1>, tensor<256x32xf32, #layout2>) -> tensor<256x32xf32, #layout2> return %1 : tensor<256x32xf32, #layout2> } @@ -21,8 +19,7 @@ func.func @reduceH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x384xf32, # // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, dim_arg = [-2: i32], - keep_dim = true, - operand_constraints = [#any_device, #any_device, #any_device]}> : + keep_dim = true}> : (tensor<256x384xf32, #layout1>, tensor<32x384xf32, #layout3>) -> tensor<32x384xf32, #layout3> return %1 : tensor<32x384xf32, #layout3> } @@ -33,8 +30,7 @@ func.func @reduceWH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x32xf32, # // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, dim_arg = [-1: i32, -2: i32], - keep_dim = true, - operand_constraints = [#any_device, #any_device, #any_device]}> : + keep_dim = true}> : (tensor<256x384xf32, #layout1>, tensor<32x32xf32, #layout4>) -> tensor<32x32xf32, #layout4> return %1 : tensor<32x32xf32, #layout4> } @@ -44,8 +40,7 @@ func.func @maxReduceWH(%arg0: tensor<256x384xf32, #layout1>) -> tensor<32x32xf32 // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.max" (%arg0, %0) <{operandSegmentSizes = array, dim_arg = [-1: i32, -2: i32], - keep_dim = true, - operand_constraints = [#any_device, #any_device, #any_device]}> : + keep_dim = true}> : (tensor<256x384xf32, #layout1>, tensor<32x32xf32, #layout4>) -> tensor<32x32xf32, #layout4> return %1 : tensor<32x32xf32, #layout4> } diff --git a/test/ttmlir/Silicon/TTMetal/simple_reduce_1x1.mlir b/test/ttmlir/Silicon/TTMetal/simple_reduce_1x1.mlir index 2df51c9e5..2038cfa08 100644 --- a/test/ttmlir/Silicon/TTMetal/simple_reduce_1x1.mlir +++ b/test/ttmlir/Silicon/TTMetal/simple_reduce_1x1.mlir @@ -1,5 +1,4 @@ // RUN: ttmlir-opt --ttir-to-ttmetal-backend-pipeline="system-desc-path=%system_desc_path%" %s | FileCheck %s -#any_device = #tt.operand_constraint #l1_ = #tt.memory_space func.func @reduceW(%arg0: tensor<64x256xf32>) -> tensor<64x32xf32> { @@ -7,8 +6,7 @@ func.func @reduceW(%arg0: tensor<64x256xf32>) -> tensor<64x32xf32> { // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, dim_arg = [-1: i32], - keep_dim = true, - operand_constraints = [#any_device, #any_device, #any_device]}> : + keep_dim = true}> : (tensor<64x256xf32>, tensor<64x32xf32>) -> tensor<64x32xf32> return %1 : tensor<64x32xf32> } @@ -18,8 +16,7 @@ func.func @reduceH(%arg0: tensor<256x64xf32>) -> tensor<32x64xf32> { // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, dim_arg = [-2: i32], - keep_dim = true, - operand_constraints = [#any_device, #any_device, #any_device]}> : + keep_dim = true}> : (tensor<256x64xf32>, tensor<32x64xf32>) -> tensor<32x64xf32> return %1 : tensor<32x64xf32> } @@ -29,8 +26,7 @@ func.func @reduceWH(%arg0: tensor<256x64xf32>) -> tensor<32x32xf32> { // CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]] %1 = "ttir.sum"(%arg0, %0) <{operandSegmentSizes = array, dim_arg = [-1: i32, -2: i32], - keep_dim = true, - operand_constraints = [#any_device, #any_device, #any_device]}> : + keep_dim = true}> : (tensor<256x64xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> } diff --git a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir index f3affc69d..62d4f228c 100644 --- a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir +++ b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim2.mlir @@ -3,13 +3,12 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // UNSUPPORTED: true // https://github.com/tenstorrent/tt-mlir/issues/1448 -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] %0 = "ttir.arange"() <{start = 0: si64, end = 64: si64, step = 2: si64, arange_dimension = 2: i64}> : () -> tensor<1x1x32x128xbf16> %1 = tensor.empty() : tensor<1x1x32x128xbf16> - %2 = "ttir.multiply"(%arg0, %0, %1) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + %2 = "ttir.multiply"(%arg0, %0, %1) <{operandSegmentSizes = array}> : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> return %2 : tensor<1x1x32x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir index 196e75709..26eb2f1a6 100644 --- a/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir +++ b/test/ttmlir/Silicon/TTNN/arange/simple_device_arange_dim3.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.arange"[[C:.*]] %0 = "ttir.arange"() <{start = 0: si64, end = 128: si64, step = 1: si64, arange_dimension = 3: i64}> : () -> tensor<1x1x32x128xbf16> %1 = tensor.empty() : tensor<1x1x32x128xbf16> - %2 = "ttir.multiply"(%arg0, %0, %1) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> + %2 = "ttir.multiply"(%arg0, %0, %1) <{operandSegmentSizes = array}> : (tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>, tensor<1x1x32x128xbf16>) -> tensor<1x1x32x128xbf16> return %2 : tensor<1x1x32x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir b/test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir index edf0a4eaf..9e5972e13 100644 --- a/test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir +++ b/test/ttmlir/Silicon/TTNN/ccl/all_gather.mlir @@ -3,13 +3,10 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // UNSUPPORTED: true // REQUIRES: multi-chip -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @forward(%arg0: tensor<1x1x32x32xf32>) -> tensor<1x1x32x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<1x1x32x128xf32> // CHECK: %[[C:.*]] = "ttnn.all_gather"[[C:.*]] - %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x32x32xf32>, tensor<1x1x32x128xf32>) -> tensor<1x1x32x128xf32> + %1 = "ttir.all_gather"(%arg0, %0) <{dim = 3 : si32}> : (tensor<1x1x32x32xf32>, tensor<1x1x32x128xf32>) -> tensor<1x1x32x128xf32> return %1 : tensor<1x1x32x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/complex_conv_channel_first.mlir b/test/ttmlir/Silicon/TTNN/complex_conv_channel_first.mlir index 8b0b0dec6..ca773e978 100644 --- a/test/ttmlir/Silicon/TTNN/complex_conv_channel_first.mlir +++ b/test/ttmlir/Silicon/TTNN/complex_conv_channel_first.mlir @@ -1,7 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device_tile = #tt.operand_constraint module @jit_convolution { func.func public @test_NCHW_IOHW_to_NHWC_OIHW_conv2d(%arg0: tensor<1x3x100x100xbf16>, %arg1: tensor<7x3x3x3xbf16>) -> tensor<1x7x100x100xbf16> { %0 = tensor.empty() : tensor<1x7x100x100xbf16> @@ -23,7 +22,6 @@ module @jit_convolution { >, feature_group_count = 1 : i64, input_dilation = array, - operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile], padding = array, weight_dilation = array, window_reversal = array, diff --git a/test/ttmlir/Silicon/TTNN/deallocate.mlir b/test/ttmlir/Silicon/TTNN/deallocate.mlir index 1e2f0b3c3..cdba16016 100644 --- a/test/ttmlir/Silicon/TTNN/deallocate.mlir +++ b/test/ttmlir/Silicon/TTNN/deallocate.mlir @@ -1,36 +1,35 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint #loc = loc("Dealloc":4294967295:0) module @"dealloc_test" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("Dealloc":4294967295:0), %arg1: tensor<1x10xf32> loc("Dealloc":4294967295:0), %arg2: tensor<256x10xf32> loc("Dealloc":4294967295:0), %arg3: tensor<1x256xf32> loc("Dealloc":4294967295:0), %arg4: tensor<784x256xf32> loc("Dealloc":4294967295:0)) -> tensor<1x10xf32> { %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) - %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) + %1 = "ttir.matmul"(%arg0, %arg4, %0) : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) // CHECK: %{{.+}} = "ttnn.matmul"([[I1:%.+]], [[I2:%.+]], [[O1:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}> // CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<784x256xf32, {{.+}}) -> () // CHECK: "ttnn.deallocate"([[I1]]) {{.+}} : (tensor<1x784xf32, {{.+}}>) -> () %2 = tensor.empty() : tensor<1x256xf32> loc(#loc9) - %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) + %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) // CHECK: %{{.+}} = "ttnn.add"([[I1:%.+]], [[I2:%.+]], [[O2:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}> // CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> () // CHECK: "ttnn.deallocate"([[O1]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> () %4 = tensor.empty() : tensor<1x256xf32> loc(#loc10) - %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) + %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) // CHECK: %{{.+}} = "ttnn.relu"([[I1:%.+]], [[O3:%.+]]) {{.+}} -> tensor<1x256xf32, {{.+}}> // CHECK: "ttnn.deallocate"([[O2]]) {{.+}} : (tensor<1x256xf32, {{.+}}>) -> () %6 = tensor.empty() : tensor<1x10xf32> loc(#loc11) - %7 = "ttir.matmul"(%5, %arg2, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) + %7 = "ttir.matmul"(%5, %arg2, %6) : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) // CHECK: %{{.+}} = "ttnn.matmul"([[I1:%.+]], [[I2:%.+]], [[O4:%.+]]) {{.+}} -> tensor<1x10xf32, {{.+}}> // CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<256x10xf32, {{.+}}>) -> () // CHECK: "ttnn.deallocate"([[O3]]) {{.+}} : (tensor<1x256xf32,{{.+}}>) -> () %8 = tensor.empty() : tensor<1x10xf32> loc(#loc12) - %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) + %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) // CHECK: %{{.+}} = "ttnn.add"([[I1:%.+]], [[I2:%.+]], [[O5:%.+]]) {{.+}} -> tensor<1x10xf32,{{.+}}> // CHECK: "ttnn.deallocate"([[I2]]) {{.+}} : (tensor<1x10xf32, {{.+}}>) -> () // CHECK: "ttnn.deallocate"([[O4]]) {{.+}} : (tensor<1x10xf32, {{.+}}>) -> () %10 = tensor.empty() : tensor<1x10xf32> loc(#loc13) - %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) + %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) return %11 : tensor<1x10xf32> loc(#loc7) } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Silicon/TTNN/embedding/embedding_1d_tensor.mlir b/test/ttmlir/Silicon/TTNN/embedding/embedding_1d_tensor.mlir index f4850e4f8..96c8609dc 100644 --- a/test/ttmlir/Silicon/TTNN/embedding/embedding_1d_tensor.mlir +++ b/test/ttmlir/Silicon/TTNN/embedding/embedding_1d_tensor.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xbf16>, tensor<512x128xbf16>, tensor<32x128xbf16>) -> tensor<32x128xbf16> + %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32xbf16>, tensor<512x128xbf16>, tensor<32x128xbf16>) -> tensor<32x128xbf16> return %1 : tensor<32x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/embedding/embedding_backward.mlir b/test/ttmlir/Silicon/TTNN/embedding/embedding_backward.mlir index db0a47c21..7cec2cd5e 100644 --- a/test/ttmlir/Silicon/TTNN/embedding/embedding_backward.mlir +++ b/test/ttmlir/Silicon/TTNN/embedding/embedding_backward.mlir @@ -1,14 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @backward(%arg0: tensor<1x32xbf16>, %arg1: tensor<512x128xbf16>, %arg2: tensor<1x32x128xbf16>) -> tensor<512x128xbf16> { // CHECK: %{{[0-9]+}} = "ttnn.empty" %0 = tensor.empty() : tensor<512x128xbf16> // CHECK: %{{[0-9]+}} = "ttnn.embedding_bw" - %1 = "ttir.embedding_backward"(%arg0, %arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : - (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>, tensor<512x128xbf16>) -> tensor<512x128xbf16> + %1 = "ttir.embedding_backward"(%arg0, %arg1, %arg2, %0) : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>, tensor<512x128xbf16>) -> tensor<512x128xbf16> return %1 : tensor<512x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/embedding/embedding_non_tile.mlir b/test/ttmlir/Silicon/TTNN/embedding/embedding_non_tile.mlir index c26634771..11397f27a 100644 --- a/test/ttmlir/Silicon/TTNN/embedding/embedding_non_tile.mlir +++ b/test/ttmlir/Silicon/TTNN/embedding/embedding_non_tile.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<1x32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<1x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16> + %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<1x32xbf16>, tensor<512x128xbf16>, tensor<1x32x128xbf16>) -> tensor<1x32x128xbf16> return %1 : tensor<1x32x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/embedding/gather_to_embedding.mlir b/test/ttmlir/Silicon/TTNN/embedding/gather_to_embedding.mlir index 52f417e3c..bf654ccb7 100644 --- a/test/ttmlir/Silicon/TTNN/embedding/gather_to_embedding.mlir +++ b/test/ttmlir/Silicon/TTNN/embedding/gather_to_embedding.mlir @@ -2,12 +2,11 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // XFAIL: * -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%operand: tensor<32000x1024xbf16>, %start_indices: tensor<1x32xbf16>) -> tensor<1x32x1024xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<1x32x1024xbf16> - // CHECK: %[[C:.*]] = "ttnn.embedding"(%start_indices, %operand, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32xbf16>, tensor<32000x1024xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16> + // CHECK: %[[C:.*]] = "ttnn.embedding"(%start_indices, %operand, %0) <{operandSegmentSizes = array}> : (tensor<1x32xbf16>, tensor<32000x1024xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16> %1 = "ttir.gather"(%operand, %start_indices, %0) { offset_dims = array, collapsed_slice_dims = array, @@ -16,8 +15,7 @@ module attributes {} { start_index_map = array, index_vector_dim = 1 : si64, slice_sizes = array, - indices_are_sorted = false, - operand_constraints = [#any_device, #any_device, #any_device] + indices_are_sorted = false } : (tensor<32000x1024xbf16>, tensor<1x32xbf16>, tensor<1x32x1024xbf16>) -> tensor<1x32x1024xbf16> return %1 : tensor<1x32x1024xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir b/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir index 343bb5e76..583aa82e0 100644 --- a/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir +++ b/test/ttmlir/Silicon/TTNN/embedding/simple_embedding.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<32x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> + %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> return %1 : tensor<32x32x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir b/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir index 33645730a..951b36061 100644 --- a/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir +++ b/test/ttmlir/Silicon/TTNN/emitc/simple_add.mlir @@ -1,10 +1,8 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint - func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %1 : tensor<32x32xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir b/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir index 3f304969c..8fc4d2e9b 100644 --- a/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir +++ b/test/ttmlir/Silicon/TTNN/emitc/two_fns.mlir @@ -1,16 +1,14 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint - func.func @add(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %1 : tensor<32x32xbf16> } func.func @subtract(%arg0: tensor<32x32xbf16>, %arg1: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> - %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xbf16>, tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %1 : tensor<32x32xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/kv_cache/fill_cache.mlir b/test/ttmlir/Silicon/TTNN/kv_cache/fill_cache.mlir index 67bf8387b..1a47ecf87 100644 --- a/test/ttmlir/Silicon/TTNN/kv_cache/fill_cache.mlir +++ b/test/ttmlir/Silicon/TTNN/kv_cache/fill_cache.mlir @@ -1,14 +1,13 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module { func.func @forward(%arg0: tensor<1x32x64x512xbf16>, %arg1: tensor<1x32x3x512xbf16>) -> tensor<1x32x64x512xbf16> { // CHECK: "ttnn.fill_cache"[[C:.*]] - %1 = "ttir.fill_cache"(%arg0, %arg1) <{batch_offset = 0: i32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x512xbf16>, tensor<1x32x3x512xbf16>) -> tensor<1x32x64x512xbf16> + %1 = "ttir.fill_cache"(%arg0, %arg1) <{batch_offset = 0: i32}> : (tensor<1x32x64x512xbf16>, tensor<1x32x3x512xbf16>) -> tensor<1x32x64x512xbf16> %cst = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<1x32x64x512xbf16>}> : () -> tensor<1x32x64x512xbf16> %addition_dps = tensor.empty() : tensor<1x32x64x512xbf16> - %2 = "ttir.add"(%1, %cst, %addition_dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>) -> tensor<1x32x64x512xbf16> + %2 = "ttir.add"(%1, %cst, %addition_dps) <{operandSegmentSizes = array}> : (tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>) -> tensor<1x32x64x512xbf16> return %2 : tensor<1x32x64x512xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/kv_cache/update_cache.mlir b/test/ttmlir/Silicon/TTNN/kv_cache/update_cache.mlir index 63a08b302..564c030a3 100644 --- a/test/ttmlir/Silicon/TTNN/kv_cache/update_cache.mlir +++ b/test/ttmlir/Silicon/TTNN/kv_cache/update_cache.mlir @@ -1,15 +1,14 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module { func.func @forward(%arg0: tensor<1x32x64x512xbf16>, %arg1: tensor<1x32x1x512xbf16>) -> tensor<1x32x64x512xbf16> { // CHECK: "ttnn.update_cache"[[C:.*]] %update_index = "ttir.constant"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> - %1 = "ttir.update_cache"(%arg0, %arg1, %update_index) <{batch_offset = 0: i32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x512xbf16>, tensor<1x32x1x512xbf16>, tensor<1xi32>) -> tensor<1x32x64x512xbf16> + %1 = "ttir.update_cache"(%arg0, %arg1, %update_index) <{batch_offset = 0: i32}> : (tensor<1x32x64x512xbf16>, tensor<1x32x1x512xbf16>, tensor<1xi32>) -> tensor<1x32x64x512xbf16> %cst = "ttir.constant"() <{value = dense<1.000000e+00> : tensor<1x32x64x512xbf16>}> : () -> tensor<1x32x64x512xbf16> %addition_dps = tensor.empty() : tensor<1x32x64x512xbf16> - %2 = "ttir.add"(%1, %cst, %addition_dps) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>) -> tensor<1x32x64x512xbf16> + %2 = "ttir.add"(%1, %cst, %addition_dps) <{operandSegmentSizes = array}> : (tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>, tensor<1x32x64x512xbf16>) -> tensor<1x32x64x512xbf16> return %2 : tensor<1x32x64x512xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/multi_device.mlir b/test/ttmlir/Silicon/TTNN/multi_device.mlir index 448f67169..c927c0d2b 100644 --- a/test/ttmlir/Silicon/TTNN/multi_device.mlir +++ b/test/ttmlir/Silicon/TTNN/multi_device.mlir @@ -3,13 +3,10 @@ // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // UNSUPPORTED: true // REQUIRES: multi-chip -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @multiply(%arg0: tensor<8x64x128xf32>, %arg1: tensor<8x64x128xf32>) -> tensor<8x64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<8x64x128xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8x64x128xf32>, tensor<8x64x128xf32>, tensor<8x64x128xf32>) -> tensor<8x64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<8x64x128xf32>, tensor<8x64x128xf32>, tensor<8x64x128xf32>) -> tensor<8x64x128xf32> return %1 : tensor<8x64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/operand_broadcasts.mlir b/test/ttmlir/Silicon/TTNN/operand_broadcasts.mlir index 2a985d65f..1b919ec1d 100644 --- a/test/ttmlir/Silicon/TTNN/operand_broadcasts.mlir +++ b/test/ttmlir/Silicon/TTNN/operand_broadcasts.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @bcast_one_dim(%arg0: tensor<2x64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<2x64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<2x64x128xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<2x64x128xf32>, tensor<64x128xf32>, tensor<2x64x128xf32>) -> tensor<2x64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<2x64x128xf32>, tensor<64x128xf32>, tensor<2x64x128xf32>) -> tensor<2x64x128xf32> return %1 : tensor<2x64x128xf32> } @@ -15,7 +14,7 @@ module attributes {} { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<17x16x15x14xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<17x16x15x14xf32>, tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<17x16x15x14xf32>, tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> return %1 : tensor<17x16x15x14xf32> } diff --git a/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir b/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir index 3cf9c4581..96798905c 100644 --- a/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir +++ b/test/ttmlir/Silicon/TTNN/optimizer/mnist_sharding.mlir @@ -1,7 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=true memory-layout-analysis-enabled=true" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { @@ -9,21 +8,21 @@ module @"tt-forge-graph" attributes {} { // CHECK-DAG: #[[LAYOUT_11:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x1x!tt.tile<32x32, f32>, #l1_>, > %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_10]]> - %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) + %1 = "ttir.matmul"(%arg0, %arg4, %0) : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) %2 = tensor.empty() : tensor<1x256xf32> loc(#loc9) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_10]]> - %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) + %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) %4 = tensor.empty() : tensor<1x256xf32> loc(#loc10) // CHECK: %{{.*}} = "ttnn.relu"{{.*}} -> tensor<1x256xf32, #[[LAYOUT_10]]> - %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) + %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) %6 = tensor.empty() : tensor<1x10xf32> loc(#loc11) // CHECK: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_11]]> - %7 = "ttir.matmul"(%5, %arg2, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) + %7 = "ttir.matmul"(%5, %arg2, %6) : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) %8 = tensor.empty() : tensor<1x10xf32> loc(#loc12) // CHECK: %{{.*}} = "ttnn.add"{{.*}} -> tensor<1x10xf32, #[[LAYOUT_11]]> - %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) + %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) %10 = tensor.empty() : tensor<1x10xf32> loc(#loc13) - %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) + %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) return %11 : tensor<1x10xf32> loc(#loc7) } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Silicon/TTNN/optimizer/simple_fork_join.mlir b/test/ttmlir/Silicon/TTNN/optimizer/simple_fork_join.mlir index 981c26b49..e323e1024 100644 --- a/test/ttmlir/Silicon/TTNN/optimizer/simple_fork_join.mlir +++ b/test/ttmlir/Silicon/TTNN/optimizer/simple_fork_join.mlir @@ -2,17 +2,16 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // UNSUPPORTED: true -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %arg2: tensor<64x128xbf16>, %arg3: tensor<64x128xbf16>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> %2 = tensor.empty() : tensor<64x128xbf16> - %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> %4 = tensor.empty() : tensor<64x128xbf16> - %5 = "ttir.add"(%1, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %5 = "ttir.add"(%1, %3, %4) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> %6 = tensor.empty() : tensor<64x128xbf16> - %7 = "ttir.relu"(%5, %6) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %7 = "ttir.relu"(%5, %6) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %7 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir index 0193ec36b..8d1b6393b 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/mnist.mlir @@ -1,6 +1,5 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint #loc = loc("MNISTLinear":4294967295:0) module @"tt-forge-graph" attributes {} { func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> { @@ -8,21 +7,21 @@ module @"tt-forge-graph" attributes {} { // CHECK: #[[LAYOUT_11:.*]] = #tt.metal_layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<1x10xf32, #l1_>, block_sharded> %0 = tensor.empty() : tensor<1x256xf32> loc(#loc8) // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]> - %1 = "ttir.matmul"(%arg0, %arg4, %0) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) + %1 = "ttir.matmul"(%arg0, %arg4, %0) : (tensor<1x784xf32>, tensor<784x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc8) %2 = tensor.empty() : tensor<1x256xf32> loc(#loc9) // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]> - %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) + %3 = "ttir.add"(%1, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc9) %4 = tensor.empty() : tensor<1x256xf32> loc(#loc10) // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] -> tensor<1x256xf32, #[[LAYOUT_10]]> - %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) + %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array}> : (tensor<1x256xf32>, tensor<1x256xf32>) -> tensor<1x256xf32> loc(#loc10) %6 = tensor.empty() : tensor<1x10xf32> loc(#loc11) // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] -> tensor<1x10xf32, #[[LAYOUT_11]]> - %7 = "ttir.matmul"(%5, %arg2, %6) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) + %7 = "ttir.matmul"(%5, %arg2, %6) : (tensor<1x256xf32>, tensor<256x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc11) %8 = tensor.empty() : tensor<1x10xf32> loc(#loc12) // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x10xf32, #[[LAYOUT_11]]> - %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) + %9 = "ttir.add"(%7, %arg1, %8) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc12) %10 = tensor.empty() : tensor<1x10xf32> loc(#loc13) - %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) + %11 = "ttir.softmax"(%9, %10) <{dimension = 1 : si32}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> loc(#loc13) return %11 : tensor<1x10xf32> loc(#loc7) } loc(#loc) } loc(#loc) diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_and.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_and.mlir index b4569ef61..d27968520 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_and.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_and.mlir @@ -1,14 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @logical_and(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: {{.*}} = "ttnn.empty"{{.*}} - %1 = "ttir.logical_and"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.logical_and"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.logical_and" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir index 2e7f55428..d554baf2e 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ceil.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @ceil(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.ceil"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir index de90c54b8..44806c22d 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_clamp.mlir @@ -2,9 +2,6 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: %[[DEVICE:.*]] = "ttnn.to_device"(%arg0, @@ -12,6 +9,6 @@ func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #ttnn_layout{{[0-9]+}}>) -> [[TENSOR]] - %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_concat.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_concat.mlir index 122364cac..c889afec3 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_concat.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_concat.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @concat(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<32x96xf32> // CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]] - %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> return %1 : tensor<32x96xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv.mlir index 543f05763..13708ef16 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_conv.mlir @@ -1,12 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x32x32x64xbf16> { %0 = tensor.empty() : tensor<1x32x32x64xbf16> // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] - %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) <{stride_height=1: si32, stride_width=1: si32, dilation_height=1: si32, dilation_width=1: si32, groups=1: si32, padding_left=1: si32, padding_right=1: si32, padding_top=1: si32, padding_bottom=1: si32, is_convtranspose2d=0: si32, output_height_transpose=0: si32, output_width_transpose=0: si32, stride_transpose=0: si32, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) <{stride_height=1: si32, stride_width=1: si32, dilation_height=1: si32, dilation_width=1: si32, groups=1: si32, padding_left=1: si32, padding_right=1: si32, padding_top=1: si32, padding_bottom=1: si32, is_convtranspose2d=0: si32, output_height_transpose=0: si32, output_width_transpose=0: si32, stride_transpose=0: si32}> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> return %1 : tensor<1x32x32x64xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir index ede823439..2596e4a13 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_cosine.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @cosine(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.cos"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_div.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_div.mlir index 249c8e314..a6b6a55a4 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_div.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_div.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @div(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.div"[[C:.*]] - %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_embedding.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_embedding.mlir index 343bb5e76..583aa82e0 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_embedding.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_embedding.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xbf16>, %arg1: tensor<512x128xbf16>) -> tensor<32x32x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<32x32x128xbf16> // CHECK: %[[C:.*]] = "ttnn.embedding"[[C:.*]] - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> + %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32x32xbf16>, tensor<512x128xbf16>, tensor<32x32x128xbf16>) -> tensor<32x32x128xbf16> return %1 : tensor<32x32x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_eq.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_eq.mlir index 39fdcd6d1..44ff28faf 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_eq.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_eq.mlir @@ -1,10 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - module attributes {} { func.func @equal(%arg0: tensor<13x31xf32>, %arg1: tensor<13x31xf32>) -> tensor<13x31xf32> { // CHECK: %[[C:.*]] = "ttnn.empty @@ -15,7 +11,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_expm1.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_expm1.mlir index 27cf6f80e..7d035174c 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_expm1.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_expm1.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @expm1(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.expm1"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.expm1"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.expm1"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_floor.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_floor.mlir index fa77817a8..d73927534 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_floor.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_floor.mlir @@ -1,9 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @floor(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %{{[0-9]+}} = "ttnn.empty" // CHECK-SAME: [[TENSOR:tensor<64x128xf32,]] @@ -12,6 +9,6 @@ func.func @floor(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.floor"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.floor"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ge.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ge.mlir index 64e3b16e3..07a6a56f1 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ge.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ge.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @ge(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.ge"[[C:.*]] - %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gelu.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gelu.mlir index 628bb5c37..7e9767e1f 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gelu.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gelu.mlir @@ -1,8 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: "ttnn.empty" @@ -12,6 +10,6 @@ func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, - %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gt.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gt.mlir index 835714660..e02ed1e95 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gt.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_gt.mlir @@ -1,10 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - module attributes {} { func.func @greater_than(%arg0: tensor<13x31xf32>, %arg1: tensor<13x31xf32>) -> tensor<13x31xf32> { // CHECK: %[[C:.*]] = "ttnn.empty @@ -15,7 +11,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.gt"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.gt"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir index f1489a5eb..b8dc64fb7 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_isfinite.mlir @@ -1,8 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.empty" @@ -12,6 +10,6 @@ func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir index 6da5d3910..ab073ef75 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_linear.mlir @@ -2,7 +2,6 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device_tile = #tt.operand_constraint module { func.func @linear(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> { // CHECK: "ttnn.empty" @@ -14,7 +13,7 @@ module { // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir index b3de1bba4..d4a7ed331 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @log(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.log"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.log"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.log"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir index 2c32cc817..3d50d3e88 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_log1p.mlir @@ -2,13 +2,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @log1p(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_lt.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_lt.mlir index 1b3bca82c..1f95207ba 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_lt.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_lt.mlir @@ -1,10 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - module attributes {} { func.func @less_than(%arg0: tensor<13x31xf32>, %arg1: tensor<13x31xf32>) -> tensor<13x31xf32> { // CHECK: %[[C:.*]] = "ttnn.empty @@ -15,7 +11,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.lt"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.lt"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir index 9c240b0ab..f221001bb 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_matmul.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device_tile = #tt.operand_constraint // CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> return %1 : tensor<64x96xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_max.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_max.mlir index fda141005..1011fad89 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_max.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_max.mlir @@ -1,11 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint - func.func @max(%arg0: tensor<1x1x512x64xbf16>) -> tensor<1x1x512xbf16> { %0 = tensor.empty() : tensor<1x1x512xbf16> // CHECK: %[[C:.*]] = "ttnn.max"[[C:.*]] - %1 = "ttir.max"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16> + %1 = "ttir.max"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16> return %1 : tensor<1x1x512xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_maximum.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_maximum.mlir index 3642a5511..3893bc9f0 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_maximum.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_maximum.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.maximum"[[C:.*]] - %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_maxpool2d.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_maxpool2d.mlir index 4722e9c52..4fdd836dd 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_maxpool2d.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_maxpool2d.mlir @@ -1,12 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> { %0 = tensor.empty() : tensor<1x64x64x32xbf16> // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] - %1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16> + %1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16> return %1 : tensor<1x64x64x32xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_multiply.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_multiply.mlir index 8b53113c2..7991cbc78 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_multiply.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_multiply.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ne.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ne.mlir index 78e5b1245..300e66226 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ne.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_ne.mlir @@ -1,10 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - module attributes {} { func.func @not_equal(%arg0: tensor<13x31xf32>, %arg1: tensor<13x31xf32>) -> tensor<13x31xf32> { // CHECK: %[[C:.*]] = "ttnn.empty @@ -15,7 +11,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.ne"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.ne"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_neg.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_neg.mlir index b0aaaa8fd..907541764 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_neg.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_neg.mlir @@ -1,12 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @negate(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> // CHECK: %[[C:.*]] = "ttnn.neg"[[C:.*]] - %1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + %1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_not.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_not.mlir index c3429abd7..b9d07674e 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_not.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_not.mlir @@ -1,14 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @logical_not(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: {{.*}} = "ttnn.empty"{{.*}} - %1 = "ttir.logical_not"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.logical_not"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.logical_not" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_or.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_or.mlir index 21287a739..e6c7ec555 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_or.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_or.mlir @@ -1,14 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @logical_or(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: {{.*}} = "ttnn.empty"{{.*}} - %1 = "ttir.logical_or"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.logical_or"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.logical_or" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_reciprocal.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_reciprocal.mlir index 8a5bf39f3..d17444e4b 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_reciprocal.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_reciprocal.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @reciprocal(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.reciprocal"[[C:.*]] - %1 = "ttir.reciprocal"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.reciprocal"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_relu.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_relu.mlir index cec787a3b..0ae23ec15 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_relu.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_relu.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] - %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_remainder.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_remainder.mlir index 68375a9e0..e358d663e 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_remainder.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_remainder.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}} - %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> // CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}} return %1 : tensor<32x32xf32> // CHECK: return {{.*}} : tensor<32x32xf32, {{.*}} diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_rsqrt.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_rsqrt.mlir index 61a5f4055..4c85d11ca 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_rsqrt.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_rsqrt.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @rsqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.rsqrt"[[C:.*]] - %1 = "ttir.rsqrt"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.rsqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sigmoid.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sigmoid.mlir index 1084d5321..9583be957 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sigmoid.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sigmoid.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @sigmoid(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.sigmoid"[[C:.*]] - %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sign.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sign.mlir index 543a54d3e..26fe2b2d0 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sign.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sign.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @sign(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.sign"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir index 36f71d8e6..61fe517ea 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sine.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @sine(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.sin"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_slice.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_slice.mlir index 8d43cb7dc..e32101781 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_slice.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_slice.mlir @@ -1,12 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<4x32x32xbf16>) -> tensor<2x16x16xbf16> { %0 = tensor.empty() : tensor<2x16x16xbf16> // CHECK: %[[C:.*]] = "ttnn.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [2: i32, 16: i32, 16: i32], step = [1: i32, 1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x32x32xbf16>, tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [2: i32, 16: i32, 16: i32], step = [1: i32, 1: i32, 1: i32]}> : (tensor<4x32x32xbf16>, tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16> return %1 : tensor<2x16x16xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_softmax.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_softmax.mlir index cdf8fae8d..34d430019 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_softmax.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_softmax.mlir @@ -1,17 +1,14 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @softmax(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> { %0 = tensor.empty() : tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for positive dimension attribute - %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> + %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> %2 = tensor.empty() : tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for negative dimension attribute - %3 = "ttir.softmax"(%1, %2) <{dimension = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> + %3 = "ttir.softmax"(%1, %2) <{dimension = -1 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> return %3 : tensor<512x1024xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sqrt.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sqrt.mlir index eeba82ec7..72e7bb579 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sqrt.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sqrt.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @sqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.sqrt"[[C:.*]] - %1 = "ttir.sqrt"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_subtract.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_subtract.mlir index 75f9b0b7d..679994dc5 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_subtract.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_subtract.mlir @@ -1,13 +1,10 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @subtract(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]] - %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sum.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sum.mlir index 432264760..f0beb34b2 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sum.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_sum.mlir @@ -1,11 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint - func.func @sum(%arg0: tensor<1x1x512x64xbf16>) -> tensor<1x1x512xbf16> { %0 = tensor.empty() : tensor<1x1x512xbf16> // CHECK: %[[C:.*]] = "ttnn.sum"[[C:.*]] - %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16> + %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16> return %1 : tensor<1x1x512xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir index aa7b97298..47957677b 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tan.mlir @@ -1,13 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint func.func @tan(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.tan"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir index ecb7266c9..4844bd308 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_tanh.mlir @@ -1,13 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint func.func @tanh(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_transpose.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_transpose.mlir index d9587863c..0f2fc1b98 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_transpose.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_transpose.mlir @@ -1,12 +1,9 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @transpose(%arg0: tensor<64x128xbf16>) -> tensor<128x64xbf16> { %0 = tensor.empty() : tensor<128x64xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> return %1 : tensor<128x64xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_typecast.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_typecast.mlir index cb4f2d64f..82666195d 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_typecast.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_typecast.mlir @@ -2,12 +2,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint func.func @typecast(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: %[[C:.*]] = "ttnn.typecast" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xbf16, - %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir index 647f94e61..9076f24f4 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_where.mlir @@ -1,14 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint func.func @test_where(%arg0: tensor<13x37xbf16>, %arg1: tensor<13x37xbf16>) -> tensor<13x37xbf16> { %0 = tensor.empty() : tensor<13x37xbf16> - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> %2 = tensor.empty() : tensor<13x37xbf16> - %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) diff --git a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir index c47a34cee..d68b72608 100644 --- a/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir +++ b/test/ttmlir/Silicon/TTNN/perf_unit/test_perf_xor.mlir @@ -2,9 +2,6 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: %{{[0-9]+}} = "ttnn.empty"{{.*}} [[TENSOR:tensor<64x128xbf16]] %0 = tensor.empty() : tensor<64x128xbf16> @@ -13,6 +10,6 @@ func.func @logical_xor(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>) - // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/pooling/complex_pooling.mlir b/test/ttmlir/Silicon/TTNN/pooling/complex_pooling.mlir index 8e44f2a74..7a0be59ac 100644 --- a/test/ttmlir/Silicon/TTNN/pooling/complex_pooling.mlir +++ b/test/ttmlir/Silicon/TTNN/pooling/complex_pooling.mlir @@ -1,7 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32x128x128xbf16>, %arg1: tensor<1x32x128x128xbf16>) -> (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) { %0 = tensor.empty() : tensor<1x32x64x64xbf16> @@ -14,8 +13,7 @@ module attributes {} { window_strides = array, base_dilations = array, window_dilations = array, - padding = array, - operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x128x128xbf16>, tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) + padding = array}> : (tensor<1x32x128x128xbf16>, tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) -> (tensor<1x32x64x64xbf16>, tensor<1x32x64x64xbf16>) return %2, %3 : tensor<1x32x64x64xbf16>,tensor<1x32x64x64xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir b/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir index 710daea10..7c4d62660 100644 --- a/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir +++ b/test/ttmlir/Silicon/TTNN/pooling/simple_pooling.mlir @@ -1,7 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32x128x128xbf16>) -> tensor<1x32x64x64xbf16> { %0 = tensor.empty() : tensor<1x32x64x64xbf16> @@ -13,8 +12,7 @@ module attributes {} { window_strides = array, base_dilations = array, window_dilations = array, - padding = array, - operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> + padding = array}> : (tensor<1x32x128x128xbf16>, tensor<1x32x64x64xbf16>) -> tensor<1x32x64x64xbf16> return %1 : tensor<1x32x64x64xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/sharded/simple_eltwise_sharded.mlir b/test/ttmlir/Silicon/TTNN/sharded/simple_eltwise_sharded.mlir index ff8caa4f7..d74b582ed 100644 --- a/test/ttmlir/Silicon/TTNN/sharded/simple_eltwise_sharded.mlir +++ b/test/ttmlir/Silicon/TTNN/sharded/simple_eltwise_sharded.mlir @@ -1,14 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path% enable-optimizer=false" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#l1_block_sharded = #tt.operand_constraint -#l1_height_sharded = #tt.operand_constraint - func.func @subtract(%arg0: tensor<224x64xf32>, %arg1: tensor<224x64xf32>) -> tensor<224x64xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<224x64xf32> // CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]] - %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#l1_block_sharded, #l1_block_sharded, #l1_block_sharded]}> : (tensor<224x64xf32>, tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> + %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<224x64xf32>, tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> return %1 : tensor<224x64xf32> } @@ -16,7 +13,7 @@ func.func @div(%arg0: tensor<224x64xf32>, %arg1: tensor<224x64xf32>) -> tensor<2 // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<224x64xf32> // CHECK: %[[C:.*]] = "ttnn.div"[[C:.*]] - %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#l1_block_sharded, #l1_block_sharded, #l1_block_sharded]}> : (tensor<224x64xf32>, tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> + %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<224x64xf32>, tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> return %1 : tensor<224x64xf32> } @@ -24,7 +21,7 @@ func.func @multiply(%arg0: tensor<224x64xf32>, %arg1: tensor<224x64xf32>) -> ten // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<224x64xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#l1_block_sharded, #l1_block_sharded, #l1_block_sharded]}> : (tensor<224x64xf32>, tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<224x64xf32>, tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> return %1 : tensor<224x64xf32> } @@ -32,7 +29,7 @@ func.func @relu(%arg0: tensor<224x64xf32>) -> tensor<224x64xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<224x64xf32> // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] - %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#l1_block_sharded, #l1_block_sharded]}> : (tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> return %1 : tensor<224x64xf32> } @@ -40,21 +37,21 @@ func.func @ge(%arg0: tensor<224x64xf32>, %arg1: tensor<224x64xf32>) -> tensor<22 // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<224x64xf32> // CHECK: %[[C:.*]] = "ttnn.ge"[[C:.*]] - %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#l1_block_sharded, #l1_block_sharded, #l1_block_sharded]}> : (tensor<224x64xf32>, tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> + %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<224x64xf32>, tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> return %1 : tensor<224x64xf32> } func.func @reshape(%arg0: tensor<4x2x224x64xbf16>) -> tensor<2x4x224x64xbf16> { %0 = tensor.empty() : tensor<2x4x224x64xbf16> // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] - %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 224: i32, 64: i32] , operand_constraints = [#l1_height_sharded, #l1_height_sharded]}> : (tensor<4x2x224x64xbf16>, tensor<2x4x224x64xbf16>) -> tensor<2x4x224x64xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 224: i32, 64: i32]}> : (tensor<4x2x224x64xbf16>, tensor<2x4x224x64xbf16>) -> tensor<2x4x224x64xbf16> return %1 : tensor<2x4x224x64xbf16> } func.func @squeeze(%arg0: tensor<1x2x1x224x64xbf16>) -> tensor<1x2x224x64xbf16> { %0 = tensor.empty() : tensor<1x2x224x64xbf16> // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] - %1 = "ttir.squeeze"(%arg0, %0) <{dim = 2 : si32, operand_constraints = [#l1_height_sharded, #l1_height_sharded]}> : (tensor<1x2x1x224x64xbf16>, tensor<1x2x224x64xbf16>) -> tensor<1x2x224x64xbf16> + %1 = "ttir.squeeze"(%arg0, %0) <{dim = 2 : si32}> : (tensor<1x2x1x224x64xbf16>, tensor<1x2x224x64xbf16>) -> tensor<1x2x224x64xbf16> return %1 : tensor<1x2x224x64xbf16> } @@ -62,7 +59,7 @@ func.func @reciprocal(%arg0: tensor<224x64xf32>) -> tensor<224x64xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<224x64xf32> // CHECK: %[[C:.*]] = "ttnn.reciprocal"[[C:.*]] - %1 = "ttir.reciprocal"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#l1_block_sharded, #l1_block_sharded]}> : (tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> + %1 = "ttir.reciprocal"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> return %1 : tensor<224x64xf32> } @@ -70,7 +67,7 @@ func.func @sigmoid(%arg0: tensor<224x64xf32>) -> tensor<224x64xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<224x64xf32> // CHECK: %[[C:.*]] = "ttnn.sigmoid"[[C:.*]] - %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#l1_block_sharded, #l1_block_sharded]}> : (tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> + %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> return %1 : tensor<224x64xf32> } @@ -78,7 +75,7 @@ func.func @sqrt(%arg0: tensor<224x64xf32>) -> tensor<224x64xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<224x64xf32> // CHECK: %[[C:.*]] = "ttnn.sqrt"[[C:.*]] - %1 = "ttir.sqrt"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#l1_block_sharded, #l1_block_sharded]}> : (tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> + %1 = "ttir.sqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<224x64xf32>, tensor<224x64xf32>) -> tensor<224x64xf32> return %1 : tensor<224x64xf32> } diff --git a/test/ttmlir/Silicon/TTNN/simple_compare.mlir b/test/ttmlir/Silicon/TTNN/simple_compare.mlir index f53ba7530..5263c4fe4 100644 --- a/test/ttmlir/Silicon/TTNN/simple_compare.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_compare.mlir @@ -1,10 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - module attributes {} { func.func @equal(%arg0: tensor<13x31xf32>, %arg1: tensor<13x31xf32>) -> tensor<13x31xf32> { // CHECK: %[[C:.*]] = "ttnn.empty @@ -15,7 +11,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } @@ -28,7 +24,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.ne"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.ne"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } @@ -41,7 +37,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } @@ -54,7 +50,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.gt"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.gt"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } @@ -67,7 +63,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.le"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.le"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } @@ -80,7 +76,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.lt"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> + %1 = "ttir.lt"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x31xf32>, tensor<13x31xf32>, tensor<13x31xf32>) -> tensor<13x31xf32> return %1 : tensor<13x31xf32> } } diff --git a/test/ttmlir/Silicon/TTNN/simple_conv.mlir b/test/ttmlir/Silicon/TTNN/simple_conv.mlir index 543f05763..13708ef16 100644 --- a/test/ttmlir/Silicon/TTNN/simple_conv.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_conv.mlir @@ -1,12 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x32x32x64xbf16>, %arg1: tensor<64x64x3x3xbf16>, %arg2: tensor<1x1x1x64xbf16>) -> tensor<1x32x32x64xbf16> { %0 = tensor.empty() : tensor<1x32x32x64xbf16> // CHECK: %[[C:.*]] = "ttnn.conv2d"[[C:.*]] - %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) <{stride_height=1: si32, stride_width=1: si32, dilation_height=1: si32, dilation_width=1: si32, groups=1: si32, padding_left=1: si32, padding_right=1: si32, padding_top=1: si32, padding_bottom=1: si32, is_convtranspose2d=0: si32, output_height_transpose=0: si32, output_width_transpose=0: si32, stride_transpose=0: si32, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> + %1 = "ttir.conv2d"(%arg0, %arg1, %arg2, %0) <{stride_height=1: si32, stride_width=1: si32, dilation_height=1: si32, dilation_width=1: si32, groups=1: si32, padding_left=1: si32, padding_right=1: si32, padding_top=1: si32, padding_bottom=1: si32, is_convtranspose2d=0: si32, output_height_transpose=0: si32, output_width_transpose=0: si32, stride_transpose=0: si32}> : (tensor<1x32x32x64xbf16>, tensor<64x64x3x3xbf16>, tensor<1x1x1x64xbf16>, tensor<1x32x32x64xbf16>) -> tensor<1x32x32x64xbf16> return %1 : tensor<1x32x32x64xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir index b0fb94cc6..a0452f01f 100644 --- a/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_eltwise.mlir @@ -1,14 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @add(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -16,7 +13,7 @@ func.func @ceil(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.ceil"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + %1 = "ttir.ceil"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> } @@ -27,7 +24,7 @@ func.func @clamp(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK: = "ttnn.clamp"(%[[LAYOUT]]) // CHECK-SAME: {max = 3.000000e+00 : f32, min = 2.000000e+00 : f32} // CHECK-SAME: [[TENSOR:tensor<64x128xbf16]], #ttnn_layout{{[0-9]+}}>) -> [[TENSOR]] - %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.clamp"(%arg0, %0) <{max = 3.000000e+00 : f32, min = 2.000000e+00 : f32}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } @@ -35,7 +32,7 @@ func.func @concat(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor< // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<32x96xf32> // CHECK: %[[C:.*]] = "ttnn.concat"[[C:.*]] - %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> + %1 = "ttir.concat"(%arg0, %arg1, %0) <{dim = 1 : si32}> : (tensor<32x32xf32>, tensor<32x64xf32>, tensor<32x96xf32>) -> tensor<32x96xf32> return %1 : tensor<32x96xf32> } @@ -43,7 +40,7 @@ func.func @cosine(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.cos"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + %1 = "ttir.cos"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> } @@ -51,7 +48,7 @@ func.func @div(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<6 // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.div"[[C:.*]] - %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.div"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -63,7 +60,7 @@ func.func @floor(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.floor"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.floor"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -75,7 +72,7 @@ func.func @is_finite(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { // CHECK-SAME: tensor<64x128xbf16, // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.isfinite"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } @@ -88,7 +85,7 @@ func.func @minimum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tens // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.minimum"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.minimum"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -96,7 +93,7 @@ func.func @ge(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64 // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.ge"[[C:.*]] - %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.ge"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -104,7 +101,7 @@ func.func @maximum(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tens // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.maximum"[[C:.*]] - %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.maximum"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -112,14 +109,14 @@ func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> ten // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.multiply"[[C:.*]] - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } func.func @negate(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> // CHECK: %[[C:.*]] = "ttnn.neg"[[C:.*]] - %1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + %1 = "ttir.neg"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> } @@ -127,7 +124,7 @@ func.func @reciprocal(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.reciprocal"[[C:.*]] - %1 = "ttir.reciprocal"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.reciprocal"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -135,7 +132,7 @@ func.func @relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.relu"[[C:.*]] - %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.relu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -143,21 +140,21 @@ func.func @leaky_relu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty" %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.leaky_relu" - %1 = "ttir.leaky_relu"(%arg0, %0) <{parameter = 0.01 : f32, operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.leaky_relu"(%arg0, %0) <{parameter = 0.01 : f32, operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } func.func @reshape(%arg0: tensor<4x2x32x32xbf16>) -> tensor<2x4x32x32xbf16> { %0 = tensor.empty() : tensor<2x4x32x32xbf16> // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] - %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32] , operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [2: i32, 4: i32, 32: i32, 32: i32]}> : (tensor<4x2x32x32xbf16>, tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> return %1 : tensor<2x4x32x32xbf16> } func.func @squeeze(%arg0: tensor<1x2x1x32x32xbf16>) -> tensor<1x2x32x32xbf16> { %0 = tensor.empty() : tensor<1x2x32x32xbf16> // CHECK: %[[C:.*]] = "ttnn.reshape"[[C:.*]] - %1 = "ttir.squeeze"(%arg0, %0) <{dim = 2 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<1x2x1x32x32xbf16>, tensor<1x2x32x32xbf16>) -> tensor<1x2x32x32xbf16> + %1 = "ttir.squeeze"(%arg0, %0) <{dim = 2 : si32}> : (tensor<1x2x1x32x32xbf16>, tensor<1x2x32x32xbf16>) -> tensor<1x2x32x32xbf16> return %1 : tensor<1x2x32x32xbf16> } @@ -165,7 +162,7 @@ func.func @subtract(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> ten // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.subtract"[[C:.*]] - %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.subtract"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -173,7 +170,7 @@ func.func @rsqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.rsqrt"[[C:.*]] - %1 = "ttir.rsqrt"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.rsqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -181,7 +178,7 @@ func.func @sigmoid(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.sigmoid"[[C:.*]] - %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sigmoid"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -189,7 +186,7 @@ func.func @sqrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.sqrt"[[C:.*]] - %1 = "ttir.sqrt"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sqrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -197,7 +194,7 @@ func.func @sine(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.sin"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + %1 = "ttir.sin"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> return %1 : tensor<32x32xf32> } @@ -205,11 +202,11 @@ func.func @softmax(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> { %0 = tensor.empty() : tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for positive dimension attribute - %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> + %1 = "ttir.softmax"(%arg0, %0) <{dimension = 1 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> %2 = tensor.empty() : tensor<512x1024xbf16> // CHECK: %[[C:.*]] = "ttnn.softmax"[[C:.*]] // Check for negative dimension attribute - %3 = "ttir.softmax"(%1, %2) <{dimension = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> + %3 = "ttir.softmax"(%1, %2) <{dimension = -1 : si32}> : (tensor<512x1024xbf16>, tensor<512x1024xbf16>) -> tensor<512x1024xbf16> return %3 : tensor<512x1024xbf16> } @@ -217,7 +214,7 @@ func.func @cbrt(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.cbrt"[[C:.*]] - %1 = "ttir.cbrt"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.cbrt"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -226,7 +223,7 @@ func.func @typecast(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { // CHECK: %[[C:.*]] = "ttnn.typecast" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xbf16, - %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } @@ -234,14 +231,14 @@ func.func @log(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<6 // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.log"[[C:.*]] - %1 = "ttir.log"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.log"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } func.func @log1p(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.log1p"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.log1p"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> @@ -250,7 +247,7 @@ func.func @log1p(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { func.func @expm1(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.expm1"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.expm1"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.expm1"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> @@ -259,7 +256,7 @@ func.func @expm1(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { func.func @sign(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) <{dtype = {{.*}}, layout = {{.*}}, memory_config = {{.*}}, <{{.*}}>>, shape = #ttnn.shape<[[TENSOR_SHAPE:[0-9]+x[0-9]+]]>}> - %1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.sign"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %{{[0-9]+}} = "ttnn.sign"(%{{[0-9]+}}, [[VAL0]]) <{operandSegmentSizes = array}> : (tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}>, tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}) -> tensor<[[TENSOR_SHAPE]]x{{.*}}, {{.*}}> return %1 : tensor<64x128xf32> // CHECK: return %{{[0-9]+}} : tensor<[[TENSOR_SHAPE]]xf32, {{.*}}> @@ -268,7 +265,7 @@ func.func @sign(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { func.func @remainder(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>) -> tensor<32x32xf32> { %0 = tensor.empty() : tensor<32x32xf32> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} -> tensor<32x32xf32, {{.*}} - %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + %1 = "ttir.remainder"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> // CHECK: %[[REM:[0-9]+]] = "ttnn.remainder"({{.*}}, {{.*}}, %[[EMPTY]]){{.*}} -> tensor<32x32xf32, {{.*}} return %1 : tensor<32x32xf32> // CHECK: return {{.*}} : tensor<32x32xf32, {{.*}} @@ -283,9 +280,9 @@ func.func @get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<1xi32> { func.func @test_where(%arg0: tensor<13x37xbf16>, %arg1: tensor<13x37xbf16>) -> tensor<13x37xbf16> { %0 = tensor.empty() : tensor<13x37xbf16> - %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %1 = "ttir.eq"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> %2 = tensor.empty() : tensor<13x37xbf16> - %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> + %3 = "ttir.where"(%1, %arg0, %arg1, %2) <{operandSegmentSizes = array}> : (tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>, tensor<13x37xbf16>) -> tensor<13x37xbf16> // CHECK: %[[EMPTY:.*]] = "ttnn.empty"{{.*}} // CHECK: %[[VAL1:[0-9]+]] = "ttnn.eq"(%{{[0-9]+}}, %{{[0-9]+}}, %[[EMPTY]]) // CHECK: %{{[0-9]+}} = "ttnn.where"(%[[VAL1]], %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}}) @@ -300,7 +297,7 @@ func.func @gelu(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, - %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.gelu"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> } @@ -308,7 +305,7 @@ func.func @tan(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.tan"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.tan"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } @@ -316,20 +313,20 @@ func.func @tanh(%arg0: tensor<64x128xbf16>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: [[VAL0:%[0-9]+]] = "ttnn.empty"(%{{[0-9]+}}) // CHECK: %{{[0-9]+}} = "ttnn.tanh"(%{{[0-9]+}}, [[VAL0]]) - %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.tanh"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } func.func @addint32(%arg0: tensor<64x128xi32>, %arg1: tensor<64x128xi32>) -> tensor<64x128xi32> { %0 = tensor.empty() : tensor<64x128xi32> - %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xi32>, tensor<64x128xi32>, tensor<64x128xi32>) -> tensor<64x128xi32> + %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xi32>, tensor<64x128xi32>, tensor<64x128xi32>) -> tensor<64x128xi32> return %1 : tensor<64x128xi32> } func.func @scatter(%arg0: tensor<1x3x320x320xf32>, %arg1: tensor<1x3x32x32xf32>) -> tensor<1x3x320x320xf32> { %0 = tensor.empty() : tensor<1x3x320x320xf32> %1 = tensor.empty() : tensor<1x1xi32> - %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile], scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ + %2 = "ttir.scatter"(%arg0, %1, %arg1, %0) <{index_vector_dim = 1 : i32, indices_are_sorted = false, input_batching_dims = array, inserted_window_dims = array, scatter_dims_to_operand_dims = array, scatter_indices_batching_dims = array, unique_indices = false, update_window_dims = array}> ({ ^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>): "ttir.yield"(%arg4) : (tensor<1xf32>) -> () }) : (tensor<1x3x320x320xf32>, tensor<1x1xi32>, tensor<1x3x32x32xf32>, tensor<1x3x320x320xf32>) -> tensor<1x3x320x320xf32> diff --git a/test/ttmlir/Silicon/TTNN/simple_index.mlir b/test/ttmlir/Silicon/TTNN/simple_index.mlir index 6e5ead92a..fcd163dff 100644 --- a/test/ttmlir/Silicon/TTNN/simple_index.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_index.mlir @@ -1,12 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<4x32x32xbf16>) -> tensor<4x32x16xbf16> { %0 = tensor.empty() : tensor<4x32x16xbf16> // CHECK: %[[C:.*]] = "ttnn.slice"[[C:.*]] - %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = 0: i32, end = 32: i32, step = 2: i32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x32x32xbf16>, tensor<4x32x16xbf16>) -> tensor<4x32x16xbf16> + %1 = "ttir.index"(%arg0, %0) <{dim = 2: i32, begin = 0: i32, end = 32: i32, step = 2: i32}> : (tensor<4x32x32xbf16>, tensor<4x32x16xbf16>) -> tensor<4x32x16xbf16> return %1 : tensor<4x32x16xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/simple_linear.mlir b/test/ttmlir/Silicon/TTNN/simple_linear.mlir index f53de38cf..b65bf99db 100644 --- a/test/ttmlir/Silicon/TTNN/simple_linear.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_linear.mlir @@ -2,7 +2,6 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device_tile = #tt.operand_constraint module { func.func @simple_linear_without_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>) -> tensor<64x64xbf16> { // CHECK: "ttnn.empty" @@ -13,7 +12,7 @@ module { // CHECK-SAME: tensor<128x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } @@ -27,7 +26,7 @@ module { // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 // CHECK-SAME: tensor<64x64xbf16 - %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> + %1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16> return %1 : tensor<64x64xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/simple_logical.mlir b/test/ttmlir/Silicon/TTNN/simple_logical.mlir index e5d68f5ec..558f815c7 100644 --- a/test/ttmlir/Silicon/TTNN/simple_logical.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_logical.mlir @@ -1,15 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn - -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - module attributes {} { func.func @logical_and(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: {{.*}} = "ttnn.empty"{{.*}} - %1 = "ttir.logical_and"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.logical_and"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.logical_and" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, @@ -20,7 +16,7 @@ module attributes {} { func.func @logical_not(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: {{.*}} = "ttnn.empty"{{.*}} - %1 = "ttir.logical_not"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.logical_not"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.logical_not" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, @@ -30,7 +26,7 @@ module attributes {} { func.func @logical_or(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> // CHECK: {{.*}} = "ttnn.empty"{{.*}} - %1 = "ttir.logical_or"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> + %1 = "ttir.logical_or"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> // CHECK: %[[C:.*]] = "ttnn.logical_or" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xf32, @@ -46,7 +42,7 @@ module attributes {} { // CHECK-SAME: [[TENSOR]] // CHECK-SAME: [[TENSOR]] // CHECK-SAME: -> [[TENSOR]] - %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.logical_xor"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/simple_matmul.mlir b/test/ttmlir/Silicon/TTNN/simple_matmul.mlir index 9c240b0ab..f221001bb 100644 --- a/test/ttmlir/Silicon/TTNN/simple_matmul.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_matmul.mlir @@ -1,13 +1,12 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device_tile = #tt.operand_constraint // CHECK: #[[TILED_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>, > module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> // CHECK: %[[C:.*]] = "ttnn.matmul"[[C:.*]] - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#any_device_tile, #any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> return %1 : tensor<64x96xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/simple_maxpool2d.mlir b/test/ttmlir/Silicon/TTNN/simple_maxpool2d.mlir index 4722e9c52..4fdd836dd 100644 --- a/test/ttmlir/Silicon/TTNN/simple_maxpool2d.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_maxpool2d.mlir @@ -1,12 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> { %0 = tensor.empty() : tensor<1x64x64x32xbf16> // CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]] - %1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16> + %1 = "ttir.max_pool2d"(%arg0, %0) <{kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16> return %1 : tensor<1x64x64x32xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/simple_mean.mlir b/test/ttmlir/Silicon/TTNN/simple_mean.mlir index f8ca09f6c..0a3250936 100644 --- a/test/ttmlir/Silicon/TTNN/simple_mean.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_mean.mlir @@ -2,12 +2,11 @@ // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn // UNSUPPORTED: true -#any_device = #tt.operand_constraint module { func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> { %0 = tensor.empty() : tensor<512x32xbf16> // CHECK: %[[C:.*]] = "ttnn.mean"[[C:.*]] - %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> + %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<512x1024xbf16>, tensor<512x32xbf16>) -> tensor<512x32xbf16> return %1 : tensor<512x32xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/simple_reductions.mlir b/test/ttmlir/Silicon/TTNN/simple_reductions.mlir index 908a2c67f..28eaf47fa 100644 --- a/test/ttmlir/Silicon/TTNN/simple_reductions.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_reductions.mlir @@ -1,52 +1,50 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint - func.func @sum(%arg0: tensor<1x1x512x64xbf16>) -> tensor<1x1x512xbf16> { %0 = tensor.empty() : tensor<1x1x512xbf16> // CHECK: %[[C:.*]] = "ttnn.sum"[[C:.*]] - %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16> + %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16> return %1 : tensor<1x1x512xbf16> } func.func @sum_last_2_dims(%arg0: tensor<1x32x512x64xbf16>) -> tensor<1x32xbf16> { %0 = tensor.empty() : tensor<1x32xbf16> // CHECK: %[[C:.*]] = "ttnn.sum"[[C:.*]] - %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32, -2: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x512x64xbf16>, tensor<1x32xbf16>) -> tensor<1x32xbf16> + %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-1: i32, -2: i32], keep_dim = true}> : (tensor<1x32x512x64xbf16>, tensor<1x32xbf16>) -> tensor<1x32xbf16> return %1 : tensor<1x32xbf16> } func.func @sum_first_dim(%arg0: tensor<64x10xf32>) -> tensor<1x10xf32> { %0 = tensor.empty() : tensor<1x10xf32> - %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-2 : i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<64x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + %1 = "ttir.sum"(%arg0, %0) <{dim_arg = [-2 : i32], keep_dim = true}> : (tensor<64x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> return %1: tensor<1x10xf32> } func.func @mean(%arg0: tensor<1x1x512x64xbf16>) -> tensor<1x1x512xbf16> { %0 = tensor.empty() : tensor<1x1x512xbf16> // CHECK: %[[C:.*]] = "ttnn.mean"[[C:.*]] - %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16> + %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16> return %1 : tensor<1x1x512xbf16> } func.func @mean_last_2_dims(%arg0: tensor<1x32x512x64xbf16>) -> tensor<1x32xbf16> { %0 = tensor.empty() : tensor<1x32xbf16> // CHECK: %[[C:.*]] = "ttnn.mean"[[C:.*]] - %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32, -2: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x512x64xbf16>, tensor<1x32xbf16>) -> tensor<1x32xbf16> + %1 = "ttir.mean"(%arg0, %0) <{dim_arg = [-1: i32, -2: i32], keep_dim = true}> : (tensor<1x32x512x64xbf16>, tensor<1x32xbf16>) -> tensor<1x32xbf16> return %1 : tensor<1x32xbf16> } func.func @max(%arg0: tensor<1x1x512x64xbf16>) -> tensor<1x1x512xbf16> { %0 = tensor.empty() : tensor<1x1x512xbf16> // CHECK: %[[C:.*]] = "ttnn.max"[[C:.*]] - %1 = "ttir.max"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16> + %1 = "ttir.max"(%arg0, %0) <{dim_arg = [-1: i32], keep_dim = true}> : (tensor<1x1x512x64xbf16>, tensor<1x1x512xbf16>) -> tensor<1x1x512xbf16> return %1 : tensor<1x1x512xbf16> } func.func @max_last_2_dims(%arg0: tensor<1x32x512x64xbf16>) -> tensor<1x32xbf16> { %0 = tensor.empty() : tensor<1x32xbf16> // CHECK: %[[C:.*]] = "ttnn.max"[[C:.*]] - %1 = "ttir.max"(%arg0, %0) <{dim_arg = [-1: i32, -2: i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x512x64xbf16>, tensor<1x32xbf16>) -> tensor<1x32xbf16> + %1 = "ttir.max"(%arg0, %0) <{dim_arg = [-1: i32, -2: i32], keep_dim = true}> : (tensor<1x32x512x64xbf16>, tensor<1x32xbf16>) -> tensor<1x32xbf16> return %1 : tensor<1x32xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/simple_slice.mlir b/test/ttmlir/Silicon/TTNN/simple_slice.mlir index 8d43cb7dc..e32101781 100644 --- a/test/ttmlir/Silicon/TTNN/simple_slice.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_slice.mlir @@ -1,12 +1,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<4x32x32xbf16>) -> tensor<2x16x16xbf16> { %0 = tensor.empty() : tensor<2x16x16xbf16> // CHECK: %[[C:.*]] = "ttnn.slice"[[C:.*]] - %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [2: i32, 16: i32, 16: i32], step = [1: i32, 1: i32, 1: i32], operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<4x32x32xbf16>, tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16> + %1 = "ttir.slice"(%arg0, %0) <{begins = [0: i32, 0: i32, 0: i32], ends = [2: i32, 16: i32, 16: i32], step = [1: i32, 1: i32, 1: i32]}> : (tensor<4x32x32xbf16>, tensor<2x16x16xbf16>) -> tensor<2x16x16xbf16> return %1 : tensor<2x16x16xbf16> } } diff --git a/test/ttmlir/Silicon/TTNN/simple_typecast.mlir b/test/ttmlir/Silicon/TTNN/simple_typecast.mlir index cb4f2d64f..82666195d 100644 --- a/test/ttmlir/Silicon/TTNN/simple_typecast.mlir +++ b/test/ttmlir/Silicon/TTNN/simple_typecast.mlir @@ -2,12 +2,11 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint func.func @typecast(%arg0: tensor<64x128xf32>) -> tensor<64x128xbf16> { %0 = tensor.empty() : tensor<64x128xbf16> // CHECK: %[[C:.*]] = "ttnn.typecast" // CHECK-SAME: tensor<64x128xf32, // CHECK-SAME: tensor<64x128xbf16, - %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> + %1 = "ttir.typecast"(%arg0, %0) <{operandSegmentSizes = array}> : (tensor<64x128xf32>, tensor<64x128xbf16>) -> tensor<64x128xbf16> return %1 : tensor<64x128xbf16> } diff --git a/test/ttmlir/Silicon/TTNN/transpose.mlir b/test/ttmlir/Silicon/TTNN/transpose.mlir index 184b6b807..b9805dd3c 100644 --- a/test/ttmlir/Silicon/TTNN/transpose.mlir +++ b/test/ttmlir/Silicon/TTNN/transpose.mlir @@ -1,33 +1,30 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" %s > %t.mlir // RUN: FileCheck %s --input-file=%t.mlir // RUN: ttmlir-translate --ttnn-to-flatbuffer %t.mlir > %t.ttnn -#any_device = #tt.operand_constraint -#any_device_tile = #tt.operand_constraint - func.func @transpose(%arg0: tensor<64x128xbf16>) -> tensor<128x64xbf16> { %0 = tensor.empty() : tensor<128x64xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<64x128xbf16>, tensor<128x64xbf16>) -> tensor<128x64xbf16> return %1 : tensor<128x64xbf16> } func.func @transpose_8x8(%arg0: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 0 : si32, dim1 = 1 : si32}> : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %1 : tensor<32x32xbf16> } func.func @transpose_8x16_reverse_dims(%arg0: tensor<64x16xbf16>) -> tensor<16x64xbf16> { %0 = tensor.empty() : tensor<16x64xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 1 : si32, dim1 = 0 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<64x16xbf16>, tensor<16x64xbf16>) -> tensor<16x64xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = 1 : si32, dim1 = 0 : si32}> : (tensor<64x16xbf16>, tensor<16x64xbf16>) -> tensor<16x64xbf16> return %1 : tensor<16x64xbf16> } func.func @transpose_negative_dims(%arg0: tensor<32x32xbf16>) -> tensor<32x32xbf16> { %0 = tensor.empty() : tensor<32x32xbf16> // CHECK: %[[C:.*]] = "ttnn.transpose"[[C:.*]] - %1 = "ttir.transpose"(%arg0, %0) <{dim0 = -1 : si32, dim1 = -2 : si32, operand_constraints = [#any_device_tile, #any_device_tile]}> : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> + %1 = "ttir.transpose"(%arg0, %0) <{dim0 = -1 : si32, dim1 = -2 : si32}> : (tensor<32x32xbf16>, tensor<32x32xbf16>) -> tensor<32x32xbf16> return %1 : tensor<32x32xbf16> } diff --git a/test/ttmlir/Translate/TTNN/1d_tensor.mlir b/test/ttmlir/Translate/TTNN/1d_tensor.mlir index 695812737..5752be5ce 100644 --- a/test/ttmlir/Translate/TTNN/1d_tensor.mlir +++ b/test/ttmlir/Translate/TTNN/1d_tensor.mlir @@ -1,8 +1,6 @@ // RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | ttmlir-translate --ttnn-to-flatbuffer -#any_device = #tt.operand_constraint - func.func @embedding_1d_tensor(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> { %0 = tensor.empty() : tensor<32x128xf32> - %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32> + %1 = "ttir.embedding"(%arg0, %arg1, %0) : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32> return %1 : tensor<32x128xf32> } diff --git a/test/unittests/TestScheduler/TestScheduler.cpp b/test/unittests/TestScheduler/TestScheduler.cpp index 7d44bb079..39a7fac50 100644 --- a/test/unittests/TestScheduler/TestScheduler.cpp +++ b/test/unittests/TestScheduler/TestScheduler.cpp @@ -76,17 +76,6 @@ class SchedulerBase : public ::testing::Test { builder.getUnknownLoc(), getTensorShape(), builder.getF32Type()); } - llvm::SmallVector createOperandConstraints() { - llvm::SmallVector operand_constraints; - mlir::Attribute operand_constraint_attribute = - builder.getAttr( - mlir::tt::OperandConstraint::AnyDevice); - operand_constraints.push_back(operand_constraint_attribute); - operand_constraints.push_back(operand_constraint_attribute); - operand_constraints.push_back(operand_constraint_attribute); - return operand_constraints; - } - mlir::func::FuncOp createFuncOp() { mlir::SmallVector input; input.push_back(getTensorType()); @@ -119,11 +108,9 @@ TEST_F(SchedulerBase, FixedSchedule) { mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0); mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); - mlir::ArrayAttr attrs = builder.getArrayAttr(createOperandConstraints()); - // First operation has arg1 and arg2 and %0 as dps operand - ttir::TTIROp op = builder.create(builder.getUnknownLoc(), lhs, - rhs, dest, attrs); + ttir::TTIROp op = + builder.create(builder.getUnknownLoc(), lhs, rhs, dest); // Create a chain of operations by using the result of the previous operation llvm::SmallVector operands = {rhs, @@ -137,8 +124,7 @@ TEST_F(SchedulerBase, FixedSchedule) { mlir::Value lhs = operands[operands.size() - 2]; mlir::Value rhs = operands[operands.size() - 1]; dest = createEmptyTensor(); - op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest, - attrs); + op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest); operands.push_back(op.getOperation()->getResult(0)); ops.push_back(op); } @@ -172,11 +158,9 @@ TEST_F(SchedulerBase, SingleOp) { mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0); mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); - mlir::ArrayAttr attrs = builder.getArrayAttr(createOperandConstraints()); - // First operation has arg1 and arg2 and %0 as dps operand - ttir::TTIROp op = builder.create(builder.getUnknownLoc(), lhs, - rhs, dest, attrs); + ttir::TTIROp op = + builder.create(builder.getUnknownLoc(), lhs, rhs, dest); mlir::tt::scheduler::Scheduler scheduler(&func); ASSERT_TRUE(scheduler.hasUnscheduledOps()); @@ -199,9 +183,8 @@ TEST_F(SchedulerBase, VerifyFork) { mlir::Value dest = createEmptyTensor(); mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0); mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); - mlir::ArrayAttr attrs = builder.getArrayAttr(createOperandConstraints()); - ttir::TTIROp op = builder.create(builder.getUnknownLoc(), lhs, - rhs, dest, attrs); + ttir::TTIROp op = + builder.create(builder.getUnknownLoc(), lhs, rhs, dest); std::vector ops; ops.push_back(op); @@ -212,12 +195,10 @@ TEST_F(SchedulerBase, VerifyFork) { // Create the second operation which works on the result of the first // operation and arg1 dest = createEmptyTensor(); - op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest, - attrs); + op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest); ops.push_back(op); dest = createEmptyTensor(); - op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest, - attrs); + op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest); ops.push_back(op); // Create the third operation which works on the result of the second and @@ -225,8 +206,7 @@ TEST_F(SchedulerBase, VerifyFork) { lhs = ops[ops.size() - 2].getOperation()->getResult(0); rhs = ops[ops.size() - 1].getOperation()->getResult(0); dest = createEmptyTensor(); - op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest, - attrs); + op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest); ops.push_back(op); mlir::tt::scheduler::Scheduler scheduler(&func); diff --git a/tools/explorer/test/models/forward_and_backward.mlir b/tools/explorer/test/models/forward_and_backward.mlir index 3f0b8f781..e205bcf2b 100644 --- a/tools/explorer/test/models/forward_and_backward.mlir +++ b/tools/explorer/test/models/forward_and_backward.mlir @@ -1,30 +1,30 @@ module @SimpleModel attributes {} { func.func @forward(%arg0: tensor<1x784xf32> {ttir.name = "input_1"}, %arg1: tensor<10x784xf32> {ttir.name = "linear.weight"}, %arg2: tensor<10xf32> {ttir.name = "linear.bias"}) -> (tensor<1x10xf32> {ttir.name = "SimpleModel_472.output_softmax_1495"}) { %0 = tensor.empty() : tensor<784x10xf32> - %1 = "ttir.transpose"(%arg1, %0) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<10x784xf32>, tensor<784x10xf32>) -> tensor<784x10xf32> + %1 = "ttir.transpose"(%arg1, %0) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<10x784xf32>, tensor<784x10xf32>) -> tensor<784x10xf32> %2 = tensor.empty() : tensor<1x10xf32> - %3 = "ttir.matmul"(%arg0, %1, %2) <{operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x784xf32>, tensor<784x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + %3 = "ttir.matmul"(%arg0, %1, %2) : (tensor<1x784xf32>, tensor<784x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> %4 = tensor.empty() : tensor<1x10xf32> - %5 = "ttir.add"(%3, %arg2, %4) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x10xf32>, tensor<10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + %5 = "ttir.add"(%3, %arg2, %4) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> %6 = tensor.empty() : tensor<1x10xf32> - %7 = "ttir.softmax"(%5, %6) <{dimension = -1 : si32, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + %7 = "ttir.softmax"(%5, %6) <{dimension = -1 : si32}> : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> return %7 : tensor<1x10xf32> } func.func @backward(%arg0: tensor<1x10xf32> {ttir.name = "loss_SimpleModel_472.output_softmax_1495"}, %arg1: tensor<1x10xf32> {ttir.name = "SimpleModel_472.output_softmax_1495"}, %arg2: tensor<1x784xf32> {ttir.name = "input_1"}) -> (tensor<1x10xf32> {ttir.name = "grad_acc_linear.bias_grad_accumulator"}, tensor<10x784xf32> {ttir.name = "grad_acc_linear.weight_grad_accumulator"}) { %0 = tensor.empty() : tensor<1x10xf32> - %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> %2 = tensor.empty() : tensor<1x1xf32> - %3 = "ttir.sum"(%1, %2) <{keep_dim = true, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x10xf32>, tensor<1x1xf32>) -> tensor<1x1xf32> + %3 = "ttir.sum"(%1, %2) <{keep_dim = true}> : (tensor<1x10xf32>, tensor<1x1xf32>) -> tensor<1x1xf32> %4 = tensor.empty() : tensor<1x10xf32> - %5 = "ttir.subtract"(%arg0, %3, %4) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x10xf32>, tensor<1x1xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + %5 = "ttir.subtract"(%arg0, %3, %4) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<1x1xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> %6 = tensor.empty() : tensor<1x10xf32> - %7 = "ttir.multiply"(%5, %arg1, %6) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> + %7 = "ttir.multiply"(%5, %arg1, %6) <{operandSegmentSizes = array}> : (tensor<1x10xf32>, tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32> %8 = tensor.empty() : tensor<784x1xf32> - %9 = "ttir.transpose"(%arg2, %8) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x784xf32>, tensor<784x1xf32>) -> tensor<784x1xf32> + %9 = "ttir.transpose"(%arg2, %8) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<1x784xf32>, tensor<784x1xf32>) -> tensor<784x1xf32> %10 = tensor.empty() : tensor<784x10xf32> - %11 = "ttir.matmul"(%9, %7, %10) <{operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<784x1xf32>, tensor<1x10xf32>, tensor<784x10xf32>) -> tensor<784x10xf32> + %11 = "ttir.matmul"(%9, %7, %10) : (tensor<784x1xf32>, tensor<1x10xf32>, tensor<784x10xf32>) -> tensor<784x10xf32> %12 = tensor.empty() : tensor<10x784xf32> - %13 = "ttir.transpose"(%11, %12) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<784x10xf32>, tensor<10x784xf32>) -> tensor<10x784xf32> + %13 = "ttir.transpose"(%11, %12) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<784x10xf32>, tensor<10x784xf32>) -> tensor<10x784xf32> return %7, %13 : tensor<1x10xf32>, tensor<10x784xf32> } } diff --git a/tools/explorer/test/models/linear_autoencoder.mlir b/tools/explorer/test/models/linear_autoencoder.mlir index 8d7defc53..d8af25bbf 100644 --- a/tools/explorer/test/models/linear_autoencoder.mlir +++ b/tools/explorer/test/models/linear_autoencoder.mlir @@ -1,49 +1,49 @@ module @LinearAE attributes {} { func.func @forward(%arg0: tensor<1x784xf32> {ttir.name = "input_1"}, %arg1: tensor<784x128xf32> {ttir.name = "encoder_lin1.weight"}, %arg2: tensor<128xf32> {ttir.name = "encoder_lin1.bias"}, %arg3: tensor<128x64xf32> {ttir.name = "encoder_lin2.weight"}, %arg4: tensor<64xf32> {ttir.name = "encoder_lin2.bias"}, %arg5: tensor<64x12xf32> {ttir.name = "encoder_lin3.weight"}, %arg6: tensor<12xf32> {ttir.name = "encoder_lin3.bias"}, %arg7: tensor<12x3xf32> {ttir.name = "encoder_lin4.weight"}, %arg8: tensor<3xf32> {ttir.name = "encoder_lin4.bias"}, %arg9: tensor<3x12xf32> {ttir.name = "decoder_lin1.weight"}, %arg10: tensor<12xf32> {ttir.name = "decoder_lin1.bias"}, %arg11: tensor<12x64xf32> {ttir.name = "decoder_lin2.weight"}, %arg12: tensor<64xf32> {ttir.name = "decoder_lin2.bias"}, %arg13: tensor<64x128xf32> {ttir.name = "decoder_lin3.weight"}, %arg14: tensor<128xf32> {ttir.name = "decoder_lin3.bias"}, %arg15: tensor<128x784xf32> {ttir.name = "decoder_lin4.weight"}, %arg16: tensor<784xf32> {ttir.name = "decoder_lin4.bias"}) -> (tensor<1x784xf32> {ttir.name = "LinearAE.output_add_29"}) { %0 = tensor.empty() : tensor<1x128xf32> - %1 = "ttir.matmul"(%arg0, %arg1, %0) <{operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x784xf32>, tensor<784x128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> + %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<1x784xf32>, tensor<784x128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> %2 = tensor.empty() : tensor<1x128xf32> - %3 = "ttir.add"(%1, %arg2, %2) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x128xf32>, tensor<128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> + %3 = "ttir.add"(%1, %arg2, %2) <{operandSegmentSizes = array}> : (tensor<1x128xf32>, tensor<128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> %4 = tensor.empty() : tensor<1x128xf32> - %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> + %5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array}> : (tensor<1x128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> %6 = tensor.empty() : tensor<1x64xf32> - %7 = "ttir.matmul"(%5, %arg3, %6) <{operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x128xf32>, tensor<128x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %7 = "ttir.matmul"(%5, %arg3, %6) : (tensor<1x128xf32>, tensor<128x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> %8 = tensor.empty() : tensor<1x64xf32> - %9 = "ttir.add"(%7, %arg4, %8) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x64xf32>, tensor<64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %9 = "ttir.add"(%7, %arg4, %8) <{operandSegmentSizes = array}> : (tensor<1x64xf32>, tensor<64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> %10 = tensor.empty() : tensor<1x64xf32> - %11 = "ttir.relu"(%9, %10) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %11 = "ttir.relu"(%9, %10) <{operandSegmentSizes = array}> : (tensor<1x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> %12 = tensor.empty() : tensor<1x12xf32> - %13 = "ttir.matmul"(%11, %arg5, %12) <{operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x64xf32>, tensor<64x12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> + %13 = "ttir.matmul"(%11, %arg5, %12) : (tensor<1x64xf32>, tensor<64x12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> %14 = tensor.empty() : tensor<1x12xf32> - %15 = "ttir.add"(%13, %arg6, %14) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x12xf32>, tensor<12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> + %15 = "ttir.add"(%13, %arg6, %14) <{operandSegmentSizes = array}> : (tensor<1x12xf32>, tensor<12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> %16 = tensor.empty() : tensor<1x12xf32> - %17 = "ttir.relu"(%15, %16) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> + %17 = "ttir.relu"(%15, %16) <{operandSegmentSizes = array}> : (tensor<1x12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> %18 = tensor.empty() : tensor<1x3xf32> - %19 = "ttir.matmul"(%17, %arg7, %18) <{operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x12xf32>, tensor<12x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %19 = "ttir.matmul"(%17, %arg7, %18) : (tensor<1x12xf32>, tensor<12x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> %20 = tensor.empty() : tensor<1x3xf32> - %21 = "ttir.add"(%19, %arg8, %20) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x3xf32>, tensor<3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %21 = "ttir.add"(%19, %arg8, %20) <{operandSegmentSizes = array}> : (tensor<1x3xf32>, tensor<3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> %22 = tensor.empty() : tensor<1x12xf32> - %23 = "ttir.matmul"(%21, %arg9, %22) <{operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x3xf32>, tensor<3x12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> + %23 = "ttir.matmul"(%21, %arg9, %22) : (tensor<1x3xf32>, tensor<3x12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> %24 = tensor.empty() : tensor<1x12xf32> - %25 = "ttir.add"(%23, %arg10, %24) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x12xf32>, tensor<12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> + %25 = "ttir.add"(%23, %arg10, %24) <{operandSegmentSizes = array}> : (tensor<1x12xf32>, tensor<12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> %26 = tensor.empty() : tensor<1x12xf32> - %27 = "ttir.relu"(%25, %26) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> + %27 = "ttir.relu"(%25, %26) <{operandSegmentSizes = array}> : (tensor<1x12xf32>, tensor<1x12xf32>) -> tensor<1x12xf32> %28 = tensor.empty() : tensor<1x64xf32> - %29 = "ttir.matmul"(%27, %arg11, %28) <{operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x12xf32>, tensor<12x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %29 = "ttir.matmul"(%27, %arg11, %28) : (tensor<1x12xf32>, tensor<12x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> %30 = tensor.empty() : tensor<1x64xf32> - %31 = "ttir.add"(%29, %arg12, %30) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x64xf32>, tensor<64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %31 = "ttir.add"(%29, %arg12, %30) <{operandSegmentSizes = array}> : (tensor<1x64xf32>, tensor<64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> %32 = tensor.empty() : tensor<1x64xf32> - %33 = "ttir.relu"(%31, %32) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %33 = "ttir.relu"(%31, %32) <{operandSegmentSizes = array}> : (tensor<1x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> %34 = tensor.empty() : tensor<1x128xf32> - %35 = "ttir.matmul"(%33, %arg13, %34) <{operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x64xf32>, tensor<64x128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> + %35 = "ttir.matmul"(%33, %arg13, %34) : (tensor<1x64xf32>, tensor<64x128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> %36 = tensor.empty() : tensor<1x128xf32> - %37 = "ttir.add"(%35, %arg14, %36) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x128xf32>, tensor<128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> + %37 = "ttir.add"(%35, %arg14, %36) <{operandSegmentSizes = array}> : (tensor<1x128xf32>, tensor<128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> %38 = tensor.empty() : tensor<1x128xf32> - %39 = "ttir.relu"(%37, %38) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> + %39 = "ttir.relu"(%37, %38) <{operandSegmentSizes = array}> : (tensor<1x128xf32>, tensor<1x128xf32>) -> tensor<1x128xf32> %40 = tensor.empty() : tensor<1x784xf32> - %41 = "ttir.matmul"(%39, %arg15, %40) <{operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x128xf32>, tensor<128x784xf32>, tensor<1x784xf32>) -> tensor<1x784xf32> + %41 = "ttir.matmul"(%39, %arg15, %40) : (tensor<1x128xf32>, tensor<128x784xf32>, tensor<1x784xf32>) -> tensor<1x784xf32> %42 = tensor.empty() : tensor<1x784xf32> - %43 = "ttir.add"(%41, %arg16, %42) <{operandSegmentSizes = array, operand_constraints = [#tt.operand_constraint, #tt.operand_constraint, #tt.operand_constraint]}> : (tensor<1x784xf32>, tensor<784xf32>, tensor<1x784xf32>) -> tensor<1x784xf32> + %43 = "ttir.add"(%41, %arg16, %42) <{operandSegmentSizes = array}> : (tensor<1x784xf32>, tensor<784xf32>, tensor<1x784xf32>) -> tensor<1x784xf32> return %43 : tensor<1x784xf32> } } diff --git a/tools/explorer/test/models/open_llama_3b_single_layer.mlir b/tools/explorer/test/models/open_llama_3b_single_layer.mlir index 5e17dc39e..97731870b 100644 --- a/tools/explorer/test/models/open_llama_3b_single_layer.mlir +++ b/tools/explorer/test/models/open_llama_3b_single_layer.mlir @@ -1,170 +1,169 @@ -#any_device = #tt.operand_constraint #loc = loc("LlamaForCausalLM":0:0) #system_desc = #tt.system_desc<[{role = host, target_triple = "x86_64-pc-linux-gnu"}], [{arch = , grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32, l1_unreserved_base = 1024, erisc_l1_unreserved_base = 1024, dram_unreserved_base = 1024, dram_unreserved_end = 1073741824, physical_cores = {worker = [ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 1x0, 1x1, 1x2, 1x3, 1x4, 1x5, 1x6, 1x7, 2x0, 2x1, 2x2, 2x3, 2x4, 2x5, 2x6, 2x7, 3x0, 3x1, 3x2, 3x3, 3x4, 3x5, 3x6, 3x7, 4x0, 4x1, 4x2, 4x3, 4x4, 4x5, 4x6, 4x7, 5x0, 5x1, 5x2, 5x3, 5x4, 5x5, 5x6, 5x7, 6x0, 6x1, 6x2, 6x3, 6x4, 6x5, 6x6, 6x7, 7x0, 7x1, 7x2, 7x3, 7x4, 7x5, 7x6, 7x7] dram = [ 8x0, 9x0, 10x0, 8x1, 9x1, 10x1, 8x2, 9x2, 10x2, 8x3, 9x3, 10x3]}, supported_data_types = [, , , , , , , , , , , ], supported_tile_sizes = [ 4x16, 16x16, 32x16, 4x32, 16x32, 32x32], num_cbs = 32}], [0], [3 : i32], [ 0x0x0x0]> module @LlamaForCausalLM attributes {tt.system_desc = #system_desc} { func.func @forward(%arg0: tensor<1x12xi32> {ttir.name = "input_1"} loc("LlamaForCausalLM":0:0), %arg1: tensor<1xf32> {ttir.name = "input_1_add_4"} loc("LlamaForCausalLM":0:0), %arg2: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_14"} loc("LlamaForCausalLM":0:0), %arg3: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_24.1"} loc("LlamaForCausalLM":0:0), %arg4: tensor<1xf32> {ttir.name = "input_1_multiply_25"} loc("LlamaForCausalLM":0:0), %arg5: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_26.1"} loc("LlamaForCausalLM":0:0), %arg6: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_38.1"} loc("LlamaForCausalLM":0:0), %arg7: tensor<1xf32> {ttir.name = "input_1_multiply_39"} loc("LlamaForCausalLM":0:0), %arg8: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_40.1"} loc("LlamaForCausalLM":0:0), %arg9: tensor<1xf32> {ttir.name = "input_1_multiply_48"} loc("LlamaForCausalLM":0:0), %arg10: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_49"} loc("LlamaForCausalLM":0:0), %arg11: tensor<1xf32> {ttir.name = "input_1_add_70"} loc("LlamaForCausalLM":0:0), %arg12: tensor<1xf32> {ttir.name = "input_1_add_90"} loc("LlamaForCausalLM":0:0), %arg13: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_100"} loc("LlamaForCausalLM":0:0), %arg14: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_110.1"} loc("LlamaForCausalLM":0:0), %arg15: tensor<1xf32> {ttir.name = "input_1_multiply_111"} loc("LlamaForCausalLM":0:0), %arg16: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_112.1"} loc("LlamaForCausalLM":0:0), %arg17: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_124.1"} loc("LlamaForCausalLM":0:0), %arg18: tensor<1xf32> {ttir.name = "input_1_multiply_125"} loc("LlamaForCausalLM":0:0), %arg19: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_126.1"} loc("LlamaForCausalLM":0:0), %arg20: tensor<1xf32> {ttir.name = "input_1_multiply_134"} loc("LlamaForCausalLM":0:0), %arg21: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_135"} loc("LlamaForCausalLM":0:0), %arg22: tensor<1xf32> {ttir.name = "input_1_add_156"} loc("LlamaForCausalLM":0:0), %arg23: tensor<1xf32> {ttir.name = "input_1_add_176"} loc("LlamaForCausalLM":0:0), %arg24: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_186"} loc("LlamaForCausalLM":0:0), %arg25: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_196.1"} loc("LlamaForCausalLM":0:0), %arg26: tensor<1xf32> {ttir.name = "input_1_multiply_197"} loc("LlamaForCausalLM":0:0), %arg27: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_198.1"} loc("LlamaForCausalLM":0:0), %arg28: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_210.1"} loc("LlamaForCausalLM":0:0), %arg29: tensor<1xf32> {ttir.name = "input_1_multiply_211"} loc("LlamaForCausalLM":0:0), %arg30: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_212.1"} loc("LlamaForCausalLM":0:0), %arg31: tensor<1xf32> {ttir.name = "input_1_multiply_220"} loc("LlamaForCausalLM":0:0), %arg32: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_221"} loc("LlamaForCausalLM":0:0), %arg33: tensor<1xf32> {ttir.name = "input_1_add_242"} loc("LlamaForCausalLM":0:0), %arg34: tensor<1xf32> {ttir.name = "input_1_add_262"} loc("LlamaForCausalLM":0:0), %arg35: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_272"} loc("LlamaForCausalLM":0:0), %arg36: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_282.1"} loc("LlamaForCausalLM":0:0), %arg37: tensor<1xf32> {ttir.name = "input_1_multiply_283"} loc("LlamaForCausalLM":0:0), %arg38: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_284.1"} loc("LlamaForCausalLM":0:0), %arg39: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_296.1"} loc("LlamaForCausalLM":0:0), %arg40: tensor<1xf32> {ttir.name = "input_1_multiply_297"} loc("LlamaForCausalLM":0:0), %arg41: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_298.1"} loc("LlamaForCausalLM":0:0), %arg42: tensor<1xf32> {ttir.name = "input_1_multiply_306"} loc("LlamaForCausalLM":0:0), %arg43: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_307"} loc("LlamaForCausalLM":0:0), %arg44: tensor<1xf32> {ttir.name = "input_1_add_328"} loc("LlamaForCausalLM":0:0), %arg45: tensor<1xf32> {ttir.name = "input_1_add_348"} loc("LlamaForCausalLM":0:0), %arg46: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_358"} loc("LlamaForCausalLM":0:0), %arg47: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_368.1"} loc("LlamaForCausalLM":0:0), %arg48: tensor<1xf32> {ttir.name = "input_1_multiply_369"} loc("LlamaForCausalLM":0:0), %arg49: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_370.1"} loc("LlamaForCausalLM":0:0), %arg50: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_382.1"} loc("LlamaForCausalLM":0:0), %arg51: tensor<1xf32> {ttir.name = "input_1_multiply_383"} loc("LlamaForCausalLM":0:0), %arg52: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_384.1"} loc("LlamaForCausalLM":0:0), %arg53: tensor<1xf32> {ttir.name = "input_1_multiply_392"} loc("LlamaForCausalLM":0:0), %arg54: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_393"} loc("LlamaForCausalLM":0:0), %arg55: tensor<1xf32> {ttir.name = "input_1_add_414"} loc("LlamaForCausalLM":0:0), %arg56: tensor<1xf32> {ttir.name = "input_1_add_434"} loc("LlamaForCausalLM":0:0), %arg57: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_444"} loc("LlamaForCausalLM":0:0), %arg58: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_454.1"} loc("LlamaForCausalLM":0:0), %arg59: tensor<1xf32> {ttir.name = "input_1_multiply_455"} loc("LlamaForCausalLM":0:0), %arg60: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_456.1"} loc("LlamaForCausalLM":0:0), %arg61: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_468.1"} loc("LlamaForCausalLM":0:0), %arg62: tensor<1xf32> {ttir.name = "input_1_multiply_469"} loc("LlamaForCausalLM":0:0), %arg63: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_470.1"} loc("LlamaForCausalLM":0:0), %arg64: tensor<1xf32> {ttir.name = "input_1_multiply_478"} loc("LlamaForCausalLM":0:0), %arg65: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_479"} loc("LlamaForCausalLM":0:0), %arg66: tensor<1xf32> {ttir.name = "input_1_add_500"} loc("LlamaForCausalLM":0:0), %arg67: tensor<1xf32> {ttir.name = "input_1_add_520"} loc("LlamaForCausalLM":0:0), %arg68: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_530"} loc("LlamaForCausalLM":0:0), %arg69: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_540.1"} loc("LlamaForCausalLM":0:0), %arg70: tensor<1xf32> {ttir.name = "input_1_multiply_541"} loc("LlamaForCausalLM":0:0), %arg71: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_542.1"} loc("LlamaForCausalLM":0:0), %arg72: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_554.1"} loc("LlamaForCausalLM":0:0), %arg73: tensor<1xf32> {ttir.name = "input_1_multiply_555"} loc("LlamaForCausalLM":0:0), %arg74: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_556.1"} loc("LlamaForCausalLM":0:0), %arg75: tensor<1xf32> {ttir.name = "input_1_multiply_564"} loc("LlamaForCausalLM":0:0), %arg76: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_565"} loc("LlamaForCausalLM":0:0), %arg77: tensor<1xf32> {ttir.name = "input_1_add_586"} loc("LlamaForCausalLM":0:0), %arg78: tensor<1xf32> {ttir.name = "input_1_add_606"} loc("LlamaForCausalLM":0:0), %arg79: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_616"} loc("LlamaForCausalLM":0:0), %arg80: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_626.1"} loc("LlamaForCausalLM":0:0), %arg81: tensor<1xf32> {ttir.name = "input_1_multiply_627"} loc("LlamaForCausalLM":0:0), %arg82: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_628.1"} loc("LlamaForCausalLM":0:0), %arg83: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_640.1"} loc("LlamaForCausalLM":0:0), %arg84: tensor<1xf32> {ttir.name = "input_1_multiply_641"} loc("LlamaForCausalLM":0:0), %arg85: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_642.1"} loc("LlamaForCausalLM":0:0), %arg86: tensor<1xf32> {ttir.name = "input_1_multiply_650"} loc("LlamaForCausalLM":0:0), %arg87: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_651"} loc("LlamaForCausalLM":0:0), %arg88: tensor<1xf32> {ttir.name = "input_1_add_672"} loc("LlamaForCausalLM":0:0), %arg89: tensor<1xf32> {ttir.name = "input_1_add_692"} loc("LlamaForCausalLM":0:0), %arg90: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_702"} loc("LlamaForCausalLM":0:0), %arg91: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_712.1"} loc("LlamaForCausalLM":0:0), %arg92: tensor<1xf32> {ttir.name = "input_1_multiply_713"} loc("LlamaForCausalLM":0:0), %arg93: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_714.1"} loc("LlamaForCausalLM":0:0), %arg94: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_726.1"} loc("LlamaForCausalLM":0:0), %arg95: tensor<1xf32> {ttir.name = "input_1_multiply_727"} loc("LlamaForCausalLM":0:0), %arg96: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_728.1"} loc("LlamaForCausalLM":0:0), %arg97: tensor<1xf32> {ttir.name = "input_1_multiply_736"} loc("LlamaForCausalLM":0:0), %arg98: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_737"} loc("LlamaForCausalLM":0:0), %arg99: tensor<1xf32> {ttir.name = "input_1_add_758"} loc("LlamaForCausalLM":0:0), %arg100: tensor<1xf32> {ttir.name = "input_1_add_778"} loc("LlamaForCausalLM":0:0), %arg101: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_788"} loc("LlamaForCausalLM":0:0), %arg102: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_798.1"} loc("LlamaForCausalLM":0:0), %arg103: tensor<1xf32> {ttir.name = "input_1_multiply_799"} loc("LlamaForCausalLM":0:0), %arg104: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_800.1"} loc("LlamaForCausalLM":0:0), %arg105: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_812.1"} loc("LlamaForCausalLM":0:0), %arg106: tensor<1xf32> {ttir.name = "input_1_multiply_813"} loc("LlamaForCausalLM":0:0), %arg107: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_814.1"} loc("LlamaForCausalLM":0:0), %arg108: tensor<1xf32> {ttir.name = "input_1_multiply_822"} loc("LlamaForCausalLM":0:0), %arg109: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_823"} loc("LlamaForCausalLM":0:0), %arg110: tensor<1xf32> {ttir.name = "input_1_add_844"} loc("LlamaForCausalLM":0:0), %arg111: tensor<1xf32> {ttir.name = "input_1_add_864"} loc("LlamaForCausalLM":0:0), %arg112: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_874"} loc("LlamaForCausalLM":0:0), %arg113: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_884.1"} loc("LlamaForCausalLM":0:0), %arg114: tensor<1xf32> {ttir.name = "input_1_multiply_885"} loc("LlamaForCausalLM":0:0), %arg115: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_886.1"} loc("LlamaForCausalLM":0:0), %arg116: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_898.1"} loc("LlamaForCausalLM":0:0), %arg117: tensor<1xf32> {ttir.name = "input_1_multiply_899"} loc("LlamaForCausalLM":0:0), %arg118: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_900.1"} loc("LlamaForCausalLM":0:0), %arg119: tensor<1xf32> {ttir.name = "input_1_multiply_908"} loc("LlamaForCausalLM":0:0), %arg120: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_909"} loc("LlamaForCausalLM":0:0), %arg121: tensor<1xf32> {ttir.name = "input_1_add_930"} loc("LlamaForCausalLM":0:0), %arg122: tensor<1xf32> {ttir.name = "input_1_add_950"} loc("LlamaForCausalLM":0:0), %arg123: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_960"} loc("LlamaForCausalLM":0:0), %arg124: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_970.1"} loc("LlamaForCausalLM":0:0), %arg125: tensor<1xf32> {ttir.name = "input_1_multiply_971"} loc("LlamaForCausalLM":0:0), %arg126: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_972.1"} loc("LlamaForCausalLM":0:0), %arg127: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_984.1"} loc("LlamaForCausalLM":0:0), %arg128: tensor<1xf32> {ttir.name = "input_1_multiply_985"} loc("LlamaForCausalLM":0:0), %arg129: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_986.1"} loc("LlamaForCausalLM":0:0), %arg130: tensor<1xf32> {ttir.name = "input_1_multiply_994"} loc("LlamaForCausalLM":0:0), %arg131: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_995"} loc("LlamaForCausalLM":0:0), %arg132: tensor<1xf32> {ttir.name = "input_1_add_1016"} loc("LlamaForCausalLM":0:0), %arg133: tensor<1xf32> {ttir.name = "input_1_add_1036"} loc("LlamaForCausalLM":0:0), %arg134: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1046"} loc("LlamaForCausalLM":0:0), %arg135: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1056.1"} loc("LlamaForCausalLM":0:0), %arg136: tensor<1xf32> {ttir.name = "input_1_multiply_1057"} loc("LlamaForCausalLM":0:0), %arg137: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1058.1"} loc("LlamaForCausalLM":0:0), %arg138: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1070.1"} loc("LlamaForCausalLM":0:0), %arg139: tensor<1xf32> {ttir.name = "input_1_multiply_1071"} loc("LlamaForCausalLM":0:0), %arg140: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1072.1"} loc("LlamaForCausalLM":0:0), %arg141: tensor<1xf32> {ttir.name = "input_1_multiply_1080"} loc("LlamaForCausalLM":0:0), %arg142: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1081"} loc("LlamaForCausalLM":0:0), %arg143: tensor<1xf32> {ttir.name = "input_1_add_1102"} loc("LlamaForCausalLM":0:0), %arg144: tensor<1xf32> {ttir.name = "input_1_add_1122"} loc("LlamaForCausalLM":0:0), %arg145: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1132"} loc("LlamaForCausalLM":0:0), %arg146: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1142.1"} loc("LlamaForCausalLM":0:0), %arg147: tensor<1xf32> {ttir.name = "input_1_multiply_1143"} loc("LlamaForCausalLM":0:0), %arg148: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1144.1"} loc("LlamaForCausalLM":0:0), %arg149: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1156.1"} loc("LlamaForCausalLM":0:0), %arg150: tensor<1xf32> {ttir.name = "input_1_multiply_1157"} loc("LlamaForCausalLM":0:0), %arg151: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1158.1"} loc("LlamaForCausalLM":0:0), %arg152: tensor<1xf32> {ttir.name = "input_1_multiply_1166"} loc("LlamaForCausalLM":0:0), %arg153: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1167"} loc("LlamaForCausalLM":0:0), %arg154: tensor<1xf32> {ttir.name = "input_1_add_1188"} loc("LlamaForCausalLM":0:0), %arg155: tensor<1xf32> {ttir.name = "input_1_add_1208"} loc("LlamaForCausalLM":0:0), %arg156: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1218"} loc("LlamaForCausalLM":0:0), %arg157: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1228.1"} loc("LlamaForCausalLM":0:0), %arg158: tensor<1xf32> {ttir.name = "input_1_multiply_1229"} loc("LlamaForCausalLM":0:0), %arg159: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1230.1"} loc("LlamaForCausalLM":0:0), %arg160: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1242.1"} loc("LlamaForCausalLM":0:0), %arg161: tensor<1xf32> {ttir.name = "input_1_multiply_1243"} loc("LlamaForCausalLM":0:0), %arg162: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1244.1"} loc("LlamaForCausalLM":0:0), %arg163: tensor<1xf32> {ttir.name = "input_1_multiply_1252"} loc("LlamaForCausalLM":0:0), %arg164: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1253"} loc("LlamaForCausalLM":0:0), %arg165: tensor<1xf32> {ttir.name = "input_1_add_1274"} loc("LlamaForCausalLM":0:0), %arg166: tensor<1xf32> {ttir.name = "input_1_add_1294"} loc("LlamaForCausalLM":0:0), %arg167: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1304"} loc("LlamaForCausalLM":0:0), %arg168: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1314.1"} loc("LlamaForCausalLM":0:0), %arg169: tensor<1xf32> {ttir.name = "input_1_multiply_1315"} loc("LlamaForCausalLM":0:0), %arg170: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1316.1"} loc("LlamaForCausalLM":0:0), %arg171: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1328.1"} loc("LlamaForCausalLM":0:0), %arg172: tensor<1xf32> {ttir.name = "input_1_multiply_1329"} loc("LlamaForCausalLM":0:0), %arg173: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1330.1"} loc("LlamaForCausalLM":0:0), %arg174: tensor<1xf32> {ttir.name = "input_1_multiply_1338"} loc("LlamaForCausalLM":0:0), %arg175: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1339"} loc("LlamaForCausalLM":0:0), %arg176: tensor<1xf32> {ttir.name = "input_1_add_1360"} loc("LlamaForCausalLM":0:0), %arg177: tensor<1xf32> {ttir.name = "input_1_add_1380"} loc("LlamaForCausalLM":0:0), %arg178: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1390"} loc("LlamaForCausalLM":0:0), %arg179: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1400.1"} loc("LlamaForCausalLM":0:0), %arg180: tensor<1xf32> {ttir.name = "input_1_multiply_1401"} loc("LlamaForCausalLM":0:0), %arg181: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1402.1"} loc("LlamaForCausalLM":0:0), %arg182: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1414.1"} loc("LlamaForCausalLM":0:0), %arg183: tensor<1xf32> {ttir.name = "input_1_multiply_1415"} loc("LlamaForCausalLM":0:0), %arg184: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1416.1"} loc("LlamaForCausalLM":0:0), %arg185: tensor<1xf32> {ttir.name = "input_1_multiply_1424"} loc("LlamaForCausalLM":0:0), %arg186: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1425"} loc("LlamaForCausalLM":0:0), %arg187: tensor<1xf32> {ttir.name = "input_1_add_1446"} loc("LlamaForCausalLM":0:0), %arg188: tensor<1xf32> {ttir.name = "input_1_add_1466"} loc("LlamaForCausalLM":0:0), %arg189: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1476"} loc("LlamaForCausalLM":0:0), %arg190: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1486.1"} loc("LlamaForCausalLM":0:0), %arg191: tensor<1xf32> {ttir.name = "input_1_multiply_1487"} loc("LlamaForCausalLM":0:0), %arg192: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1488.1"} loc("LlamaForCausalLM":0:0), %arg193: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1500.1"} loc("LlamaForCausalLM":0:0), %arg194: tensor<1xf32> {ttir.name = "input_1_multiply_1501"} loc("LlamaForCausalLM":0:0), %arg195: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1502.1"} loc("LlamaForCausalLM":0:0), %arg196: tensor<1xf32> {ttir.name = "input_1_multiply_1510"} loc("LlamaForCausalLM":0:0), %arg197: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1511"} loc("LlamaForCausalLM":0:0), %arg198: tensor<1xf32> {ttir.name = "input_1_add_1532"} loc("LlamaForCausalLM":0:0), %arg199: tensor<1xf32> {ttir.name = "input_1_add_1552"} loc("LlamaForCausalLM":0:0), %arg200: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1562"} loc("LlamaForCausalLM":0:0), %arg201: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1572.1"} loc("LlamaForCausalLM":0:0), %arg202: tensor<1xf32> {ttir.name = "input_1_multiply_1573"} loc("LlamaForCausalLM":0:0), %arg203: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1574.1"} loc("LlamaForCausalLM":0:0), %arg204: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1586.1"} loc("LlamaForCausalLM":0:0), %arg205: tensor<1xf32> {ttir.name = "input_1_multiply_1587"} loc("LlamaForCausalLM":0:0), %arg206: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1588.1"} loc("LlamaForCausalLM":0:0), %arg207: tensor<1xf32> {ttir.name = "input_1_multiply_1596"} loc("LlamaForCausalLM":0:0), %arg208: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1597"} loc("LlamaForCausalLM":0:0), %arg209: tensor<1xf32> {ttir.name = "input_1_add_1618"} loc("LlamaForCausalLM":0:0), %arg210: tensor<1xf32> {ttir.name = "input_1_add_1638"} loc("LlamaForCausalLM":0:0), %arg211: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1648"} loc("LlamaForCausalLM":0:0), %arg212: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1658.1"} loc("LlamaForCausalLM":0:0), %arg213: tensor<1xf32> {ttir.name = "input_1_multiply_1659"} loc("LlamaForCausalLM":0:0), %arg214: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1660.1"} loc("LlamaForCausalLM":0:0), %arg215: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1672.1"} loc("LlamaForCausalLM":0:0), %arg216: tensor<1xf32> {ttir.name = "input_1_multiply_1673"} loc("LlamaForCausalLM":0:0), %arg217: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1674.1"} loc("LlamaForCausalLM":0:0), %arg218: tensor<1xf32> {ttir.name = "input_1_multiply_1682"} loc("LlamaForCausalLM":0:0), %arg219: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1683"} loc("LlamaForCausalLM":0:0), %arg220: tensor<1xf32> {ttir.name = "input_1_add_1704"} loc("LlamaForCausalLM":0:0), %arg221: tensor<1xf32> {ttir.name = "input_1_add_1724"} loc("LlamaForCausalLM":0:0), %arg222: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1734"} loc("LlamaForCausalLM":0:0), %arg223: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1744.1"} loc("LlamaForCausalLM":0:0), %arg224: tensor<1xf32> {ttir.name = "input_1_multiply_1745"} loc("LlamaForCausalLM":0:0), %arg225: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1746.1"} loc("LlamaForCausalLM":0:0), %arg226: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1758.1"} loc("LlamaForCausalLM":0:0), %arg227: tensor<1xf32> {ttir.name = "input_1_multiply_1759"} loc("LlamaForCausalLM":0:0), %arg228: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1760.1"} loc("LlamaForCausalLM":0:0), %arg229: tensor<1xf32> {ttir.name = "input_1_multiply_1768"} loc("LlamaForCausalLM":0:0), %arg230: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1769"} loc("LlamaForCausalLM":0:0), %arg231: tensor<1xf32> {ttir.name = "input_1_add_1790"} loc("LlamaForCausalLM":0:0), %arg232: tensor<1xf32> {ttir.name = "input_1_add_1810"} loc("LlamaForCausalLM":0:0), %arg233: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1820"} loc("LlamaForCausalLM":0:0), %arg234: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1830.1"} loc("LlamaForCausalLM":0:0), %arg235: tensor<1xf32> {ttir.name = "input_1_multiply_1831"} loc("LlamaForCausalLM":0:0), %arg236: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1832.1"} loc("LlamaForCausalLM":0:0), %arg237: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1844.1"} loc("LlamaForCausalLM":0:0), %arg238: tensor<1xf32> {ttir.name = "input_1_multiply_1845"} loc("LlamaForCausalLM":0:0), %arg239: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1846.1"} loc("LlamaForCausalLM":0:0), %arg240: tensor<1xf32> {ttir.name = "input_1_multiply_1854"} loc("LlamaForCausalLM":0:0), %arg241: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1855"} loc("LlamaForCausalLM":0:0), %arg242: tensor<1xf32> {ttir.name = "input_1_add_1876"} loc("LlamaForCausalLM":0:0), %arg243: tensor<1xf32> {ttir.name = "input_1_add_1896"} loc("LlamaForCausalLM":0:0), %arg244: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1906"} loc("LlamaForCausalLM":0:0), %arg245: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1916.1"} loc("LlamaForCausalLM":0:0), %arg246: tensor<1xf32> {ttir.name = "input_1_multiply_1917"} loc("LlamaForCausalLM":0:0), %arg247: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1918.1"} loc("LlamaForCausalLM":0:0), %arg248: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1930.1"} loc("LlamaForCausalLM":0:0), %arg249: tensor<1xf32> {ttir.name = "input_1_multiply_1931"} loc("LlamaForCausalLM":0:0), %arg250: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_1932.1"} loc("LlamaForCausalLM":0:0), %arg251: tensor<1xf32> {ttir.name = "input_1_multiply_1940"} loc("LlamaForCausalLM":0:0), %arg252: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_1941"} loc("LlamaForCausalLM":0:0), %arg253: tensor<1xf32> {ttir.name = "input_1_add_1962"} loc("LlamaForCausalLM":0:0), %arg254: tensor<1xf32> {ttir.name = "input_1_add_1982"} loc("LlamaForCausalLM":0:0), %arg255: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_1992"} loc("LlamaForCausalLM":0:0), %arg256: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2002.1"} loc("LlamaForCausalLM":0:0), %arg257: tensor<1xf32> {ttir.name = "input_1_multiply_2003"} loc("LlamaForCausalLM":0:0), %arg258: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2004.1"} loc("LlamaForCausalLM":0:0), %arg259: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2016.1"} loc("LlamaForCausalLM":0:0), %arg260: tensor<1xf32> {ttir.name = "input_1_multiply_2017"} loc("LlamaForCausalLM":0:0), %arg261: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2018.1"} loc("LlamaForCausalLM":0:0), %arg262: tensor<1xf32> {ttir.name = "input_1_multiply_2026"} loc("LlamaForCausalLM":0:0), %arg263: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_2027"} loc("LlamaForCausalLM":0:0), %arg264: tensor<1xf32> {ttir.name = "input_1_add_2048"} loc("LlamaForCausalLM":0:0), %arg265: tensor<1xf32> {ttir.name = "input_1_add_2068"} loc("LlamaForCausalLM":0:0), %arg266: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_2078"} loc("LlamaForCausalLM":0:0), %arg267: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2088.1"} loc("LlamaForCausalLM":0:0), %arg268: tensor<1xf32> {ttir.name = "input_1_multiply_2089"} loc("LlamaForCausalLM":0:0), %arg269: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2090.1"} loc("LlamaForCausalLM":0:0), %arg270: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2102.1"} loc("LlamaForCausalLM":0:0), %arg271: tensor<1xf32> {ttir.name = "input_1_multiply_2103"} loc("LlamaForCausalLM":0:0), %arg272: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2104.1"} loc("LlamaForCausalLM":0:0), %arg273: tensor<1xf32> {ttir.name = "input_1_multiply_2112"} loc("LlamaForCausalLM":0:0), %arg274: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_2113"} loc("LlamaForCausalLM":0:0), %arg275: tensor<1xf32> {ttir.name = "input_1_add_2134"} loc("LlamaForCausalLM":0:0), %arg276: tensor<1xf32> {ttir.name = "input_1_add_2154"} loc("LlamaForCausalLM":0:0), %arg277: tensor<1x12x50xf32> {ttir.name = "input_0_unsqueeze_2164"} loc("LlamaForCausalLM":0:0), %arg278: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2174.1"} loc("LlamaForCausalLM":0:0), %arg279: tensor<1xf32> {ttir.name = "input_1_multiply_2175"} loc("LlamaForCausalLM":0:0), %arg280: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2176.1"} loc("LlamaForCausalLM":0:0), %arg281: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2188.1"} loc("LlamaForCausalLM":0:0), %arg282: tensor<1xf32> {ttir.name = "input_1_multiply_2189"} loc("LlamaForCausalLM":0:0), %arg283: tensor<1x32x50x100xf32> {ttir.name = "dc.input_tensor.index_2190.1"} loc("LlamaForCausalLM":0:0), %arg284: tensor<1xf32> {ttir.name = "input_1_multiply_2198"} loc("LlamaForCausalLM":0:0), %arg285: tensor<1x1x12x12xf32> {ttir.name = "input_1_add_2199"} loc("LlamaForCausalLM":0:0), %arg286: tensor<1xf32> {ttir.name = "input_1_add_2220"} loc("LlamaForCausalLM":0:0), %arg287: tensor<1xf32> {ttir.name = "input_1_add_2240"} loc("LlamaForCausalLM":0:0), %arg288: tensor<3200xf32> {ttir.name = "model.norm.weight"} loc("LlamaForCausalLM":0:0), %arg289: tensor<32000x3200xf32> {ttir.name = "model.embed_tokens.weight"} loc("LlamaForCausalLM":0:0), %arg290: tensor<3200xf32> {ttir.name = "model.layers.0.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg291: tensor<3200x3200xf32> {ttir.name = "model.layers.0.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg292: tensor<3200x3200xf32> {ttir.name = "model.layers.0.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg293: tensor<3200x3200xf32> {ttir.name = "model.layers.0.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg294: tensor<3200x3200xf32> {ttir.name = "model.layers.0.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg295: tensor<3200xf32> {ttir.name = "model.layers.0.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg296: tensor<3200x8640xf32> {ttir.name = "model.layers.0.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg297: tensor<3200x8640xf32> {ttir.name = "model.layers.0.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg298: tensor<8640x3200xf32> {ttir.name = "model.layers.0.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg299: tensor<3200xf32> {ttir.name = "model.layers.1.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg300: tensor<3200x3200xf32> {ttir.name = "model.layers.1.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg301: tensor<3200x3200xf32> {ttir.name = "model.layers.1.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg302: tensor<3200x3200xf32> {ttir.name = "model.layers.1.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg303: tensor<3200x3200xf32> {ttir.name = "model.layers.1.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg304: tensor<3200xf32> {ttir.name = "model.layers.1.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg305: tensor<3200x8640xf32> {ttir.name = "model.layers.1.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg306: tensor<3200x8640xf32> {ttir.name = "model.layers.1.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg307: tensor<8640x3200xf32> {ttir.name = "model.layers.1.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg308: tensor<3200xf32> {ttir.name = "model.layers.2.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg309: tensor<3200x3200xf32> {ttir.name = "model.layers.2.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg310: tensor<3200x3200xf32> {ttir.name = "model.layers.2.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg311: tensor<3200x3200xf32> {ttir.name = "model.layers.2.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg312: tensor<3200x3200xf32> {ttir.name = "model.layers.2.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg313: tensor<3200xf32> {ttir.name = "model.layers.2.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg314: tensor<3200x8640xf32> {ttir.name = "model.layers.2.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg315: tensor<3200x8640xf32> {ttir.name = "model.layers.2.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg316: tensor<8640x3200xf32> {ttir.name = "model.layers.2.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg317: tensor<3200xf32> {ttir.name = "model.layers.3.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg318: tensor<3200x3200xf32> {ttir.name = "model.layers.3.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg319: tensor<3200x3200xf32> {ttir.name = "model.layers.3.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg320: tensor<3200x3200xf32> {ttir.name = "model.layers.3.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg321: tensor<3200x3200xf32> {ttir.name = "model.layers.3.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg322: tensor<3200xf32> {ttir.name = "model.layers.3.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg323: tensor<3200x8640xf32> {ttir.name = "model.layers.3.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg324: tensor<3200x8640xf32> {ttir.name = "model.layers.3.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg325: tensor<8640x3200xf32> {ttir.name = "model.layers.3.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg326: tensor<3200xf32> {ttir.name = "model.layers.4.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg327: tensor<3200x3200xf32> {ttir.name = "model.layers.4.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg328: tensor<3200x3200xf32> {ttir.name = "model.layers.4.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg329: tensor<3200x3200xf32> {ttir.name = "model.layers.4.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg330: tensor<3200x3200xf32> {ttir.name = "model.layers.4.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg331: tensor<3200xf32> {ttir.name = "model.layers.4.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg332: tensor<3200x8640xf32> {ttir.name = "model.layers.4.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg333: tensor<3200x8640xf32> {ttir.name = "model.layers.4.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg334: tensor<8640x3200xf32> {ttir.name = "model.layers.4.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg335: tensor<3200xf32> {ttir.name = "model.layers.5.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg336: tensor<3200x3200xf32> {ttir.name = "model.layers.5.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg337: tensor<3200x3200xf32> {ttir.name = "model.layers.5.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg338: tensor<3200x3200xf32> {ttir.name = "model.layers.5.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg339: tensor<3200x3200xf32> {ttir.name = "model.layers.5.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg340: tensor<3200xf32> {ttir.name = "model.layers.5.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg341: tensor<3200x8640xf32> {ttir.name = "model.layers.5.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg342: tensor<3200x8640xf32> {ttir.name = "model.layers.5.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg343: tensor<8640x3200xf32> {ttir.name = "model.layers.5.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg344: tensor<3200xf32> {ttir.name = "model.layers.6.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg345: tensor<3200x3200xf32> {ttir.name = "model.layers.6.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg346: tensor<3200x3200xf32> {ttir.name = "model.layers.6.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg347: tensor<3200x3200xf32> {ttir.name = "model.layers.6.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg348: tensor<3200x3200xf32> {ttir.name = "model.layers.6.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg349: tensor<3200xf32> {ttir.name = "model.layers.6.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg350: tensor<3200x8640xf32> {ttir.name = "model.layers.6.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg351: tensor<3200x8640xf32> {ttir.name = "model.layers.6.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg352: tensor<8640x3200xf32> {ttir.name = "model.layers.6.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg353: tensor<3200xf32> {ttir.name = "model.layers.7.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg354: tensor<3200x3200xf32> {ttir.name = "model.layers.7.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg355: tensor<3200x3200xf32> {ttir.name = "model.layers.7.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg356: tensor<3200x3200xf32> {ttir.name = "model.layers.7.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg357: tensor<3200x3200xf32> {ttir.name = "model.layers.7.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg358: tensor<3200xf32> {ttir.name = "model.layers.7.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg359: tensor<3200x8640xf32> {ttir.name = "model.layers.7.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg360: tensor<3200x8640xf32> {ttir.name = "model.layers.7.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg361: tensor<8640x3200xf32> {ttir.name = "model.layers.7.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg362: tensor<3200xf32> {ttir.name = "model.layers.8.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg363: tensor<3200x3200xf32> {ttir.name = "model.layers.8.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg364: tensor<3200x3200xf32> {ttir.name = "model.layers.8.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg365: tensor<3200x3200xf32> {ttir.name = "model.layers.8.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg366: tensor<3200x3200xf32> {ttir.name = "model.layers.8.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg367: tensor<3200xf32> {ttir.name = "model.layers.8.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg368: tensor<3200x8640xf32> {ttir.name = "model.layers.8.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg369: tensor<3200x8640xf32> {ttir.name = "model.layers.8.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg370: tensor<8640x3200xf32> {ttir.name = "model.layers.8.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg371: tensor<3200xf32> {ttir.name = "model.layers.9.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg372: tensor<3200x3200xf32> {ttir.name = "model.layers.9.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg373: tensor<3200x3200xf32> {ttir.name = "model.layers.9.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg374: tensor<3200x3200xf32> {ttir.name = "model.layers.9.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg375: tensor<3200x3200xf32> {ttir.name = "model.layers.9.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg376: tensor<3200xf32> {ttir.name = "model.layers.9.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg377: tensor<3200x8640xf32> {ttir.name = "model.layers.9.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg378: tensor<3200x8640xf32> {ttir.name = "model.layers.9.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg379: tensor<8640x3200xf32> {ttir.name = "model.layers.9.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg380: tensor<3200xf32> {ttir.name = "model.layers.10.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg381: tensor<3200x3200xf32> {ttir.name = "model.layers.10.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg382: tensor<3200x3200xf32> {ttir.name = "model.layers.10.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg383: tensor<3200x3200xf32> {ttir.name = "model.layers.10.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg384: tensor<3200x3200xf32> {ttir.name = "model.layers.10.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg385: tensor<3200xf32> {ttir.name = "model.layers.10.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg386: tensor<3200x8640xf32> {ttir.name = "model.layers.10.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg387: tensor<3200x8640xf32> {ttir.name = "model.layers.10.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg388: tensor<8640x3200xf32> {ttir.name = "model.layers.10.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg389: tensor<3200xf32> {ttir.name = "model.layers.11.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg390: tensor<3200x3200xf32> {ttir.name = "model.layers.11.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg391: tensor<3200x3200xf32> {ttir.name = "model.layers.11.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg392: tensor<3200x3200xf32> {ttir.name = "model.layers.11.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg393: tensor<3200x3200xf32> {ttir.name = "model.layers.11.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg394: tensor<3200xf32> {ttir.name = "model.layers.11.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg395: tensor<3200x8640xf32> {ttir.name = "model.layers.11.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg396: tensor<3200x8640xf32> {ttir.name = "model.layers.11.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg397: tensor<8640x3200xf32> {ttir.name = "model.layers.11.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg398: tensor<3200xf32> {ttir.name = "model.layers.12.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg399: tensor<3200x3200xf32> {ttir.name = "model.layers.12.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg400: tensor<3200x3200xf32> {ttir.name = "model.layers.12.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg401: tensor<3200x3200xf32> {ttir.name = "model.layers.12.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg402: tensor<3200x3200xf32> {ttir.name = "model.layers.12.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg403: tensor<3200xf32> {ttir.name = "model.layers.12.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg404: tensor<3200x8640xf32> {ttir.name = "model.layers.12.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg405: tensor<3200x8640xf32> {ttir.name = "model.layers.12.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg406: tensor<8640x3200xf32> {ttir.name = "model.layers.12.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg407: tensor<3200xf32> {ttir.name = "model.layers.13.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg408: tensor<3200x3200xf32> {ttir.name = "model.layers.13.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg409: tensor<3200x3200xf32> {ttir.name = "model.layers.13.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg410: tensor<3200x3200xf32> {ttir.name = "model.layers.13.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg411: tensor<3200x3200xf32> {ttir.name = "model.layers.13.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg412: tensor<3200xf32> {ttir.name = "model.layers.13.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg413: tensor<3200x8640xf32> {ttir.name = "model.layers.13.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg414: tensor<3200x8640xf32> {ttir.name = "model.layers.13.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg415: tensor<8640x3200xf32> {ttir.name = "model.layers.13.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg416: tensor<3200xf32> {ttir.name = "model.layers.14.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg417: tensor<3200x3200xf32> {ttir.name = "model.layers.14.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg418: tensor<3200x3200xf32> {ttir.name = "model.layers.14.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg419: tensor<3200x3200xf32> {ttir.name = "model.layers.14.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg420: tensor<3200x3200xf32> {ttir.name = "model.layers.14.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg421: tensor<3200xf32> {ttir.name = "model.layers.14.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg422: tensor<3200x8640xf32> {ttir.name = "model.layers.14.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg423: tensor<3200x8640xf32> {ttir.name = "model.layers.14.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg424: tensor<8640x3200xf32> {ttir.name = "model.layers.14.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg425: tensor<3200xf32> {ttir.name = "model.layers.15.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg426: tensor<3200x3200xf32> {ttir.name = "model.layers.15.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg427: tensor<3200x3200xf32> {ttir.name = "model.layers.15.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg428: tensor<3200x3200xf32> {ttir.name = "model.layers.15.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg429: tensor<3200x3200xf32> {ttir.name = "model.layers.15.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg430: tensor<3200xf32> {ttir.name = "model.layers.15.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg431: tensor<3200x8640xf32> {ttir.name = "model.layers.15.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg432: tensor<3200x8640xf32> {ttir.name = "model.layers.15.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg433: tensor<8640x3200xf32> {ttir.name = "model.layers.15.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg434: tensor<3200xf32> {ttir.name = "model.layers.16.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg435: tensor<3200x3200xf32> {ttir.name = "model.layers.16.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg436: tensor<3200x3200xf32> {ttir.name = "model.layers.16.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg437: tensor<3200x3200xf32> {ttir.name = "model.layers.16.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg438: tensor<3200x3200xf32> {ttir.name = "model.layers.16.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg439: tensor<3200xf32> {ttir.name = "model.layers.16.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg440: tensor<3200x8640xf32> {ttir.name = "model.layers.16.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg441: tensor<3200x8640xf32> {ttir.name = "model.layers.16.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg442: tensor<8640x3200xf32> {ttir.name = "model.layers.16.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg443: tensor<3200xf32> {ttir.name = "model.layers.17.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg444: tensor<3200x3200xf32> {ttir.name = "model.layers.17.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg445: tensor<3200x3200xf32> {ttir.name = "model.layers.17.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg446: tensor<3200x3200xf32> {ttir.name = "model.layers.17.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg447: tensor<3200x3200xf32> {ttir.name = "model.layers.17.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg448: tensor<3200xf32> {ttir.name = "model.layers.17.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg449: tensor<3200x8640xf32> {ttir.name = "model.layers.17.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg450: tensor<3200x8640xf32> {ttir.name = "model.layers.17.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg451: tensor<8640x3200xf32> {ttir.name = "model.layers.17.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg452: tensor<3200xf32> {ttir.name = "model.layers.18.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg453: tensor<3200x3200xf32> {ttir.name = "model.layers.18.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg454: tensor<3200x3200xf32> {ttir.name = "model.layers.18.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg455: tensor<3200x3200xf32> {ttir.name = "model.layers.18.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg456: tensor<3200x3200xf32> {ttir.name = "model.layers.18.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg457: tensor<3200xf32> {ttir.name = "model.layers.18.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg458: tensor<3200x8640xf32> {ttir.name = "model.layers.18.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg459: tensor<3200x8640xf32> {ttir.name = "model.layers.18.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg460: tensor<8640x3200xf32> {ttir.name = "model.layers.18.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg461: tensor<3200xf32> {ttir.name = "model.layers.19.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg462: tensor<3200x3200xf32> {ttir.name = "model.layers.19.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg463: tensor<3200x3200xf32> {ttir.name = "model.layers.19.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg464: tensor<3200x3200xf32> {ttir.name = "model.layers.19.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg465: tensor<3200x3200xf32> {ttir.name = "model.layers.19.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg466: tensor<3200xf32> {ttir.name = "model.layers.19.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg467: tensor<3200x8640xf32> {ttir.name = "model.layers.19.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg468: tensor<3200x8640xf32> {ttir.name = "model.layers.19.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg469: tensor<8640x3200xf32> {ttir.name = "model.layers.19.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg470: tensor<3200xf32> {ttir.name = "model.layers.20.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg471: tensor<3200x3200xf32> {ttir.name = "model.layers.20.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg472: tensor<3200x3200xf32> {ttir.name = "model.layers.20.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg473: tensor<3200x3200xf32> {ttir.name = "model.layers.20.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg474: tensor<3200x3200xf32> {ttir.name = "model.layers.20.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg475: tensor<3200xf32> {ttir.name = "model.layers.20.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg476: tensor<3200x8640xf32> {ttir.name = "model.layers.20.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg477: tensor<3200x8640xf32> {ttir.name = "model.layers.20.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg478: tensor<8640x3200xf32> {ttir.name = "model.layers.20.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg479: tensor<3200xf32> {ttir.name = "model.layers.21.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg480: tensor<3200x3200xf32> {ttir.name = "model.layers.21.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg481: tensor<3200x3200xf32> {ttir.name = "model.layers.21.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg482: tensor<3200x3200xf32> {ttir.name = "model.layers.21.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg483: tensor<3200x3200xf32> {ttir.name = "model.layers.21.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg484: tensor<3200xf32> {ttir.name = "model.layers.21.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg485: tensor<3200x8640xf32> {ttir.name = "model.layers.21.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg486: tensor<3200x8640xf32> {ttir.name = "model.layers.21.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg487: tensor<8640x3200xf32> {ttir.name = "model.layers.21.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg488: tensor<3200xf32> {ttir.name = "model.layers.22.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg489: tensor<3200x3200xf32> {ttir.name = "model.layers.22.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg490: tensor<3200x3200xf32> {ttir.name = "model.layers.22.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg491: tensor<3200x3200xf32> {ttir.name = "model.layers.22.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg492: tensor<3200x3200xf32> {ttir.name = "model.layers.22.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg493: tensor<3200xf32> {ttir.name = "model.layers.22.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg494: tensor<3200x8640xf32> {ttir.name = "model.layers.22.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg495: tensor<3200x8640xf32> {ttir.name = "model.layers.22.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg496: tensor<8640x3200xf32> {ttir.name = "model.layers.22.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg497: tensor<3200xf32> {ttir.name = "model.layers.23.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg498: tensor<3200x3200xf32> {ttir.name = "model.layers.23.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg499: tensor<3200x3200xf32> {ttir.name = "model.layers.23.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg500: tensor<3200x3200xf32> {ttir.name = "model.layers.23.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg501: tensor<3200x3200xf32> {ttir.name = "model.layers.23.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg502: tensor<3200xf32> {ttir.name = "model.layers.23.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg503: tensor<3200x8640xf32> {ttir.name = "model.layers.23.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg504: tensor<3200x8640xf32> {ttir.name = "model.layers.23.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg505: tensor<8640x3200xf32> {ttir.name = "model.layers.23.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg506: tensor<3200xf32> {ttir.name = "model.layers.24.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg507: tensor<3200x3200xf32> {ttir.name = "model.layers.24.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg508: tensor<3200x3200xf32> {ttir.name = "model.layers.24.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg509: tensor<3200x3200xf32> {ttir.name = "model.layers.24.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg510: tensor<3200x3200xf32> {ttir.name = "model.layers.24.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg511: tensor<3200xf32> {ttir.name = "model.layers.24.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg512: tensor<3200x8640xf32> {ttir.name = "model.layers.24.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg513: tensor<3200x8640xf32> {ttir.name = "model.layers.24.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg514: tensor<8640x3200xf32> {ttir.name = "model.layers.24.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg515: tensor<3200xf32> {ttir.name = "model.layers.25.input_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg516: tensor<3200x3200xf32> {ttir.name = "model.layers.25.self_attn.q_proj.weight"} loc("LlamaForCausalLM":0:0), %arg517: tensor<3200x3200xf32> {ttir.name = "model.layers.25.self_attn.k_proj.weight"} loc("LlamaForCausalLM":0:0), %arg518: tensor<3200x3200xf32> {ttir.name = "model.layers.25.self_attn.v_proj.weight"} loc("LlamaForCausalLM":0:0), %arg519: tensor<3200x3200xf32> {ttir.name = "model.layers.25.self_attn.o_proj.weight"} loc("LlamaForCausalLM":0:0), %arg520: tensor<3200xf32> {ttir.name = "model.layers.25.post_attention_layernorm.weight"} loc("LlamaForCausalLM":0:0), %arg521: tensor<3200x8640xf32> {ttir.name = "model.layers.25.mlp.gate_proj.weight"} loc("LlamaForCausalLM":0:0), %arg522: tensor<3200x8640xf32> {ttir.name = "model.layers.25.mlp.up_proj.weight"} loc("LlamaForCausalLM":0:0), %arg523: tensor<8640x3200xf32> {ttir.name = "model.layers.25.mlp.down_proj.weight"} loc("LlamaForCausalLM":0:0), %arg524: tensor<3200x32000xf32> {ttir.name = "lm_head.weight"} loc("LlamaForCausalLM":0:0)) -> (tensor<1x12x3200xf32> {ttir.name = "LlamaForCausalLM.output_matmul_2246"}) { %0 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2091) - %1 = "ttir.embedding"(%arg0, %arg289, %0) <{operand_constraints = [#any_device, #any_device, #any_device, #any_device, #any_device, #any_device]}> : (tensor<1x12xi32>, tensor<32000x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2091) + %1 = "ttir.embedding"(%arg0, %arg289, %0) : (tensor<1x12xi32>, tensor<32000x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2091) %2 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2092) - %3 = "ttir.multiply"(%1, %1, %2) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2092) + %3 = "ttir.multiply"(%1, %1, %2) <{operandSegmentSizes = array}> : (tensor<1x12x3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2092) %4 = tensor.empty() : tensor<1x12x1xf32> loc(#loc2093) - %5 = "ttir.mean"(%3, %4) <{dim_arg = [-1 : i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x12x3200xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2093) + %5 = "ttir.mean"(%3, %4) <{dim_arg = [-1 : i32], keep_dim = true}> : (tensor<1x12x3200xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2093) %6 = tensor.empty() : tensor<1x12x1xf32> loc(#loc2094) - %7 = "ttir.add"(%5, %arg1, %6) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x1xf32>, tensor<1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2094) + %7 = "ttir.add"(%5, %arg1, %6) <{operandSegmentSizes = array}> : (tensor<1x12x1xf32>, tensor<1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2094) %8 = tensor.empty() : tensor<1x12x1xf32> loc(#loc2095) - %9 = "ttir.sqrt"(%7, %8) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x12x1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2095) + %9 = "ttir.sqrt"(%7, %8) <{operandSegmentSizes = array}> : (tensor<1x12x1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2095) %10 = tensor.empty() : tensor<1x12x1xf32> loc(#loc2096) - %11 = "ttir.reciprocal"(%9, %10) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x12x1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2096) + %11 = "ttir.reciprocal"(%9, %10) <{operandSegmentSizes = array}> : (tensor<1x12x1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2096) %12 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2097) - %13 = "ttir.multiply"(%1, %11, %12) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x3200xf32>, tensor<1x12x1xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2097) + %13 = "ttir.multiply"(%1, %11, %12) <{operandSegmentSizes = array}> : (tensor<1x12x3200xf32>, tensor<1x12x1xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2097) %14 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2098) - %15 = "ttir.multiply"(%arg290, %13, %14) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2098) + %15 = "ttir.multiply"(%arg290, %13, %14) <{operandSegmentSizes = array}> : (tensor<3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2098) %16 = tensor.empty() : tensor<12x3200xf32> loc(#loc2099) - %17 = "ttir.squeeze"(%15, %16) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : (tensor<1x12x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2099) + %17 = "ttir.squeeze"(%15, %16) <{dim = 0 : si32}> : (tensor<1x12x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2099) %18 = tensor.empty() : tensor<12x3200xf32> loc(#loc2100) - %19 = "ttir.matmul"(%17, %arg291, %18) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<12x3200xf32>, tensor<3200x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2100) + %19 = "ttir.matmul"(%17, %arg291, %18) : (tensor<12x3200xf32>, tensor<3200x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2100) %20 = tensor.empty() : tensor<1x12x32x100xf32> loc(#loc2101) - %21 = "ttir.reshape"(%19, %20) <{operand_constraints = [#any_device, #any_device], shape = [1 : i32, 12 : i32, 32 : i32, 100 : i32]}> : (tensor<12x3200xf32>, tensor<1x12x32x100xf32>) -> tensor<1x12x32x100xf32> loc(#loc2101) + %21 = "ttir.reshape"(%19, %20) <{shape = [1 : i32, 12 : i32, 32 : i32, 100 : i32]}> : (tensor<12x3200xf32>, tensor<1x12x32x100xf32>) -> tensor<1x12x32x100xf32> loc(#loc2101) %22 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2102) - %23 = "ttir.transpose"(%21, %22) <{dim0 = -3 : si32, dim1 = -2 : si32, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : (tensor<1x12x32x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2102) + %23 = "ttir.transpose"(%21, %22) <{dim0 = -3 : si32, dim1 = -2 : si32}> : (tensor<1x12x32x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2102) %24 = tensor.empty() : tensor<1x12x100xf32> loc(#loc2103) - %25 = "ttir.concat"(%arg2, %arg2, %24) <{dim = -1 : si32, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : (tensor<1x12x50xf32>, tensor<1x12x50xf32>, tensor<1x12x100xf32>) -> tensor<1x12x100xf32> loc(#loc2103) + %25 = "ttir.concat"(%arg2, %arg2, %24) <{dim = -1 : si32}> : (tensor<1x12x50xf32>, tensor<1x12x50xf32>, tensor<1x12x100xf32>) -> tensor<1x12x100xf32> loc(#loc2103) %26 = tensor.empty() : tensor<1x12x100xf32> loc(#loc2104) - %27 = "ttir.sin"(%25, %26) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x12x100xf32>, tensor<1x12x100xf32>) -> tensor<1x12x100xf32> loc(#loc2104) + %27 = "ttir.sin"(%25, %26) <{operandSegmentSizes = array}> : (tensor<1x12x100xf32>, tensor<1x12x100xf32>) -> tensor<1x12x100xf32> loc(#loc2104) %28 = tensor.empty() : tensor<1x1x12x100xf32> loc(#loc2105) - %29 = "ttir.unsqueeze"(%27, %28) <{dim = 1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x100xf32>, tensor<1x1x12x100xf32>) -> tensor<1x1x12x100xf32> loc(#loc2105) + %29 = "ttir.unsqueeze"(%27, %28) <{dim = 1 : si32}> : (tensor<1x12x100xf32>, tensor<1x1x12x100xf32>) -> tensor<1x1x12x100xf32> loc(#loc2105) %30 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2106) - %31 = "ttir.multiply"(%23, %29, %30) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x1x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2106) + %31 = "ttir.multiply"(%23, %29, %30) <{operandSegmentSizes = array}> : (tensor<1x32x12x100xf32>, tensor<1x1x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2106) %32 = tensor.empty() : tensor<1x32x100x12xf32> loc(#loc2107) - %33 = "ttir.transpose"(%23, %32) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x32x100x12xf32>) -> tensor<1x32x100x12xf32> loc(#loc2107) + %33 = "ttir.transpose"(%23, %32) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<1x32x12x100xf32>, tensor<1x32x100x12xf32>) -> tensor<1x32x100x12xf32> loc(#loc2107) %34 = tensor.empty() : tensor<1x32x50x12xf32> loc(#loc2108) - %35 = "ttir.matmul"(%arg3, %33, %34) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x50x100xf32>, tensor<1x32x100x12xf32>, tensor<1x32x50x12xf32>) -> tensor<1x32x50x12xf32> loc(#loc2108) + %35 = "ttir.matmul"(%arg3, %33, %34) : (tensor<1x32x50x100xf32>, tensor<1x32x100x12xf32>, tensor<1x32x50x12xf32>) -> tensor<1x32x50x12xf32> loc(#loc2108) %36 = tensor.empty() : tensor<1x32x12x50xf32> loc(#loc2109) - %37 = "ttir.transpose"(%35, %36) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x50x12xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2109) + %37 = "ttir.transpose"(%35, %36) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<1x32x50x12xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2109) %38 = tensor.empty() : tensor<1x32x12x50xf32> loc(#loc2110) - %39 = "ttir.multiply"(%37, %arg4, %38) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x50xf32>, tensor<1xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2110) + %39 = "ttir.multiply"(%37, %arg4, %38) <{operandSegmentSizes = array}> : (tensor<1x32x12x50xf32>, tensor<1xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2110) %40 = tensor.empty() : tensor<1x32x100x12xf32> loc(#loc2111) - %41 = "ttir.transpose"(%23, %40) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x32x100x12xf32>) -> tensor<1x32x100x12xf32> loc(#loc2111) + %41 = "ttir.transpose"(%23, %40) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<1x32x12x100xf32>, tensor<1x32x100x12xf32>) -> tensor<1x32x100x12xf32> loc(#loc2111) %42 = tensor.empty() : tensor<1x32x50x12xf32> loc(#loc2112) - %43 = "ttir.matmul"(%arg5, %41, %42) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x50x100xf32>, tensor<1x32x100x12xf32>, tensor<1x32x50x12xf32>) -> tensor<1x32x50x12xf32> loc(#loc2112) + %43 = "ttir.matmul"(%arg5, %41, %42) : (tensor<1x32x50x100xf32>, tensor<1x32x100x12xf32>, tensor<1x32x50x12xf32>) -> tensor<1x32x50x12xf32> loc(#loc2112) %44 = tensor.empty() : tensor<1x32x12x50xf32> loc(#loc2113) - %45 = "ttir.transpose"(%43, %44) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x50x12xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2113) + %45 = "ttir.transpose"(%43, %44) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<1x32x50x12xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2113) %46 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2114) - %47 = "ttir.concat"(%39, %45, %46) <{dim = -1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x50xf32>, tensor<1x32x12x50xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2114) + %47 = "ttir.concat"(%39, %45, %46) <{dim = -1 : si32}> : (tensor<1x32x12x50xf32>, tensor<1x32x12x50xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2114) %48 = tensor.empty() : tensor<1x12x100xf32> loc(#loc2115) - %49 = "ttir.cos"(%25, %48) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x12x100xf32>, tensor<1x12x100xf32>) -> tensor<1x12x100xf32> loc(#loc2115) + %49 = "ttir.cos"(%25, %48) <{operandSegmentSizes = array}> : (tensor<1x12x100xf32>, tensor<1x12x100xf32>) -> tensor<1x12x100xf32> loc(#loc2115) %50 = tensor.empty() : tensor<1x1x12x100xf32> loc(#loc2116) - %51 = "ttir.unsqueeze"(%49, %50) <{dim = 1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x100xf32>, tensor<1x1x12x100xf32>) -> tensor<1x1x12x100xf32> loc(#loc2116) + %51 = "ttir.unsqueeze"(%49, %50) <{dim = 1 : si32}> : (tensor<1x12x100xf32>, tensor<1x1x12x100xf32>) -> tensor<1x1x12x100xf32> loc(#loc2116) %52 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2117) - %53 = "ttir.multiply"(%47, %51, %52) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x1x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2117) + %53 = "ttir.multiply"(%47, %51, %52) <{operandSegmentSizes = array}> : (tensor<1x32x12x100xf32>, tensor<1x1x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2117) %54 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2118) - %55 = "ttir.add"(%31, %53, %54) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x32x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2118) + %55 = "ttir.add"(%31, %53, %54) <{operandSegmentSizes = array}> : (tensor<1x32x12x100xf32>, tensor<1x32x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2118) %56 = tensor.empty() : tensor<32x12x100xf32> loc(#loc2119) - %57 = "ttir.squeeze"(%55, %56) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<32x12x100xf32>) -> tensor<32x12x100xf32> loc(#loc2119) + %57 = "ttir.squeeze"(%55, %56) <{dim = 0 : si32}> : (tensor<1x32x12x100xf32>, tensor<32x12x100xf32>) -> tensor<32x12x100xf32> loc(#loc2119) %58 = tensor.empty() : tensor<12x3200xf32> loc(#loc2120) - %59 = "ttir.matmul"(%17, %arg292, %58) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<12x3200xf32>, tensor<3200x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2120) + %59 = "ttir.matmul"(%17, %arg292, %58) : (tensor<12x3200xf32>, tensor<3200x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2120) %60 = tensor.empty() : tensor<1x12x32x100xf32> loc(#loc2121) - %61 = "ttir.reshape"(%59, %60) <{operand_constraints = [#any_device, #any_device], shape = [1 : i32, 12 : i32, 32 : i32, 100 : i32]}> : (tensor<12x3200xf32>, tensor<1x12x32x100xf32>) -> tensor<1x12x32x100xf32> loc(#loc2121) + %61 = "ttir.reshape"(%59, %60) <{shape = [1 : i32, 12 : i32, 32 : i32, 100 : i32]}> : (tensor<12x3200xf32>, tensor<1x12x32x100xf32>) -> tensor<1x12x32x100xf32> loc(#loc2121) %62 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2122) - %63 = "ttir.transpose"(%61, %62) <{dim0 = -3 : si32, dim1 = -2 : si32, operand_constraints = [#any_device, #any_device, #any_device, #any_device]}> : (tensor<1x12x32x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2122) + %63 = "ttir.transpose"(%61, %62) <{dim0 = -3 : si32, dim1 = -2 : si32}> : (tensor<1x12x32x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2122) %64 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2123) - %65 = "ttir.multiply"(%63, %29, %64) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x1x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2123) + %65 = "ttir.multiply"(%63, %29, %64) <{operandSegmentSizes = array}> : (tensor<1x32x12x100xf32>, tensor<1x1x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2123) %66 = tensor.empty() : tensor<1x32x100x12xf32> loc(#loc2124) - %67 = "ttir.transpose"(%63, %66) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x32x100x12xf32>) -> tensor<1x32x100x12xf32> loc(#loc2124) + %67 = "ttir.transpose"(%63, %66) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<1x32x12x100xf32>, tensor<1x32x100x12xf32>) -> tensor<1x32x100x12xf32> loc(#loc2124) %68 = tensor.empty() : tensor<1x32x50x12xf32> loc(#loc2125) - %69 = "ttir.matmul"(%arg6, %67, %68) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x50x100xf32>, tensor<1x32x100x12xf32>, tensor<1x32x50x12xf32>) -> tensor<1x32x50x12xf32> loc(#loc2125) + %69 = "ttir.matmul"(%arg6, %67, %68) : (tensor<1x32x50x100xf32>, tensor<1x32x100x12xf32>, tensor<1x32x50x12xf32>) -> tensor<1x32x50x12xf32> loc(#loc2125) %70 = tensor.empty() : tensor<1x32x12x50xf32> loc(#loc2126) - %71 = "ttir.transpose"(%69, %70) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x50x12xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2126) + %71 = "ttir.transpose"(%69, %70) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<1x32x50x12xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2126) %72 = tensor.empty() : tensor<1x32x12x50xf32> loc(#loc2127) - %73 = "ttir.multiply"(%71, %arg7, %72) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x50xf32>, tensor<1xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2127) + %73 = "ttir.multiply"(%71, %arg7, %72) <{operandSegmentSizes = array}> : (tensor<1x32x12x50xf32>, tensor<1xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2127) %74 = tensor.empty() : tensor<1x32x100x12xf32> loc(#loc2128) - %75 = "ttir.transpose"(%63, %74) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x32x100x12xf32>) -> tensor<1x32x100x12xf32> loc(#loc2128) + %75 = "ttir.transpose"(%63, %74) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<1x32x12x100xf32>, tensor<1x32x100x12xf32>) -> tensor<1x32x100x12xf32> loc(#loc2128) %76 = tensor.empty() : tensor<1x32x50x12xf32> loc(#loc2129) - %77 = "ttir.matmul"(%arg8, %75, %76) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x50x100xf32>, tensor<1x32x100x12xf32>, tensor<1x32x50x12xf32>) -> tensor<1x32x50x12xf32> loc(#loc2129) + %77 = "ttir.matmul"(%arg8, %75, %76) : (tensor<1x32x50x100xf32>, tensor<1x32x100x12xf32>, tensor<1x32x50x12xf32>) -> tensor<1x32x50x12xf32> loc(#loc2129) %78 = tensor.empty() : tensor<1x32x12x50xf32> loc(#loc2130) - %79 = "ttir.transpose"(%77, %78) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x50x12xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2130) + %79 = "ttir.transpose"(%77, %78) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<1x32x50x12xf32>, tensor<1x32x12x50xf32>) -> tensor<1x32x12x50xf32> loc(#loc2130) %80 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2131) - %81 = "ttir.concat"(%73, %79, %80) <{dim = -1 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x50xf32>, tensor<1x32x12x50xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2131) + %81 = "ttir.concat"(%73, %79, %80) <{dim = -1 : si32}> : (tensor<1x32x12x50xf32>, tensor<1x32x12x50xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2131) %82 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2132) - %83 = "ttir.multiply"(%81, %51, %82) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x1x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2132) + %83 = "ttir.multiply"(%81, %51, %82) <{operandSegmentSizes = array}> : (tensor<1x32x12x100xf32>, tensor<1x1x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2132) %84 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2133) - %85 = "ttir.add"(%65, %83, %84) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x32x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2133) + %85 = "ttir.add"(%65, %83, %84) <{operandSegmentSizes = array}> : (tensor<1x32x12x100xf32>, tensor<1x32x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2133) %86 = tensor.empty() : tensor<32x12x100xf32> loc(#loc2134) - %87 = "ttir.squeeze"(%85, %86) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<32x12x100xf32>) -> tensor<32x12x100xf32> loc(#loc2134) + %87 = "ttir.squeeze"(%85, %86) <{dim = 0 : si32}> : (tensor<1x32x12x100xf32>, tensor<32x12x100xf32>) -> tensor<32x12x100xf32> loc(#loc2134) %88 = tensor.empty() : tensor<32x100x12xf32> loc(#loc2135) - %89 = "ttir.transpose"(%87, %88) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<32x12x100xf32>, tensor<32x100x12xf32>) -> tensor<32x100x12xf32> loc(#loc2135) + %89 = "ttir.transpose"(%87, %88) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<32x12x100xf32>, tensor<32x100x12xf32>) -> tensor<32x100x12xf32> loc(#loc2135) %90 = tensor.empty() : tensor<32x12x12xf32> loc(#loc2136) - %91 = "ttir.matmul"(%57, %89, %90) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x12x100xf32>, tensor<32x100x12xf32>, tensor<32x12x12xf32>) -> tensor<32x12x12xf32> loc(#loc2136) + %91 = "ttir.matmul"(%57, %89, %90) : (tensor<32x12x100xf32>, tensor<32x100x12xf32>, tensor<32x12x12xf32>) -> tensor<32x12x12xf32> loc(#loc2136) %92 = tensor.empty() : tensor<1x32x12x12xf32> loc(#loc2137) - %93 = "ttir.unsqueeze"(%91, %92) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<32x12x12xf32>, tensor<1x32x12x12xf32>) -> tensor<1x32x12x12xf32> loc(#loc2137) + %93 = "ttir.unsqueeze"(%91, %92) <{dim = 0 : si32}> : (tensor<32x12x12xf32>, tensor<1x32x12x12xf32>) -> tensor<1x32x12x12xf32> loc(#loc2137) %94 = tensor.empty() : tensor<1x32x12x12xf32> loc(#loc2138) - %95 = "ttir.multiply"(%93, %arg9, %94) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x12xf32>, tensor<1xf32>, tensor<1x32x12x12xf32>) -> tensor<1x32x12x12xf32> loc(#loc2138) + %95 = "ttir.multiply"(%93, %arg9, %94) <{operandSegmentSizes = array}> : (tensor<1x32x12x12xf32>, tensor<1xf32>, tensor<1x32x12x12xf32>) -> tensor<1x32x12x12xf32> loc(#loc2138) %96 = tensor.empty() : tensor<1x32x12x12xf32> loc(#loc2139) - %97 = "ttir.add"(%95, %arg10, %96) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x12x12xf32>, tensor<1x1x12x12xf32>, tensor<1x32x12x12xf32>) -> tensor<1x32x12x12xf32> loc(#loc2139) + %97 = "ttir.add"(%95, %arg10, %96) <{operandSegmentSizes = array}> : (tensor<1x32x12x12xf32>, tensor<1x1x12x12xf32>, tensor<1x32x12x12xf32>) -> tensor<1x32x12x12xf32> loc(#loc2139) %98 = tensor.empty() : tensor<1x32x12x12xf32> loc(#loc2140) - %99 = "ttir.softmax"(%97, %98) <{dimension = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x12x12xf32>, tensor<1x32x12x12xf32>) -> tensor<1x32x12x12xf32> loc(#loc2140) + %99 = "ttir.softmax"(%97, %98) <{dimension = -1 : si32}> : (tensor<1x32x12x12xf32>, tensor<1x32x12x12xf32>) -> tensor<1x32x12x12xf32> loc(#loc2140) %100 = tensor.empty() : tensor<32x12x12xf32> loc(#loc2141) - %101 = "ttir.squeeze"(%99, %100) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x12x12xf32>, tensor<32x12x12xf32>) -> tensor<32x12x12xf32> loc(#loc2141) + %101 = "ttir.squeeze"(%99, %100) <{dim = 0 : si32}> : (tensor<1x32x12x12xf32>, tensor<32x12x12xf32>) -> tensor<32x12x12xf32> loc(#loc2141) %102 = tensor.empty() : tensor<12x3200xf32> loc(#loc2142) - %103 = "ttir.matmul"(%17, %arg293, %102) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<12x3200xf32>, tensor<3200x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2142) + %103 = "ttir.matmul"(%17, %arg293, %102) : (tensor<12x3200xf32>, tensor<3200x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2142) %104 = tensor.empty() : tensor<1x12x32x100xf32> loc(#loc2143) - %105 = "ttir.reshape"(%103, %104) <{operand_constraints = [#any_device, #any_device], shape = [1 : i32, 12 : i32, 32 : i32, 100 : i32]}> : (tensor<12x3200xf32>, tensor<1x12x32x100xf32>) -> tensor<1x12x32x100xf32> loc(#loc2143) + %105 = "ttir.reshape"(%103, %104) <{shape = [1 : i32, 12 : i32, 32 : i32, 100 : i32]}> : (tensor<12x3200xf32>, tensor<1x12x32x100xf32>) -> tensor<1x12x32x100xf32> loc(#loc2143) %106 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2144) - %107 = "ttir.transpose"(%105, %106) <{dim0 = -3 : si32, dim1 = -2 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x12x32x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2144) + %107 = "ttir.transpose"(%105, %106) <{dim0 = -3 : si32, dim1 = -2 : si32}> : (tensor<1x12x32x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2144) %108 = tensor.empty() : tensor<1x32x100x12xf32> loc(#loc2145) - %109 = "ttir.transpose"(%107, %108) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x32x100x12xf32>) -> tensor<1x32x100x12xf32> loc(#loc2145) + %109 = "ttir.transpose"(%107, %108) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<1x32x12x100xf32>, tensor<1x32x100x12xf32>) -> tensor<1x32x100x12xf32> loc(#loc2145) %110 = tensor.empty() : tensor<32x100x12xf32> loc(#loc2146) - %111 = "ttir.squeeze"(%109, %110) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x100x12xf32>, tensor<32x100x12xf32>) -> tensor<32x100x12xf32> loc(#loc2146) + %111 = "ttir.squeeze"(%109, %110) <{dim = 0 : si32}> : (tensor<1x32x100x12xf32>, tensor<32x100x12xf32>) -> tensor<32x100x12xf32> loc(#loc2146) %112 = tensor.empty() : tensor<32x12x100xf32> loc(#loc2147) - %113 = "ttir.transpose"(%111, %112) <{dim0 = -2 : si32, dim1 = -1 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<32x100x12xf32>, tensor<32x12x100xf32>) -> tensor<32x12x100xf32> loc(#loc2147) + %113 = "ttir.transpose"(%111, %112) <{dim0 = -2 : si32, dim1 = -1 : si32}> : (tensor<32x100x12xf32>, tensor<32x12x100xf32>) -> tensor<32x12x100xf32> loc(#loc2147) %114 = tensor.empty() : tensor<32x12x100xf32> loc(#loc2148) - %115 = "ttir.matmul"(%101, %113, %114) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32x12x12xf32>, tensor<32x12x100xf32>, tensor<32x12x100xf32>) -> tensor<32x12x100xf32> loc(#loc2148) + %115 = "ttir.matmul"(%101, %113, %114) : (tensor<32x12x12xf32>, tensor<32x12x100xf32>, tensor<32x12x100xf32>) -> tensor<32x12x100xf32> loc(#loc2148) %116 = tensor.empty() : tensor<1x32x12x100xf32> loc(#loc2149) - %117 = "ttir.unsqueeze"(%115, %116) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<32x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2149) + %117 = "ttir.unsqueeze"(%115, %116) <{dim = 0 : si32}> : (tensor<32x12x100xf32>, tensor<1x32x12x100xf32>) -> tensor<1x32x12x100xf32> loc(#loc2149) %118 = tensor.empty() : tensor<1x12x32x100xf32> loc(#loc2150) - %119 = "ttir.transpose"(%117, %118) <{dim0 = -3 : si32, dim1 = -2 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x12x100xf32>, tensor<1x12x32x100xf32>) -> tensor<1x12x32x100xf32> loc(#loc2150) + %119 = "ttir.transpose"(%117, %118) <{dim0 = -3 : si32, dim1 = -2 : si32}> : (tensor<1x32x12x100xf32>, tensor<1x12x32x100xf32>) -> tensor<1x12x32x100xf32> loc(#loc2150) %120 = tensor.empty() : tensor<12x3200xf32> loc(#loc2151) - %121 = "ttir.reshape"(%119, %120) <{operand_constraints = [#any_device, #any_device], shape = [12 : i32, 3200 : i32]}> : (tensor<1x12x32x100xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2151) + %121 = "ttir.reshape"(%119, %120) <{shape = [12 : i32, 3200 : i32]}> : (tensor<1x12x32x100xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2151) %122 = tensor.empty() : tensor<12x3200xf32> loc(#loc2152) - %123 = "ttir.matmul"(%121, %arg294, %122) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<12x3200xf32>, tensor<3200x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2152) + %123 = "ttir.matmul"(%121, %arg294, %122) : (tensor<12x3200xf32>, tensor<3200x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2152) %124 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2153) - %125 = "ttir.unsqueeze"(%123, %124) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2153) + %125 = "ttir.unsqueeze"(%123, %124) <{dim = 0 : si32}> : (tensor<12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2153) %126 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2154) - %127 = "ttir.add"(%1, %125, %126) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device, #any_device, #any_device, #any_device]}> : (tensor<1x12x3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2154) + %127 = "ttir.add"(%1, %125, %126) <{operandSegmentSizes = array}> : (tensor<1x12x3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2154) %128 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2155) - %129 = "ttir.multiply"(%127, %127, %128) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2155) + %129 = "ttir.multiply"(%127, %127, %128) <{operandSegmentSizes = array}> : (tensor<1x12x3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2155) %130 = tensor.empty() : tensor<1x12x1xf32> loc(#loc2156) - %131 = "ttir.mean"(%129, %130) <{dim_arg = [-1 : i32], keep_dim = true, operand_constraints = [#any_device, #any_device]}> : (tensor<1x12x3200xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2156) + %131 = "ttir.mean"(%129, %130) <{dim_arg = [-1 : i32], keep_dim = true}> : (tensor<1x12x3200xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2156) %132 = tensor.empty() : tensor<1x12x1xf32> loc(#loc2157) - %133 = "ttir.add"(%131, %arg11, %132) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x1xf32>, tensor<1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2157) + %133 = "ttir.add"(%131, %arg11, %132) <{operandSegmentSizes = array}> : (tensor<1x12x1xf32>, tensor<1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2157) %134 = tensor.empty() : tensor<1x12x1xf32> loc(#loc2158) - %135 = "ttir.sqrt"(%133, %134) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x12x1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2158) + %135 = "ttir.sqrt"(%133, %134) <{operandSegmentSizes = array}> : (tensor<1x12x1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2158) %136 = tensor.empty() : tensor<1x12x1xf32> loc(#loc2159) - %137 = "ttir.reciprocal"(%135, %136) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x12x1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2159) + %137 = "ttir.reciprocal"(%135, %136) <{operandSegmentSizes = array}> : (tensor<1x12x1xf32>, tensor<1x12x1xf32>) -> tensor<1x12x1xf32> loc(#loc2159) %138 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2160) - %139 = "ttir.multiply"(%127, %137, %138) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x3200xf32>, tensor<1x12x1xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2160) + %139 = "ttir.multiply"(%127, %137, %138) <{operandSegmentSizes = array}> : (tensor<1x12x3200xf32>, tensor<1x12x1xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2160) %140 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2161) - %141 = "ttir.multiply"(%arg295, %139, %140) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2161) + %141 = "ttir.multiply"(%arg295, %139, %140) <{operandSegmentSizes = array}> : (tensor<3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2161) %142 = tensor.empty() : tensor<12x3200xf32> loc(#loc2162) - %143 = "ttir.squeeze"(%141, %142) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2162) + %143 = "ttir.squeeze"(%141, %142) <{dim = 0 : si32}> : (tensor<1x12x3200xf32>, tensor<12x3200xf32>) -> tensor<12x3200xf32> loc(#loc2162) %144 = tensor.empty() : tensor<12x8640xf32> loc(#loc2163) - %145 = "ttir.matmul"(%143, %arg296, %144) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<12x3200xf32>, tensor<3200x8640xf32>, tensor<12x8640xf32>) -> tensor<12x8640xf32> loc(#loc2163) + %145 = "ttir.matmul"(%143, %arg296, %144) : (tensor<12x3200xf32>, tensor<3200x8640xf32>, tensor<12x8640xf32>) -> tensor<12x8640xf32> loc(#loc2163) %146 = tensor.empty() : tensor<1x12x8640xf32> loc(#loc2164) - %147 = "ttir.unsqueeze"(%145, %146) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<12x8640xf32>, tensor<1x12x8640xf32>) -> tensor<1x12x8640xf32> loc(#loc2164) + %147 = "ttir.unsqueeze"(%145, %146) <{dim = 0 : si32}> : (tensor<12x8640xf32>, tensor<1x12x8640xf32>) -> tensor<1x12x8640xf32> loc(#loc2164) %148 = tensor.empty() : tensor<1x12x8640xf32> loc(#loc2165) - %149 = "ttir.sigmoid"(%147, %148) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device]}> : (tensor<1x12x8640xf32>, tensor<1x12x8640xf32>) -> tensor<1x12x8640xf32> loc(#loc2165) + %149 = "ttir.sigmoid"(%147, %148) <{operandSegmentSizes = array}> : (tensor<1x12x8640xf32>, tensor<1x12x8640xf32>) -> tensor<1x12x8640xf32> loc(#loc2165) %150 = tensor.empty() : tensor<1x12x8640xf32> loc(#loc2166) - %151 = "ttir.multiply"(%147, %149, %150) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x8640xf32>, tensor<1x12x8640xf32>, tensor<1x12x8640xf32>) -> tensor<1x12x8640xf32> loc(#loc2166) + %151 = "ttir.multiply"(%147, %149, %150) <{operandSegmentSizes = array}> : (tensor<1x12x8640xf32>, tensor<1x12x8640xf32>, tensor<1x12x8640xf32>) -> tensor<1x12x8640xf32> loc(#loc2166) %152 = tensor.empty() : tensor<12x8640xf32> loc(#loc2167) - %153 = "ttir.matmul"(%143, %arg297, %152) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<12x3200xf32>, tensor<3200x8640xf32>, tensor<12x8640xf32>) -> tensor<12x8640xf32> loc(#loc2167) + %153 = "ttir.matmul"(%143, %arg297, %152) : (tensor<12x3200xf32>, tensor<3200x8640xf32>, tensor<12x8640xf32>) -> tensor<12x8640xf32> loc(#loc2167) %154 = tensor.empty() : tensor<1x12x8640xf32> loc(#loc2168) - %155 = "ttir.unsqueeze"(%153, %154) <{dim = 0 : si32, operand_constraints = [#any_device, #any_device]}> : (tensor<12x8640xf32>, tensor<1x12x8640xf32>) -> tensor<1x12x8640xf32> loc(#loc2168) + %155 = "ttir.unsqueeze"(%153, %154) <{dim = 0 : si32}> : (tensor<12x8640xf32>, tensor<1x12x8640xf32>) -> tensor<1x12x8640xf32> loc(#loc2168) %156 = tensor.empty() : tensor<1x12x8640xf32> loc(#loc2169) - %157 = "ttir.multiply"(%151, %155, %156) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x8640xf32>, tensor<1x12x8640xf32>, tensor<1x12x8640xf32>) -> tensor<1x12x8640xf32> loc(#loc2169) + %157 = "ttir.multiply"(%151, %155, %156) <{operandSegmentSizes = array}> : (tensor<1x12x8640xf32>, tensor<1x12x8640xf32>, tensor<1x12x8640xf32>) -> tensor<1x12x8640xf32> loc(#loc2169) %158 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2170) - %159 = "ttir.matmul"(%157, %arg298, %158) <{operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x12x8640xf32>, tensor<8640x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2170) + %159 = "ttir.matmul"(%157, %arg298, %158) : (tensor<1x12x8640xf32>, tensor<8640x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2170) %160 = tensor.empty() : tensor<1x12x3200xf32> loc(#loc2171) - %161 = "ttir.add"(%127, %159, %160) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device, #any_device, #any_device, #any_device]}> : (tensor<1x12x3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2171) + %161 = "ttir.add"(%127, %159, %160) <{operandSegmentSizes = array}> : (tensor<1x12x3200xf32>, tensor<1x12x3200xf32>, tensor<1x12x3200xf32>) -> tensor<1x12x3200xf32> loc(#loc2171) return %161 : tensor<1x12x3200xf32> loc(#loc2090) } loc(#loc) } loc(#loc)