-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Optimizer] L1 Interleaved policy that solves simple fork-joins (#1501)
This PR introduces new MemoryLayoutAnalysis policy as an alternative to the GreedyL1Interleaved policy with the goal of solving simple fork-joins. Fork-join is considered to be simple if there is no need for DRAM spill in its execution. In this policy, we want to make sure that we always solve simple fork-joins. Furthermore, if DRAM spill is necessary, this policy will not produce globally optimal solution.
- Loading branch information
1 parent
526919d
commit 05a831e
Showing
37 changed files
with
542 additions
and
93 deletions.
There are no files selected for viewing
76 changes: 76 additions & 0 deletions
76
include/ttmlir/Dialect/TTNN/Analysis/BFInterleavedPolicy.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_BFINTERLEAVEDPOLICY_H | ||
#define TTMLIR_DIALECT_TTNN_ANALYSIS_BFINTERLEAVEDPOLICY_H | ||
|
||
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" | ||
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h" | ||
#include <cstdint> | ||
|
||
namespace mlir::tt::ttnn { | ||
|
||
// The goal of this policy is to always solve simple fork-joins if that is | ||
// possible. Fork-join is considered to be simple if there is no need for DRAM | ||
// spill in its execution. Furthermore, if DRAM spill is necessary, this policy | ||
// will not produce globally optimal solution. | ||
// | ||
class BFInterleavedPolicy : public MemoryLayoutAnalysisPolicy { | ||
public: | ||
// In order to keep track of the L1 memory usage, we have to know two things | ||
// for each op: | ||
// 1. The L1 memory usage of each op's output tensor. | ||
// 2. The number of op's users currently relying on the op's output tensor. | ||
// This is important for fork ops where the output tensor is used by | ||
// multiple other ops. | ||
// | ||
struct OpL1MemUsage { | ||
uint64_t l1MemUsagePerUser; | ||
uint64_t numOfUnscheduledUsers; | ||
}; | ||
|
||
public: | ||
BFInterleavedPolicy( | ||
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs, | ||
const llvm::DenseMap<Operation *, std::vector<TTNNLayoutAttr>> | ||
&legalLayouts, | ||
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule, | ||
unsigned usableL1CacheSize) | ||
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts, | ||
schedule, usableL1CacheSize) {} | ||
|
||
void run() final; | ||
|
||
private: | ||
// Check if the op is analyzable. Op is analyzable if it has at least one | ||
// legal layout. | ||
bool isAnalyzable(Operation *op); | ||
|
||
// Iterate over all operands of the op that satisfy the analyzability | ||
// criterium defined by the isAnalyzable method. This is an abstraction | ||
// for the boilerplate code used in different places within the policy. | ||
// | ||
void walkOnAnalyzableOperands(Operation *op, | ||
function_ref<void(Operation *)> callback); | ||
|
||
// Fetch op's DRAM layout from legalLayouts. | ||
bool hasDRAMBufferType(Operation *op); | ||
TTNNLayoutAttr getDRAMLayout(Operation *op); | ||
|
||
// Fetch op's L1 Interleaved layout from legalLayouts. | ||
bool hasL1BufferType(Operation *op); | ||
TTNNLayoutAttr getL1InterleavedLayout(Operation *op); | ||
|
||
size_t getAvailableL1CacheSize() const { | ||
// Figure out this const based on exec data, but will be replaced | ||
// with API. | ||
// | ||
constexpr float tensorL1UsageCap = 0.75; | ||
return tensorL1UsageCap * usableL1CacheSize; | ||
} | ||
}; | ||
|
||
} // namespace mlir::tt::ttnn | ||
|
||
#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_BFINTERLEAVEDPOLICY_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.