diff --git a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp index 43c5984ed..dbefdf347 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp @@ -157,12 +157,12 @@ mlir::tt::DataType TTNNLayoutAttr::getDataType() const { return elementTypeToDataType(elementType); } -// Gets the size of shard in bytes +// Get the size of the element in bytes // -// This function returns the size of the shard in bytes. -// Size is calculated by multiplying shard shape with element size. +// This function returns the size of a single tensor element in bytes. +// Distinction is made between scalar types and TileType. // -// return The size of the shard in bytes. +// return The size of the element in bytes. uint64_t TTNNLayoutAttr::getElementSizeBytes() const { mlir::Type elementType = getElementType(); if (isTiled()) { @@ -177,7 +177,7 @@ uint64_t TTNNLayoutAttr::getElementSizeBytes() const { // Return the shape of the shard. // Example: memref<2x2x!tt.tile<32x32xf32>> -> { 2, 2 } // Example: memref<128x128xf32> -> { 128, 128 } -// Example: memref<2x3!tt.tile<32x32xf32>> -> { 2, 3 } +// Example: memref<2x3x!tt.tile<32x32xf32>> -> { 2, 3 } // // return The shape of the shard. llvm::SmallVector TTNNLayoutAttr::getShardShape() const {