Skip to content

Commit

Permalink
Added calculation of logical shard shape for L1 Interleaved tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
fbajraktariTT committed Dec 16, 2024
1 parent d4d33fe commit 3a99432
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 58 deletions.
32 changes: 10 additions & 22 deletions lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,14 @@

namespace mlir::tt::ttnn {

uint64_t getOpOutputL1Usage(Operation *op, TTNNLayoutAttr opLayout,
DeviceAttr &deviceAttr) {
uint64_t getOpOutputL1Usage(TTNNLayoutAttr opLayout) {
// In case the opLayout is not in L1 memory space, L1 memory usage is 0.
//
if (opLayout.hasDRAMBufferType()) {
return 0;
}

// L1 memory usage of the ops without output tensors cannot be calculated.
// So far, this is only false for ttnn.get_device op.
//
assert(mlir::isa<RankedTensorType>(op->getResult(0).getType()));
llvm::ArrayRef<int64_t> opOutputTensorShape =
mlir::cast<RankedTensorType>(op->getResult(0).getType()).getShape();

uint64_t opL1OutputUsage =
opLayout.getTensorSizeInBytes(opOutputTensorShape, deviceAttr);
return opL1OutputUsage;
return opLayout.getShardSizeInBytes();
}

L1InterleavedPolicy::OpConfig L1InterleavedPolicy::getGreedyConfig(
Expand Down Expand Up @@ -149,7 +139,6 @@ L1InterleavedPolicy::OpConfig L1InterleavedPolicy::getGreedyConfig(
void L1InterleavedPolicy::run() {
for (Operation &funcOp : rootOp->getRegion(0).getOps()) {
func::FuncOp func = dyn_cast<func::FuncOp>(funcOp);
DeviceAttr deviceAttr = getCurrentScopeDevice(func);

// Start the policy.
//
Expand Down Expand Up @@ -186,7 +175,7 @@ void L1InterleavedPolicy::run() {
if (op->hasOneUse() && hasL1BufferType(op)) {
L1Usage l1Usage;
l1Usage.outputL1Usage =
getOpOutputL1Usage(op, getL1InterleavedLayout(op), deviceAttr);
getOpOutputL1Usage(getL1InterleavedLayout(op));
l1Usage.requiredL1Usage = 0;
opsL1Usage[op] = l1Usage;
}
Expand All @@ -211,8 +200,7 @@ void L1InterleavedPolicy::run() {
//
if (operandOpLayout.hasInterleavedL1TensorMemoryLayout()) {
L1Usage l1Usage;
l1Usage.outputL1Usage =
getOpOutputL1Usage(operandOp, operandOpLayout, deviceAttr);
l1Usage.outputL1Usage = getOpOutputL1Usage(operandOpLayout);
l1Usage.requiredL1Usage = OpMemSpecMap[operandOp].requiredL1Usage;
opsL1Usage[operandOp] = l1Usage;
}
Expand Down Expand Up @@ -271,14 +259,14 @@ void L1InterleavedPolicy::run() {
std::max(intermediateRequiredL1Usage,
intermediateL1Usage +
OpMemSpecMap[operandOp].requiredL1Usage);
intermediateL1Usage += getOpOutputL1Usage(
operandOp, OpMemSpecMap[operandOp].layout, deviceAttr);
intermediateL1Usage +=
getOpOutputL1Usage(OpMemSpecMap[operandOp].layout);
}
}
OpMemSpecMap[op].requiredL1Usage = std::max(
intermediateRequiredL1Usage,
intermediateL1Usage +
getOpOutputL1Usage(op, OpMemSpecMap[op].layout, deviceAttr));
OpMemSpecMap[op].requiredL1Usage =
std::max(intermediateRequiredL1Usage,
intermediateL1Usage +
getOpOutputL1Usage(OpMemSpecMap[op].layout));
}
}
}
Expand Down
25 changes: 13 additions & 12 deletions lib/Dialect/TTNN/Analysis/LegalLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,20 @@ void LegalLayoutAnalysis::analysisImplementation() {
TensorMemoryLayoutAttr::get(op->getContext(),
TensorMemoryLayout::Interleaved)));

// L1 Interleaved (same as above).
analysisResult.push_back(TTNNLayoutAttr::get(
op->getContext(), tensorShape, elementType, BufferType::L1,
analysisInput.maxGrid,
TensorMemoryLayoutAttr::get(op->getContext(),
TensorMemoryLayout::Interleaved)));
// L1 Interleaved - It must be tiled.
if (elementType == tileElementType) {
analysisResult.push_back(TTNNLayoutAttr::get(
op->getContext(), tensorShape, elementType, BufferType::L1,
analysisInput.maxGrid,
TensorMemoryLayoutAttr::get(op->getContext(),
TensorMemoryLayout::Interleaved)));
}

// L1 Sharded
TTNNLayoutAttr shardedBase =
layout.withBufferType(op->getContext(), BufferType::L1)
.withMemoryLayout(op->getContext(),
TensorMemoryLayout::BlockSharded)
.withElementType(op->getContext(), elementType);

assert(analysisInput.maxGrid.getShape().size() == 2 &&
Expand All @@ -246,12 +250,9 @@ void LegalLayoutAnalysis::analysisImplementation() {
for (int width = 1; width <= analysisInput.maxGrid.getShape()[0]; ++width) {
for (int height = 1; height <= analysisInput.maxGrid.getShape()[1];
++height) {
shardedResults.push_back(
shardedBase
.withGrid(op->getContext(), tensorType,
GridAttr::get(op->getContext(), {width, height}))
.withMemoryLayout(op->getContext(),
TensorMemoryLayout::BlockSharded));
shardedResults.push_back(shardedBase.withGrid(
op->getContext(), tensorType,
GridAttr::get(op->getContext(), {width, height})));
}
}

Expand Down
21 changes: 15 additions & 6 deletions lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,15 +494,24 @@ TTNNLayoutAttr TTNNLayoutAttr::get(
Type elementType, BufferType bufferType, GridAttr grid,
TensorMemoryLayoutAttr memLayoutAttr,
ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals) {

// Construct a new affine map which will be used to map from logical
// space to physical space
// space to physical space.
AffineMap linear = collapsedLinearAffineMap(
context, tensorShape, grid.getShape(), collapseIntervals);
// Calculate shard shape by evaluating the linear map with last element
// of the tensor shape and dividing it by the grid shape
mlir::SmallVector<int64_t, 4> shardShape =
TTNNLayoutAttr::calculateLogicalShardShapeForSharding(tensorShape, linear,
grid);

// Calculate shard shape by evaluating the linear map.
//
mlir::SmallVector<int64_t> shardShape;
if (bufferType == BufferType::L1 &&
memLayoutAttr.getValue() == TensorMemoryLayout::Interleaved) {
shardShape = TTNNLayoutAttr::calculateLogicalShardShapeForInterleaved(
tensorShape, elementType, linear, grid);
} else {
shardShape = TTNNLayoutAttr::calculateLogicalShardShapeForSharding(
tensorShape, linear, grid);
}

// Build memref type with the given parameters
MemRefType memRefType = buildMemRef<BufferType, BufferTypeAttr>(
context, shardShape, elementType, bufferType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
// CHECK-DAG: #[[LOC_MATMUL_IN0:.*]] = loc("matmul_1_in_0_layout"(#loc3))
// CHECK-DAG: #[[LOC_MATMUL_IN1:.*]] = loc("matmul_1_in_1_layout"(#loc3))
// CHECK-DAG: #[[LOC_MATMUL:.*]] = loc("matmul_1"(#loc3))
// CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<4x3x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[IN_1_LAYOUT:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <1x1>, memref<1x12x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>

module attributes {} {
func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> {
%0 = tensor.empty() : tensor<64x96xbf16> loc(#loc2)
// CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} loc(#[[LOC_MATMUL_IN0]])
// CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<#l1_, <<4x3>>, <interleaved>>}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]])
// CHECK-DAG: %{{.*}} = "ttnn.to_device"{{.*}} <{memory_config = #ttnn.memory_config<#l1_, <<1x12>>, <interleaved>>}> : {{.*}} -> tensor<128x96xbf16, #[[IN_1_LAYOUT]]> loc(#[[LOC_MATMUL_IN1]])
// CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} loc(#[[LOC_MATMUL]])
%1 = "ttir.matmul"(%arg0, %arg1, %0) : (tensor<64x128xbf16>, tensor<128x96xbf16>, tensor<64x96xbf16>) -> tensor<64x96xbf16> loc(#loc2)
return %1 : tensor<64x96xbf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module attributes {} {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
// CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x16x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
%0 = tensor.empty() : tensor<5120x4096xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x4096xbf16, #[[LAYOUT_4]]>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<5120x4096xbf16>, tensor<5120x4096xbf16>, tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module attributes {} {
func.func @forward(%arg0: tensor<4096x5120xbf16>, %arg1: tensor<4096x5120xbf16>, %arg2: tensor<5120x5120xbf16>, %arg3: tensor<5120x5120xbf16>) -> tensor<4096x5120xbf16> {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
// CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<16x20x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
%0 = tensor.empty() : tensor<4096x5120xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<4096x5120xbf16, #[[LAYOUT_3]]>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<4096x5120xbf16>, tensor<4096x5120xbf16>, tensor<4096x5120xbf16>) -> tensor<4096x5120xbf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module attributes {} {
func.func @forward(%arg0: tensor<2048x2048xbf16>, %arg1: tensor<2048x2048xbf16>, %arg2: tensor<2048x8192xbf16>, %arg3: tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16> {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
// CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x256x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
%0 = tensor.empty() : tensor<2048x2048xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_3]]>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<2048x2048xbf16>, tensor<2048x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module attributes {} {
func.func @forward(%arg0: tensor<5120x5120xbf16>, %arg1: tensor<5120x5120xbf16>, %arg2: tensor<5120x4096xbf16>, %arg3: tensor<5120x4096xbf16>) -> tensor<5120x4096xbf16> {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
// CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x16x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<20x20x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <{{.*}}>, memref<1x400x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
%0 = tensor.empty() : tensor<5120x5120xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<5120x5120xbf16, #[[LAYOUT_5]]>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<5120x5120xbf16>, tensor<5120x5120xbf16>, tensor<5120x5120xbf16>) -> tensor<5120x5120xbf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module attributes {} {
func.func @forward(%arg0: tensor<8192x2048xbf16>, %arg1: tensor<8192x2048xbf16>, %arg2: tensor<2048x2048xbf16>, %arg3: tensor<2048x2048xbf16>) -> tensor<8192x2048xbf16> {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
// CHECK-DAG: #[[LAYOUT_3:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_5:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x256x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
%0 = tensor.empty() : tensor<8192x2048xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_5]]>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@
module attributes {} {
func.func @forward(%arg0: tensor<2048x8192xbf16>, %arg1: tensor<2048x8192xbf16>, %arg2: tensor<8192x2048xbf16>, %arg3: tensor<8192x2048xbf16>) -> tensor<2048x2048xbf16> {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
// CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x32x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<32x8x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_7:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_4:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<1x256x!tt.tile<32x32, bf16>, #l1_>, <interleaved>>
// CHECK-DAG: #[[LAYOUT_6:.*]] = #ttnn.ttnn_layout<(d0, d1) -> (d0, d1), <8x8, (d0, d1) -> (0, d0, d1)>, memref<8x8x!tt.tile<32x32, bf16>, #dram>, <interleaved>>
%0 = tensor.empty() : tensor<2048x8192xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<2048x8192xbf16, #[[LAYOUT_4]]>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<2048x8192xbf16>, tensor<2048x8192xbf16>, tensor<2048x8192xbf16>) -> tensor<2048x8192xbf16>
%2 = tensor.empty() : tensor<8192x2048xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_6]]>
// CHECK-DAG: %{{.*}} = "ttnn.add"{{.*}} -> tensor<8192x2048xbf16, #[[LAYOUT_4]]>
%3 = "ttir.add"(%arg2, %arg3, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<8192x2048xbf16>, tensor<8192x2048xbf16>, tensor<8192x2048xbf16>) -> tensor<8192x2048xbf16>
%4 = tensor.empty() : tensor<2048x2048xbf16>
// CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_7]]>
// CHECK-DAG: %{{.*}} = "ttnn.matmul"{{.*}} -> tensor<2048x2048xbf16, #[[LAYOUT_6]]>
%5 = "ttir.matmul"(%1, %3, %4) : (tensor<2048x8192xbf16>, tensor<8192x2048xbf16>, tensor<2048x2048xbf16>) -> tensor<2048x2048xbf16>
return %5 : tensor<2048x2048xbf16>
}
Expand Down
13 changes: 7 additions & 6 deletions test/unittests/Optimizer/TestL1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

Expand Down Expand Up @@ -88,14 +89,14 @@ class L1InterleavedPolicyBase : public ::testing::Test {
TensorMemoryLayoutAttr::get(&context, tensorMemoryLayout);
if (legalLayouts.find(op) == legalLayouts.end()) {
legalLayouts[op] = std::vector<TTNNLayoutAttr>{TTNNLayoutAttr::get(
&context, getTensorRankedType().getShape(), builder.getF32Type(),
memorySpace, mlir::tt::GridAttr::get(&context, {8, 8}),
tensorMemoryLayoutAttr)};
&context, getTensorRankedType().getShape(),
mlir::tt::TileType::get(&context, builder.getF32Type()), memorySpace,
mlir::tt::GridAttr::get(&context, {8, 8}), tensorMemoryLayoutAttr)};
} else {
legalLayouts[op].push_back(TTNNLayoutAttr::get(
&context, getTensorRankedType().getShape(), builder.getF32Type(),
memorySpace, mlir::tt::GridAttr::get(&context, {8, 8}),
tensorMemoryLayoutAttr));
&context, getTensorRankedType().getShape(),
mlir::tt::TileType::get(&context, builder.getF32Type()), memorySpace,
mlir::tt::GridAttr::get(&context, {8, 8}), tensorMemoryLayoutAttr));
}
}

Expand Down

0 comments on commit 3a99432

Please sign in to comment.