Skip to content

Commit

Permalink
moved calculateLogicalShardShape into TTNN & added calculateLogicalSh…
Browse files Browse the repository at this point in the history
…ardShapeForInterleaved
  • Loading branch information
fbajraktariTT committed Dec 16, 2024
1 parent b5b140d commit d4d33fe
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
DataType getDataType() const;
uint64_t getElementSizeBytes() const;
int64_t getTensorSizeInBytes(ArrayRef<int64_t> tensorShape, ::mlir::tt::DeviceAttr device) const;
static llvm::SmallVector<int64_t> calculateLogicalShardShapeForSharding(ArrayRef<int64_t> tensorShape, mlir::AffineMap linear, GridAttr grid);
static llvm::SmallVector<int64_t> calculateLogicalShardShapeForInterleaved(ArrayRef<int64_t> tensorShape, Type elementType, mlir::AffineMap linear, GridAttr grid);
llvm::SmallVector<int64_t> getStride(ArrayRef<int64_t> logicalShape) const;
llvm::SmallVector<int64_t> getShardShape() const;
llvm::SmallVector<int64_t> getScalarShardShape() const;
Expand Down
50 changes: 46 additions & 4 deletions lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@

#include <numeric>

#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"
#include "ttmlir/Utils.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir::tt::ttnn;

Expand Down Expand Up @@ -68,6 +65,50 @@ bool TTNNLayoutAttr::hasInterleavedDRAMTensorMemoryLayout() const {
(getMemLayout().getValue() == TensorMemoryLayout::Interleaved);
}

// Calculate the logical shape of the shard.
// Shard is defined as a piece of the tensor that is mapped to a single grid
// core. It is assumed that the TensorMemoryLayout is sharded.
llvm::SmallVector<int64_t>
TTNNLayoutAttr::calculateLogicalShardShapeForSharding(
ArrayRef<int64_t> tensorShape, mlir::AffineMap linear, GridAttr grid) {
assert(linear.getNumResults() == grid.getShape().size());
mlir::SmallVector<std::int64_t> logicalShape =
ttmlir::utils::evalShape(linear, tensorShape);
mlir::SmallVector<std::int64_t> shardShape(linear.getNumResults());
for (unsigned i = 0; i < linear.getNumResults(); ++i) {
shardShape[i] =
(logicalShape[i] + grid.getShape()[i] - 1) / grid.getShape()[i];
}
return shardShape;
}

// Calculate the logical shape of the shard.
// Shard is defined as a piece of the tensor that is mapped to a single grid
// core. It is assumed that the TensorMemoryLayout is interleaved.
llvm::SmallVector<int64_t>
TTNNLayoutAttr::calculateLogicalShardShapeForInterleaved(
ArrayRef<int64_t> tensorShape, mlir::Type elementType,
mlir::AffineMap linear, mlir::tt::GridAttr grid) {
assert(linear.getNumResults() == grid.getShape().size());
assert(mlir::isa<mlir::tt::TileType>(elementType));

mlir::SmallVector<std::int64_t> logicalShape =
ttmlir::utils::evalShape(linear, tensorShape);
mlir::SmallVector<std::int64_t> logicalTiledShape =
mlir::cast<mlir::tt::TileType>(elementType).getTiledShape(logicalShape);
uint64_t numOfTiles =
std::accumulate(logicalTiledShape.begin(), logicalTiledShape.end(), 1,
std::multiplies<std::int64_t>());
uint64_t numOfGridUnits =
std::accumulate(grid.getShape().begin(), grid.getShape().end(), 1,
std::multiplies<std::int64_t>());

mlir::SmallVector<std::int64_t> shardShape;
shardShape.push_back(1);
shardShape.push_back((numOfTiles + numOfGridUnits - 1) / numOfGridUnits);
return mlir::cast<mlir::tt::TileType>(elementType).getScalarShape(shardShape);
}

// Get stride given tensor logical shape
llvm::SmallVector<int64_t>
TTNNLayoutAttr::getStride(ArrayRef<int64_t> logicalShape) const {
Expand Down Expand Up @@ -460,7 +501,8 @@ TTNNLayoutAttr TTNNLayoutAttr::get(
// Calculate shard shape by evaluating the linear map with last element
// of the tensor shape and dividing it by the grid shape
mlir::SmallVector<int64_t, 4> shardShape =
calculateLogicalShardShape(tensorShape, linear, grid);
TTNNLayoutAttr::calculateLogicalShardShapeForSharding(tensorShape, linear,
grid);
// Build memref type with the given parameters
MemRefType memRefType = buildMemRef<BufferType, BufferTypeAttr>(
context, shardShape, elementType, bufferType);
Expand Down

0 comments on commit d4d33fe

Please sign in to comment.