Skip to content

Commit

Permalink
fix PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fbajraktariTT committed Dec 18, 2024
1 parent 32f02ff commit 0e1fc7a
Showing 1 changed file with 29 additions and 14 deletions.
43 changes: 29 additions & 14 deletions lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,39 +66,55 @@ bool TTNNLayoutAttr::hasInterleavedDRAMTensorMemoryLayout() const {
}

// Calculate the logical shape of the shard.
//
// Shard is defined as a piece of the tensor that is mapped to a single grid
// core. It is assumed that the TensorMemoryLayout is sharded.
// core. This function returns the shard shape for tensors with BLOCK SHARDED
// tensor memory layout.
//
// All examples assume that the tensor is mapped to a 8x8 grid.
// Example: tensor<32x32xbf16> -> {4, 4}
// Example: tensor<65x65xbf16> -> {9, 9}
//
// return The logical shard shape in case of block sharded tensor memory layout.
llvm::SmallVector<int64_t>
TTNNLayoutAttr::calculateLogicalShardShapeForSharding(
ArrayRef<int64_t> tensorShape, mlir::AffineMap linear, GridAttr grid) {
assert(linear.getNumResults() == grid.getShape().size());
mlir::SmallVector<std::int64_t> logicalShape =
mlir::SmallVector<std::int64_t> physicalShape =
ttmlir::utils::evalShape(linear, tensorShape);
mlir::SmallVector<std::int64_t> shardShape(linear.getNumResults());
for (unsigned i = 0; i < linear.getNumResults(); ++i) {
for (size_t i = 0; i < linear.getNumResults(); ++i) {
shardShape[i] =
(logicalShape[i] + grid.getShape()[i] - 1) / grid.getShape()[i];
(physicalShape[i] + grid.getShape()[i] - 1) / grid.getShape()[i];
}
return shardShape;
}

// Calculate the logical shape of the shard.
//
// Shard is defined as a piece of the tensor that is mapped to a single grid
// core. It is assumed that the TensorMemoryLayout is interleaved and buffer
// type is L1.
// core. This function returns the shard shape for tensors with INTERLEAVED
// tensor memory layout.
//
// All examples assume that the tensor is mapped to a 8x8 grid.
// Example: tensor<1x1024xbf16> ( -> 32 tiles ) -> {1, 1}
// Example: tensor<512x512xbf16> ( -> 256 tiles ) -> {1, 4}
// Example: tensor<32x2049xbf16> ( -> 65 tiles ) -> {1, 2}
//
// return The logical shard shape in case of interleaved tensor memory layout.
llvm::SmallVector<int64_t>
TTNNLayoutAttr::calculateLogicalShardShapeForL1Interleaved(
ArrayRef<int64_t> tensorShape, mlir::Type elementType,
mlir::AffineMap linear, mlir::tt::GridAttr grid) {
assert(linear.getNumResults() == grid.getShape().size());
assert(mlir::isa<mlir::tt::TileType>(elementType));

mlir::SmallVector<std::int64_t> logicalShape =
mlir::SmallVector<std::int64_t> physicalShape =
ttmlir::utils::evalShape(linear, tensorShape);
mlir::SmallVector<std::int64_t> logicalTiledShape =
mlir::cast<mlir::tt::TileType>(elementType).getTiledShape(logicalShape);
mlir::SmallVector<std::int64_t> physicalTiledShape =
mlir::cast<mlir::tt::TileType>(elementType).getTiledShape(physicalShape);
uint64_t numOfTiles =
std::accumulate(logicalTiledShape.begin(), logicalTiledShape.end(), 1,
std::accumulate(physicalTiledShape.begin(), physicalTiledShape.end(), 1,
std::multiplies<std::int64_t>());
uint64_t numOfGridUnits =
std::accumulate(grid.getShape().begin(), grid.getShape().end(), 1,
Expand Down Expand Up @@ -325,13 +341,13 @@ mlir::AffineMap TTNNLayoutAttr::replaceMemoryMapSymbolsWithShardShape(
"shard rank");

SmallVector<AffineExpr> symReplacements;
for (unsigned i = 0; i < physicalMemoryMap.getNumSymbols(); ++i) {
for (size_t i = 0; i < physicalMemoryMap.getNumSymbols(); ++i) {
symReplacements.push_back(
getAffineConstantExpr(shardShape[i], getContext()));
}

SmallVector<AffineExpr> dimReplacements;
for (unsigned i = 0; i < physicalMemoryMap.getNumDims(); ++i) {
for (size_t i = 0; i < physicalMemoryMap.getNumDims(); ++i) {
dimReplacements.push_back(getAffineDimExpr(i, getContext()));
}

Expand Down Expand Up @@ -501,8 +517,7 @@ TTNNLayoutAttr TTNNLayoutAttr::get(
AffineMap linear = collapsedLinearAffineMap(
context, tensorShape, grid.getShape(), collapseIntervals);

// Calculate shard shape by evaluating the linear map.
//
// Calculate shard shape
mlir::SmallVector<int64_t> shardShape;
if (bufferType == BufferType::L1 &&
memLayoutAttr.getValue() == TensorMemoryLayout::Interleaved) {
Expand Down

0 comments on commit 0e1fc7a

Please sign in to comment.