diff --git a/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp b/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp index 69a07af168..f5723d1f7d 100644 --- a/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp @@ -8,24 +8,14 @@ namespace mlir::tt::ttnn { -uint64_t getOpOutputL1Usage(Operation *op, TTNNLayoutAttr opLayout, - DeviceAttr &deviceAttr) { +uint64_t getOpOutputL1Usage(TTNNLayoutAttr opLayout) { // In case the opLayout is not in L1 memory space, L1 memory usage is 0. // if (opLayout.hasDRAMBufferType()) { return 0; } - // L1 memory usage of the ops without output tensors cannot be calculated. - // So far, this is only false for ttnn.get_device op. - // - assert(mlir::isa(op->getResult(0).getType())); - llvm::ArrayRef opOutputTensorShape = - mlir::cast(op->getResult(0).getType()).getShape(); - - uint64_t opL1OutputUsage = - opLayout.getTensorSizeInBytes(opOutputTensorShape, deviceAttr); - return opL1OutputUsage; + return opLayout.getShardSizeInBytes(); } L1InterleavedPolicy::OpConfig L1InterleavedPolicy::getGreedyConfig( @@ -149,7 +139,6 @@ L1InterleavedPolicy::OpConfig L1InterleavedPolicy::getGreedyConfig( void L1InterleavedPolicy::run() { for (Operation &funcOp : rootOp->getRegion(0).getOps()) { func::FuncOp func = dyn_cast(funcOp); - DeviceAttr deviceAttr = getCurrentScopeDevice(func); // Start the policy. // @@ -186,7 +175,7 @@ void L1InterleavedPolicy::run() { if (op->hasOneUse() && hasL1BufferType(op)) { L1Usage l1Usage; l1Usage.outputL1Usage = - getOpOutputL1Usage(op, getL1InterleavedLayout(op), deviceAttr); + getOpOutputL1Usage(getL1InterleavedLayout(op)); l1Usage.requiredL1Usage = 0; opsL1Usage[op] = l1Usage; } @@ -211,8 +200,7 @@ void L1InterleavedPolicy::run() { // if (operandOpLayout.hasInterleavedL1TensorMemoryLayout()) { L1Usage l1Usage; - l1Usage.outputL1Usage = - getOpOutputL1Usage(operandOp, operandOpLayout, deviceAttr); + l1Usage.outputL1Usage = getOpOutputL1Usage(operandOpLayout); l1Usage.requiredL1Usage = OpMemSpecMap[operandOp].requiredL1Usage; opsL1Usage[operandOp] = l1Usage; } @@ -271,14 +259,14 @@ void L1InterleavedPolicy::run() { std::max(intermediateRequiredL1Usage, intermediateL1Usage + OpMemSpecMap[operandOp].requiredL1Usage); - intermediateL1Usage += getOpOutputL1Usage( - operandOp, OpMemSpecMap[operandOp].layout, deviceAttr); + intermediateL1Usage += + getOpOutputL1Usage(OpMemSpecMap[operandOp].layout); } } - OpMemSpecMap[op].requiredL1Usage = std::max( - intermediateRequiredL1Usage, - intermediateL1Usage + - getOpOutputL1Usage(op, OpMemSpecMap[op].layout, deviceAttr)); + OpMemSpecMap[op].requiredL1Usage = + std::max(intermediateRequiredL1Usage, + intermediateL1Usage + + getOpOutputL1Usage(OpMemSpecMap[op].layout)); } } } diff --git a/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp b/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp index 3f4ef25ab2..432a3f959f 100644 --- a/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp +++ b/lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp @@ -228,16 +228,20 @@ void LegalLayoutAnalysis::analysisImplementation() { TensorMemoryLayoutAttr::get(op->getContext(), TensorMemoryLayout::Interleaved))); - // L1 Interleaved (same as above). - analysisResult.push_back(TTNNLayoutAttr::get( - op->getContext(), tensorShape, elementType, BufferType::L1, - analysisInput.maxGrid, - TensorMemoryLayoutAttr::get(op->getContext(), - TensorMemoryLayout::Interleaved))); + // L1 Interleaved - It must be tiled. + if (elementType == tileElementType) { + analysisResult.push_back(TTNNLayoutAttr::get( + op->getContext(), tensorShape, elementType, BufferType::L1, + analysisInput.maxGrid, + TensorMemoryLayoutAttr::get(op->getContext(), + TensorMemoryLayout::Interleaved))); + } // L1 Sharded TTNNLayoutAttr shardedBase = layout.withBufferType(op->getContext(), BufferType::L1) + .withMemoryLayout(op->getContext(), + TensorMemoryLayout::BlockSharded) .withElementType(op->getContext(), elementType); assert(analysisInput.maxGrid.getShape().size() == 2 && @@ -246,12 +250,9 @@ void LegalLayoutAnalysis::analysisImplementation() { for (int width = 1; width <= analysisInput.maxGrid.getShape()[0]; ++width) { for (int height = 1; height <= analysisInput.maxGrid.getShape()[1]; ++height) { - shardedResults.push_back( - shardedBase - .withGrid(op->getContext(), tensorType, - GridAttr::get(op->getContext(), {width, height})) - .withMemoryLayout(op->getContext(), - TensorMemoryLayout::BlockSharded)); + shardedResults.push_back(shardedBase.withGrid( + op->getContext(), tensorType, + GridAttr::get(op->getContext(), {width, height}))); } } diff --git a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp index d2b8b9c49d..a217dd4f05 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp @@ -494,15 +494,24 @@ TTNNLayoutAttr TTNNLayoutAttr::get( Type elementType, BufferType bufferType, GridAttr grid, TensorMemoryLayoutAttr memLayoutAttr, ArrayRef> collapseIntervals) { + // Construct a new affine map which will be used to map from logical - // space to physical space + // space to physical space. AffineMap linear = collapsedLinearAffineMap( context, tensorShape, grid.getShape(), collapseIntervals); - // Calculate shard shape by evaluating the linear map with last element - // of the tensor shape and dividing it by the grid shape - mlir::SmallVector shardShape = - TTNNLayoutAttr::calculateLogicalShardShapeForSharding(tensorShape, linear, - grid); + + // Calculate shard shape by evaluating the linear map. + // + mlir::SmallVector shardShape; + if (bufferType == BufferType::L1 && + memLayoutAttr.getValue() == TensorMemoryLayout::Interleaved) { + shardShape = TTNNLayoutAttr::calculateLogicalShardShapeForInterleaved( + tensorShape, elementType, linear, grid); + } else { + shardShape = TTNNLayoutAttr::calculateLogicalShardShapeForSharding( + tensorShape, linear, grid); + } + // Build memref type with the given parameters MemRefType memRefType = buildMemRef( context, shardShape, elementType, bufferType); diff --git a/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir b/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir index 97892500aa..16b0eb1b53 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/input_layout_loc_override.mlir @@ -3,13 +3,13 @@ // CHECK-DAG: #[[LOC_MATMUL_IN0:.*]] = loc("matmul_1_in_0_layout"(#loc3)) // CHECK-DAG: #[[LOC_MATMUL_IN1:.*]] = loc("matmul_1_in_1_layout"(#loc3)) // CHECK-DAG: #[[LOC_MATMUL:.*]] = loc("matmul_1"(#loc3)) -// CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<4x3x!tt.tile<32x32, bf16>, #l1_>, > +// CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x12x!tt.tile<32x32, bf16>, #l1_>, > module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> loc(#loc2) // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} loc(#[[LOC_MATMUL_IN0]]) - // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<#l1_, <<4x3>>, >}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]]) + // CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<#l1_, <<1x12>>, >}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]]) // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} loc(#[[LOC_MATMUL]]) %1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> loc(#loc2) return %1 : tensor<64x96xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir index 056ded8d35..a6a8af54cc 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AB_l1_C.mlir @@ -15,7 +15,7 @@ module attributes {} { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x16x!tt.tile<32x32, bf16>, #dram>, > // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<5120x4096xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_4]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<5120x4096xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir index caaf3254d8..fc39fb5d60 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_AC_l1_B.mlir @@ -14,7 +14,7 @@ module attributes {} { func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<4096x5120xbf16>, %arg2: tensor<5120x5120xbf16>, %arg3: tensor<5120x5120xbf16>) -> tensor<4096x5120xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<4096x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_3]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir index 63cd3bcaa2..270bd4a122 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_A_l1_BC.mlir @@ -14,7 +14,7 @@ module attributes {} { func.func @forward(%arg0: tensor<2048x2048xbf16>, %arg1: tensor<2048x2048xbf16>, %arg2: tensor<2048x8192xbf16>, %arg3: tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x256x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<2048x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_3]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<2048x2048xbf16>, tensor<2048x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir index 9f12e8b6f6..e37883a313 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_BC_l1_A.mlir @@ -14,7 +14,7 @@ module attributes {} { func.func @forward(%arg0: tensor<5120x5120xbf16>, %arg1: tensor<5120x5120xbf16>, %arg2: tensor<5120x4096xbf16>, %arg3: tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x16x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<5120x5120xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_5]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir index c594ca4182..2f3df7293b 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_B_l1_AC.mlir @@ -14,7 +14,7 @@ module attributes {} { func.func @forward(%arg0: tensor<8192x2048xbf16>, %arg1: tensor<8192x2048xbf16>, %arg2: tensor<2048x2048xbf16>, %arg3: tensor<2048x2048xbf16>) -> tensor<8192x2048xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type // CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > - // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x256x!tt.tile<32x32, bf16>, #l1_>, > %0 = tensor.empty() : tensor<8192x2048xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_5]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> diff --git a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir index eb2a51b174..77bda95204 100644 --- a/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir +++ b/test/ttmlir/Dialect/TTNN/optimizer/l1_interleaved_policy/simple_join_tests/dram_C_l1_AB.mlir @@ -13,17 +13,16 @@ module attributes {} { func.func @forward(%arg0: tensor<2048x8192xbf16>, %arg1: tensor<2048x8192xbf16>, %arg2: tensor<8192x2048xbf16>, %arg3: tensor<8192x2048xbf16>) -> tensor<2048x2048xbf16> { // CHECK: #[[L1_:.*]] = #ttnn.buffer_type - // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, > - // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, > - // CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > + // CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x256x!tt.tile<32x32, bf16>, #l1_>, > + // CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, > %0 = tensor.empty() : tensor<2048x8192xbf16> // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x8192xbf16, #[[LAYOUT_4]]> %1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array}> : (tensor<2048x8192xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> %2 = tensor.empty() : tensor<8192x2048xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_6]]> + // CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_4]]> %3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16> %4 = tensor.empty() : tensor<2048x2048xbf16> - // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_7]]> + // CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_6]]> %5 = "ttir.matmul"(%1, %3, %4) : (tensor<2048x8192xbf16>, tensor<8192x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16> return %5 : tensor<2048x2048xbf16> } diff --git a/test/unittests/Optimizer/TestL1InterleavedPolicy.cpp b/test/unittests/Optimizer/TestL1InterleavedPolicy.cpp index b09b65245d..9d026a3e48 100644 --- a/test/unittests/Optimizer/TestL1InterleavedPolicy.cpp +++ b/test/unittests/Optimizer/TestL1InterleavedPolicy.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNN.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" @@ -88,14 +89,14 @@ class L1InterleavedPolicyBase : public ::testing::Test { TensorMemoryLayoutAttr::get(&context, tensorMemoryLayout); if (legalLayouts.find(op) == legalLayouts.end()) { legalLayouts[op] = std::vector{TTNNLayoutAttr::get( - &context, getTensorRankedType().getShape(), builder.getF32Type(), - memorySpace, mlir::tt::GridAttr::get(&context, {8, 8}), - tensorMemoryLayoutAttr)}; + &context, getTensorRankedType().getShape(), + mlir::tt::TileType::get(&context, builder.getF32Type()), memorySpace, + mlir::tt::GridAttr::get(&context, {8, 8}), tensorMemoryLayoutAttr)}; } else { legalLayouts[op].push_back(TTNNLayoutAttr::get( - &context, getTensorRankedType().getShape(), builder.getF32Type(), - memorySpace, mlir::tt::GridAttr::get(&context, {8, 8}), - tensorMemoryLayoutAttr)); + &context, getTensorRankedType().getShape(), + mlir::tt::TileType::get(&context, builder.getF32Type()), memorySpace, + mlir::tt::GridAttr::get(&context, {8, 8}), tensorMemoryLayoutAttr)); } }