Skip to content

Commit

Permalink
[Optimizer] L1 Interleaved policy that solves simple fork-joins (#1501)
Browse files Browse the repository at this point in the history
This PR introduces new MemoryLayoutAnalysis policy as an alternative to
the GreedyL1Interleaved policy with the goal of solving simple
fork-joins. Fork-join is considered to be simple if there is no need for
DRAM spill in its execution. In this policy, we want to make sure that
we always solve simple fork-joins. Furthermore, if DRAM spill is
necessary, this policy will not produce globally optimal solution.
  • Loading branch information
fbajraktariTT authored Dec 17, 2024
1 parent 526919d commit 05a831e
Show file tree
Hide file tree
Showing 37 changed files with 542 additions and 93 deletions.
76 changes: 76 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/BFInterleavedPolicy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_BFINTERLEAVEDPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_BFINTERLEAVEDPOLICY_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"
#include <cstdint>

namespace mlir::tt::ttnn {

// The goal of this policy is to always solve simple fork-joins if that is
// possible. Fork-join is considered to be simple if there is no need for DRAM
// spill in its execution. Furthermore, if DRAM spill is necessary, this policy
// will not produce globally optimal solution.
//
class BFInterleavedPolicy : public MemoryLayoutAnalysisPolicy {
public:
// In order to keep track of the L1 memory usage, we have to know two things
// for each op:
// 1. The L1 memory usage of each op's output tensor.
// 2. The number of op's users currently relying on the op's output tensor.
// This is important for fork ops where the output tensor is used by
// multiple other ops.
//
struct OpL1MemUsage {
uint64_t l1MemUsagePerUser;
uint64_t numOfUnscheduledUsers;
};

public:
BFInterleavedPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize) {}

void run() final;

private:
// Check if the op is analyzable. Op is analyzable if it has at least one
// legal layout.
bool isAnalyzable(Operation *op);

// Iterate over all operands of the op that satisfy the analyzability
// criterium defined by the isAnalyzable method. This is an abstraction
// for the boilerplate code used in different places within the policy.
//
void walkOnAnalyzableOperands(Operation *op,
function_ref<void(Operation *)> callback);

// Fetch op's DRAM layout from legalLayouts.
bool hasDRAMBufferType(Operation *op);
TTNNLayoutAttr getDRAMLayout(Operation *op);

// Fetch op's L1 Interleaved layout from legalLayouts.
bool hasL1BufferType(Operation *op);
TTNNLayoutAttr getL1InterleavedLayout(Operation *op);

size_t getAvailableL1CacheSize() const {
// Figure out this const based on exec data, but will be replaced
// with API.
//
constexpr float tensorL1UsageCap = 0.75;
return tensorL1UsageCap * usableL1CacheSize;
}
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_BFINTERLEAVEDPOLICY_H
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_GREEDYL1INTERLEAVEDPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_GREEDYL1INTERLEAVEDPOLICY_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
Expand All @@ -12,7 +12,7 @@

namespace mlir::tt::ttnn {

class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
class GreedyL1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
public:
struct OpMemSpec {
TTNNLayoutAttr layout;
Expand Down Expand Up @@ -46,7 +46,7 @@ class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
};

public:
L1InterleavedPolicy(
GreedyL1InterleavedPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
Expand Down Expand Up @@ -124,4 +124,4 @@ class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_GREEDYL1INTERLEAVEDPOLICY_H
8 changes: 5 additions & 3 deletions include/ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_L1CHAINCONFIG_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_L1CHAINCONFIG_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/Analysis/ShardSolver.h"
#include <unordered_set>

namespace mlir::tt::ttnn {

Expand All @@ -19,7 +17,7 @@ struct OpL1MemSpec {
// Tensor split factor for the output tensor of the op(working with a partial
// tensor).
//
uint tensorSplitFactor;
uint tensorSplitFactor = 1;

// Layout of the output tensor of the op.
//
Expand Down Expand Up @@ -56,6 +54,7 @@ class L1ChainConfig {
void
complete(const llvm::DenseMap<Operation *, TTNNLayoutAttr> &selectedOpLayout,
std::unordered_set<Edge> &memReconfigEdges);
void complete();

bool isEmpty() { return opL1MemSpecs.empty(); }
void addOpL1MemSpec(OpL1MemSpec spec) {
Expand All @@ -70,6 +69,9 @@ class L1ChainConfig {
const std::unordered_set<Edge> &getMemReconfigEdges() const {
return memReconfigEdges;
}

uint64_t size() const { return opL1MemSpecs.size(); }
void merge(L1ChainConfig &other);
};

} // namespace mlir::tt::ttnn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class MemoryLayoutAnalysisPolicy {
llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;
DeviceAttr deviceAttr;

public:
virtual ~MemoryLayoutAnalysisPolicy() {};
Expand Down
19 changes: 14 additions & 5 deletions include/ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

namespace mlir::tt {

enum class MemoryLayoutAnalysisPolicyType { DFSharding, L1Interleaved };
enum class MemoryLayoutAnalysisPolicyType {
DFSharding,
GreedyL1Interleaved,
BFInterleaved
};

struct MemoryLayoutAnalysisPolicyTypeParser
: public llvm::cl::parser<MemoryLayoutAnalysisPolicyType> {
Expand All @@ -22,8 +26,10 @@ struct MemoryLayoutAnalysisPolicyTypeParser
llvm::StringRef arg, MemoryLayoutAnalysisPolicyType &value) {
value = llvm::StringSwitch<MemoryLayoutAnalysisPolicyType>(arg)
.Case("DFSharding", MemoryLayoutAnalysisPolicyType::DFSharding)
.Case("L1Interleaved",
MemoryLayoutAnalysisPolicyType::L1Interleaved);
.Case("GreedyL1Interleaved",
MemoryLayoutAnalysisPolicyType::GreedyL1Interleaved)
.Case("BFInterleaved",
MemoryLayoutAnalysisPolicyType::BFInterleaved);
return false;
}

Expand All @@ -33,8 +39,11 @@ struct MemoryLayoutAnalysisPolicyTypeParser
case MemoryLayoutAnalysisPolicyType::DFSharding:
res += "DFSharding";
break;
case MemoryLayoutAnalysisPolicyType::L1Interleaved:
res += "L1Interleaved";
case MemoryLayoutAnalysisPolicyType::GreedyL1Interleaved:
res += "GreedyL1Interleaved";
break;
case MemoryLayoutAnalysisPolicyType::BFInterleaved:
res += "BFInterleaved";
break;
}
return res;
Expand Down
6 changes: 6 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ RankedTensorType
createRankedTensorTypeWithEncoding(RankedTensorType tensorType,
ttnn::TTNNLayoutAttr encoding);

// Return the L1 memory usage of the output tensor of the given op.
// Used within L1 interleaved policies.
//
uint64_t getOpOutputL1Usage(Operation *op, TTNNLayoutAttr opLayout,
DeviceAttr &deviceAttr);

} // namespace mlir::tt::ttnn::utils

#endif // TTMLIR_DIALECT_TTNN_UTILS_UTILS_H
Loading

0 comments on commit 05a831e

Please sign in to comment.