diff --git a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp index 8e748dce1..c7bf769dd 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp @@ -66,26 +66,42 @@ bool TTNNLayoutAttr::hasInterleavedDRAMTensorMemoryLayout() const { } // Calculate the logical shape of the shard. +// // Shard is defined as a piece of the tensor that is mapped to a single grid -// core. It is assumed that the TensorMemoryLayout is sharded. +// core. This function returns the shard shape for tensors with BLOCK SHARDED +// tensor memory layout. +// +// All examples assume that the tensor is mapped to a 8x8 grid. +// Example: tensor<32x32xbf16> -> {4, 4} +// Example: tensor<65x65xbf16> -> {9, 9} +// +// return The logical shard shape in case of block sharded tensor memory layout. llvm::SmallVector TTNNLayoutAttr::calculateLogicalShardShapeForSharding( ArrayRef tensorShape, mlir::AffineMap linear, GridAttr grid) { assert(linear.getNumResults() == grid.getShape().size()); - mlir::SmallVector logicalShape = + mlir::SmallVector physicalShape = ttmlir::utils::evalShape(linear, tensorShape); mlir::SmallVector shardShape(linear.getNumResults()); - for (unsigned i = 0; i < linear.getNumResults(); ++i) { + for (size_t i = 0; i < linear.getNumResults(); ++i) { shardShape[i] = - (logicalShape[i] + grid.getShape()[i] - 1) / grid.getShape()[i]; + (physicalShape[i] + grid.getShape()[i] - 1) / grid.getShape()[i]; } return shardShape; } // Calculate the logical shape of the shard. +// // Shard is defined as a piece of the tensor that is mapped to a single grid -// core. It is assumed that the TensorMemoryLayout is interleaved and buffer -// type is L1. +// core. This function returns the shard shape for tensors with INTERLEAVED +// tensor memory layout. +// +// All examples assume that the tensor is mapped to a 8x8 grid. +// Example: tensor<1x1024xbf16> ( -> 32 tiles ) -> {1, 1} +// Example: tensor<512x512xbf16> ( -> 256 tiles ) -> {1, 4} +// Example: tensor<32x2049xbf16> ( -> 65 tiles ) -> {1, 2} +// +// return The logical shard shape in case of interleaved tensor memory layout. llvm::SmallVector TTNNLayoutAttr::calculateLogicalShardShapeForL1Interleaved( ArrayRef tensorShape, mlir::Type elementType, @@ -93,12 +109,12 @@ TTNNLayoutAttr::calculateLogicalShardShapeForL1Interleaved( assert(linear.getNumResults() == grid.getShape().size()); assert(mlir::isa(elementType)); - mlir::SmallVector logicalShape = + mlir::SmallVector physicalShape = ttmlir::utils::evalShape(linear, tensorShape); - mlir::SmallVector logicalTiledShape = - mlir::cast(elementType).getTiledShape(logicalShape); + mlir::SmallVector physicalTiledShape = + mlir::cast(elementType).getTiledShape(physicalShape); uint64_t numOfTiles = - std::accumulate(logicalTiledShape.begin(), logicalTiledShape.end(), 1, + std::accumulate(physicalTiledShape.begin(), physicalTiledShape.end(), 1, std::multiplies()); uint64_t numOfGridUnits = std::accumulate(grid.getShape().begin(), grid.getShape().end(), 1, @@ -325,13 +341,13 @@ mlir::AffineMap TTNNLayoutAttr::replaceMemoryMapSymbolsWithShardShape( "shard rank"); SmallVector symReplacements; - for (unsigned i = 0; i < physicalMemoryMap.getNumSymbols(); ++i) { + for (size_t i = 0; i < physicalMemoryMap.getNumSymbols(); ++i) { symReplacements.push_back( getAffineConstantExpr(shardShape[i], getContext())); } SmallVector dimReplacements; - for (unsigned i = 0; i < physicalMemoryMap.getNumDims(); ++i) { + for (size_t i = 0; i < physicalMemoryMap.getNumDims(); ++i) { dimReplacements.push_back(getAffineDimExpr(i, getContext())); } @@ -501,8 +517,7 @@ TTNNLayoutAttr TTNNLayoutAttr::get( AffineMap linear = collapsedLinearAffineMap( context, tensorShape, grid.getShape(), collapseIntervals); - // Calculate shard shape by evaluating the linear map. - // + // Calculate shard shape mlir::SmallVector shardShape; if (bufferType == BufferType::L1 && memLayoutAttr.getValue() == TensorMemoryLayout::Interleaved) {