Skip to content

Commit

Permalink
Replacing LayoutAttr with TensorConfigAttr (#1231)
Browse files Browse the repository at this point in the history
* Moving layout to TTNN

* Addressing comments and merging with main.

* Fixing build

* Align attribute name with IR name
  • Loading branch information
mtopalovicTT authored Nov 17, 2024
1 parent 14d87b7 commit 5c1f2a9
Show file tree
Hide file tree
Showing 68 changed files with 1,576 additions and 536 deletions.
24 changes: 24 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,28 @@ SystemDescAttr getCurrentScopeSystemDesc(Operation *op);
DeviceAttr getCurrentScopeDevice(Operation *op);
} // namespace mlir::tt

mlir::AffineMap collapsedLinearAffineMap(
::mlir::MLIRContext *context, ::llvm::ArrayRef<int64_t> shape,
::llvm::ArrayRef<int64_t> gridShape,
::llvm::ArrayRef<std::pair<std::int64_t, std::int64_t>> collapseIntervals);

mlir::SmallVector<std::int64_t>
calculateLogicalShardShape(mlir::ArrayRef<int64_t> tensorShape,
mlir::AffineMap linear, mlir::tt::GridAttr grid);

template <typename T, typename TAttr>
mlir::MemRefType buildMemRef(::mlir::MLIRContext *context,
::llvm::ArrayRef<int64_t> shardShape,
::mlir::Type elementType, T memorySpace) {
::llvm::SmallVector<int64_t> scalarShardShape(shardShape);
if (mlir::isa<mlir::tt::TileType>(elementType)) {
scalarShardShape = mlir::cast<mlir::tt::TileType>(elementType)
.getTiledShape(scalarShardShape);
}
return mlir::MemRefType::get(
scalarShardShape, elementType,
mlir::AffineMap::getMultiDimIdentityMap(scalarShardShape.size(), context),
TAttr::get(context, memorySpace));
}

#endif
95 changes: 95 additions & 0 deletions include/ttmlir/Dialect/TT/Utils/OperandConstraints.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TT_UTILS_OPERANDCONSTRAINTS_H
#define TTMLIR_DIALECT_TT_UTILS_OPERANDCONSTRAINTS_H

#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

namespace mlir::tt {

inline OperandConstraint
memorySpaceAsOperandConstraint(MemorySpace memorySpace) {
switch (memorySpace) {
case MemorySpace::System:
case MemorySpace::SystemMMIO:
return OperandConstraint::System;
case MemorySpace::DeviceDRAM:
return OperandConstraint::DRAM;
case MemorySpace::DeviceL1:
return OperandConstraint::L1;
}
}

inline OperandConstraint
memoryLayoutAsOperandConstraint(TensorMemoryLayout memoryLayout) {
switch (memoryLayout) {
case TensorMemoryLayout::None:
return OperandConstraint::None;
case TensorMemoryLayout::Interleaved:
return OperandConstraint::Interleaved;
case TensorMemoryLayout::SingleBank:
return OperandConstraint::SingleBank;
case TensorMemoryLayout::HeightSharded:
return OperandConstraint::HeightSharded;
case TensorMemoryLayout::WidthSharded:
return OperandConstraint::WidthSharded;
case TensorMemoryLayout::BlockSharded:
return OperandConstraint::BlockSharded;
}
}

inline MemorySpace getLegalMemorySpace(OperandConstraint operandConstraint,
MemorySpace defaultMemorySpace) {
if (bitEnumContainsAny(operandConstraint,
memorySpaceAsOperandConstraint(defaultMemorySpace))) {
return defaultMemorySpace;
}
if (bitEnumContainsAny(operandConstraint, OperandConstraint::DRAM)) {
return MemorySpace::DeviceDRAM;
}
if (bitEnumContainsAny(operandConstraint, OperandConstraint::L1)) {
return MemorySpace::DeviceL1;
}
return MemorySpace::System;
}

inline TensorMemoryLayout
getLegalTensorMemoryLayout(OperandConstraint operandConstraint,
MemorySpace targetMemorySpace,
TensorMemoryLayout defaultDeviceMemLayout) {
if (defaultDeviceMemLayout == TensorMemoryLayout::None) {
return TensorMemoryLayout::None;
}

if (isSystemMemorySpace(targetMemorySpace)) {
return TensorMemoryLayout::None;
}

assert(isDeviceMemorySpace(targetMemorySpace));
if (bitEnumContainsAny(operandConstraint, memoryLayoutAsOperandConstraint(
defaultDeviceMemLayout))) {
return defaultDeviceMemLayout;
}

std::map<OperandConstraint, TensorMemoryLayout> validLayoutsMap = {
{OperandConstraint::Interleaved, TensorMemoryLayout::Interleaved},
{OperandConstraint::SingleBank, TensorMemoryLayout::SingleBank},
{OperandConstraint::HeightSharded, TensorMemoryLayout::HeightSharded},
{OperandConstraint::WidthSharded, TensorMemoryLayout::WidthSharded},
{OperandConstraint::BlockSharded, TensorMemoryLayout::BlockSharded}};

for (const auto &[constraintLayout, memLayout] : validLayoutsMap) {
if (bitEnumContainsAny(operandConstraint, constraintLayout)) {
return memLayout;
}
}

return TensorMemoryLayout::None;
}

} // namespace mlir::tt

#endif // TTMLIR_DIALECT_TT_UTILS_OPERANDCONSTRAINTS_H
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

namespace mlir::tt::ttnn {

Expand All @@ -23,7 +24,7 @@ class DFShardingPolicy : public MemoryLayoutAnalysisPolicy {
public:
DFShardingPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
Expand Down
6 changes: 3 additions & 3 deletions include/ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ struct OpL1MemSpec {

// Layout of the output tensor of the op.
//
tt::LayoutAttr layout;
TTNNLayoutAttr layout;
};

// Enum to track the state of the L1 chain.
Expand All @@ -47,14 +47,14 @@ class L1ChainConfig {
L1ChainConfig() : opL1MemSpecs(), state() {}

ShardSolver resolveWithSolver(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges);
void resolve();
void build();
void
complete(const llvm::DenseMap<Operation *, tt::LayoutAttr> &selectedOpLayout,
complete(const llvm::DenseMap<Operation *, TTNNLayoutAttr> &selectedOpLayout,
std::unordered_set<Edge> &memReconfigEdges);

bool isEmpty() { return opL1MemSpecs.empty(); }
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
public:
L1InterleavedPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Analysis/LegalGridAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct LegalGridAnalysisInput {
};

class LegalGridAnalysis
: public TTNNAnalysis<LegalGridAnalysisInput, std::vector<tt::LayoutAttr>> {
: public TTNNAnalysis<LegalGridAnalysisInput, std::vector<TTNNLayoutAttr>> {
private:
void analysisImplementation() override;
bool applyOverrides() override;
Expand Down
8 changes: 4 additions & 4 deletions include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
namespace mlir::tt::ttnn {

struct MemoryLayoutAnalysisInput {
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalLayouts;
unsigned usableL1CacheSize = 0;
std::unordered_set<Edge> overrideReshardEdges;
MemoryLayoutAnalysisPolicyType policy;

MemoryLayoutAnalysisInput() : legalLayouts() {}

MemoryLayoutAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges,
Expand All @@ -40,15 +40,15 @@ struct MemoryLayoutAnalysisInput {
};

struct MemoryLayoutAnalysisResult {
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalLayouts;
std::unordered_set<Edge> memReconfigEdges;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> schedule;

MemoryLayoutAnalysisResult()
: legalLayouts(), memReconfigEdges(), schedule() {}

MemoryLayoutAnalysisResult(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
const std::unordered_set<Edge> &memReconfigEdges)
: legalLayouts(legalLayouts), memReconfigEdges(memReconfigEdges) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class MemoryLayoutAnalysisPolicy {
protected:
Operation *rootOp;
std::vector<L1ChainConfig> *l1ChainConfigs;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;

Expand All @@ -23,7 +23,7 @@ class MemoryLayoutAnalysisPolicy {

MemoryLayoutAnalysisPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
Expand Down
10 changes: 5 additions & 5 deletions include/ttmlir/Dialect/TTNN/Analysis/OpConfigAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_OPCONFIGANALYSIS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

namespace mlir::tt::ttnn {

struct OpConfigAnalysisInput {
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalGrids;
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalGrids;

OpConfigAnalysisInput() : legalGrids() {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&&legalGrids)
: legalGrids(std::move(legalGrids)) {}

OpConfigAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalGrids)
: legalGrids(legalGrids) {}

Expand All @@ -38,7 +38,7 @@ struct OpConfigAnalysisInput {
//
class OpConfigAnalysis
: public TTNNAnalysis<OpConfigAnalysisInput,
llvm::DenseMap<Operation *, tt::LayoutAttr>> {
llvm::DenseMap<Operation *, TTNNLayoutAttr>> {

private:
void analysisImplementation() override;
Expand Down
35 changes: 18 additions & 17 deletions include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Analysis/Edge.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include <algorithm>
#include <bitset>
#include <unordered_map>
Expand All @@ -18,11 +19,11 @@ namespace mlir::tt::ttnn {
struct OpL1MemSpec;

struct ShardSolverSolution {
llvm::DenseMap<Operation *, tt::LayoutAttr> selectedOpLayout;
llvm::DenseMap<Operation *, TTNNLayoutAttr> selectedOpLayout;
std::unordered_set<Edge> memReconfigEdges;

ShardSolverSolution(
const llvm::DenseMap<Operation *, tt::LayoutAttr> &selectedOpLayout,
const llvm::DenseMap<Operation *, TTNNLayoutAttr> &selectedOpLayout,
const std::unordered_set<Edge> &memReconfigEdges)
: selectedOpLayout(selectedOpLayout), memReconfigEdges(memReconfigEdges) {
}
Expand All @@ -43,7 +44,7 @@ class ShardSolver {
struct RemainingLayoutAttrs {
class Iterator {
std::uint64_t i = 0;
std::vector<tt::LayoutAttr> const *p = nullptr;
std::vector<TTNNLayoutAttr> const *p = nullptr;
Bitset mask = 0;

private:
Expand All @@ -62,12 +63,12 @@ class ShardSolver {

public:
using iterator_category = std::input_iterator_tag;
using value_type = const tt::LayoutAttr;
using difference_type = const tt::LayoutAttr;
using pointer = const tt::LayoutAttr *;
using reference = const tt::LayoutAttr &;
using value_type = const TTNNLayoutAttr;
using difference_type = const TTNNLayoutAttr;
using pointer = const TTNNLayoutAttr *;
using reference = const TTNNLayoutAttr &;

Iterator(std::vector<tt::LayoutAttr> const *p, const Bitset &mask,
Iterator(std::vector<TTNNLayoutAttr> const *p, const Bitset &mask,
std::uint64_t i = 0)
: i(i), p(p), mask(mask) {
nextValid();
Expand All @@ -94,7 +95,7 @@ class ShardSolver {
std::uint64_t index() const { return i; }
};

RemainingLayoutAttrs(std::vector<tt::LayoutAttr> const &p,
RemainingLayoutAttrs(std::vector<TTNNLayoutAttr> const &p,
const Bitset &mask)
: p(&p), mask(mask) {}

Expand All @@ -104,7 +105,7 @@ class ShardSolver {
}
size_t size() const { return mask.count(); }

std::vector<tt::LayoutAttr> const *p = nullptr;
std::vector<TTNNLayoutAttr> const *p = nullptr;
Bitset mask = 0;
};

Expand Down Expand Up @@ -252,7 +253,7 @@ class ShardSolver {
Paths paths;
};

const std::vector<tt::LayoutAttr> &
const std::vector<TTNNLayoutAttr> &
getLegalLayouts(Operation *operation) const;
void reset();

Expand All @@ -276,25 +277,25 @@ class ShardSolver {

void preprocessFirstOp();
bool checkShardCompatible(Operation *producerOp,
tt::LayoutAttr const &producerLayout,
TTNNLayoutAttr const &producerLayout,
Operation *consumerOp,
tt::LayoutAttr const &consumerLayout) const;
TTNNLayoutAttr const &consumerLayout) const;

public:
ShardSolver(const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
ShardSolver(const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
const std::vector<OpL1MemSpec> &shardSpecs,
const llvm::DenseSet<Operation *> &shardedOps,
const unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges);
RemainingLayoutAttrs at(Operation *operation) const;
void set(Operation *operation, tt::LayoutAttr const &layout);
void set(Operation *operation, TTNNLayoutAttr const &layout);
static bool supportsInterleavedInputShardedOutput(Operation *op);
llvm::DenseMap<Operation *, SmallVector<float, 64>> produceMaxCoreUsage();
ShardSolverSolution finish() const;

private:
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> *legalLayouts;
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> *legalLayouts;
const std::vector<OpL1MemSpec> *shardSpecs;
const llvm::DenseSet<Operation *> *shardedOps;
unsigned usableL1CacheSize;
Expand All @@ -307,7 +308,7 @@ class ShardSolver {
std::unordered_map<Edge, PathSetId> pathSetIds;
std::unordered_map<Operation *, BitsetId> bitsetIds;

llvm::DenseMap<Operation *, tt::LayoutAttr> selectedOpLayout;
llvm::DenseMap<Operation *, TTNNLayoutAttr> selectedOpLayout;
std::unordered_set<Edge> memReconfigEdges;
};

Expand Down
Loading

0 comments on commit 5c1f2a9

Please sign in to comment.