diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td index e483b07bf2..1069f6341b 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td @@ -164,6 +164,8 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> { DataType getDataType() const; uint64_t getElementSizeBytes() const; int64_t getTensorSizeInBytes(ArrayRef tensorShape, ::mlir::tt::DeviceAttr device) const; + static llvm::SmallVector calculateLogicalShardShapeForSharding(ArrayRef tensorShape, mlir::AffineMap linear, GridAttr grid); + static llvm::SmallVector calculateLogicalShardShapeForInterleaved(ArrayRef tensorShape, Type elementType, mlir::AffineMap linear, GridAttr grid); llvm::SmallVector getStride(ArrayRef logicalShape) const; llvm::SmallVector getShardShape() const; llvm::SmallVector getScalarShardShape() const; diff --git a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp index dbefdf3473..d2b8b9c49d 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp @@ -4,13 +4,10 @@ #include -#include "mlir/IR/Builders.h" -#include "mlir/IR/DialectImplementation.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/Utils/Utils.h" #include "ttmlir/Utils.h" -#include "llvm/ADT/TypeSwitch.h" using namespace mlir::tt::ttnn; @@ -68,6 +65,50 @@ bool TTNNLayoutAttr::hasInterleavedDRAMTensorMemoryLayout() const { (getMemLayout().getValue() == TensorMemoryLayout::Interleaved); } +// Calculate the logical shape of the shard. +// Shard is defined as a piece of the tensor that is mapped to a single grid +// core. It is assumed that the TensorMemoryLayout is sharded. +llvm::SmallVector +TTNNLayoutAttr::calculateLogicalShardShapeForSharding( + ArrayRef tensorShape, mlir::AffineMap linear, GridAttr grid) { + assert(linear.getNumResults() == grid.getShape().size()); + mlir::SmallVector logicalShape = + ttmlir::utils::evalShape(linear, tensorShape); + mlir::SmallVector shardShape(linear.getNumResults()); + for (unsigned i = 0; i < linear.getNumResults(); ++i) { + shardShape[i] = + (logicalShape[i] + grid.getShape()[i] - 1) / grid.getShape()[i]; + } + return shardShape; +} + +// Calculate the logical shape of the shard. +// Shard is defined as a piece of the tensor that is mapped to a single grid +// core. It is assumed that the TensorMemoryLayout is interleaved. +llvm::SmallVector +TTNNLayoutAttr::calculateLogicalShardShapeForInterleaved( + ArrayRef tensorShape, mlir::Type elementType, + mlir::AffineMap linear, mlir::tt::GridAttr grid) { + assert(linear.getNumResults() == grid.getShape().size()); + assert(mlir::isa(elementType)); + + mlir::SmallVector logicalShape = + ttmlir::utils::evalShape(linear, tensorShape); + mlir::SmallVector logicalTiledShape = + mlir::cast(elementType).getTiledShape(logicalShape); + uint64_t numOfTiles = + std::accumulate(logicalTiledShape.begin(), logicalTiledShape.end(), 1, + std::multiplies()); + uint64_t numOfGridUnits = + std::accumulate(grid.getShape().begin(), grid.getShape().end(), 1, + std::multiplies()); + + mlir::SmallVector shardShape; + shardShape.push_back(1); + shardShape.push_back((numOfTiles + numOfGridUnits - 1) / numOfGridUnits); + return mlir::cast(elementType).getScalarShape(shardShape); +} + // Get stride given tensor logical shape llvm::SmallVector TTNNLayoutAttr::getStride(ArrayRef logicalShape) const { @@ -460,7 +501,8 @@ TTNNLayoutAttr TTNNLayoutAttr::get( // Calculate shard shape by evaluating the linear map with last element // of the tensor shape and dividing it by the grid shape mlir::SmallVector shardShape = - calculateLogicalShardShape(tensorShape, linear, grid); + TTNNLayoutAttr::calculateLogicalShardShapeForSharding(tensorShape, linear, + grid); // Build memref type with the given parameters MemRefType memRefType = buildMemRef( context, shardShape, elementType, bufferType);