Skip to content

Commit

Permalink
[Optimizer] Globally optimal shard config picker(numcores). (#1251)
Browse files Browse the repository at this point in the history
  • Loading branch information
nobradovictt authored Nov 16, 2024
1 parent 52275fa commit 14d87b7
Show file tree
Hide file tree
Showing 7 changed files with 412 additions and 42 deletions.
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
/python/ @nsmithtt @tapspatel @odjuricicTT @nobradovictt @vprajapati-tt
/runtime/ @jnie-TT @kmabeeTT @AleksKnezevic @pilkicTT
/runtime/tools/ @tapspatel @nsmithtt
/test/unittests/Optimizer @nobradovictt @odjuricicTT
/tools/explorer/ @odjuricicTT @nobradovictt @vprajapati-tt
21 changes: 12 additions & 9 deletions include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class ShardSolver {
}
bool operator!=(Iterator other) const { return not(*this == other); }
reference operator*() const { return (*p)[i]; }
pointer operator->() const { return get(); }
pointer get() const { return &(*p)[i]; }
std::uint64_t index() const { return i; }
};

RemainingLayoutAttrs(std::vector<tt::LayoutAttr> const &p,
Expand All @@ -105,8 +108,6 @@ class ShardSolver {
Bitset mask = 0;
};

ShardSolverSolution const finish();

private:
static Bitset bitset(std::uint64_t bit) {
Bitset b;
Expand Down Expand Up @@ -241,8 +242,8 @@ class ShardSolver {

Operation *getProducerOp() const { return producerOperation; }
Operation *getConsumerOp() const { return consumerOperation; }
const Paths &getPaths() const { return paths; }

private:
private:
BitsetId producerSetId = -1;
BitsetId consumerSetId = -1;
Expand Down Expand Up @@ -280,15 +281,17 @@ class ShardSolver {
tt::LayoutAttr const &consumerLayout) const;

public:
ShardSolver(
const llvm::DenseMap<Operation *, std::vector<LayoutAttr>> &legalLayouts,
const std::vector<OpL1MemSpec> &shardSpecs,
const llvm::DenseSet<Operation *> &shardedOps,
const unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges);
ShardSolver(const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
const std::vector<OpL1MemSpec> &shardSpecs,
const llvm::DenseSet<Operation *> &shardedOps,
const unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges);
RemainingLayoutAttrs at(Operation *operation) const;
void set(Operation *operation, tt::LayoutAttr const &layout);
static bool supportsInterleavedInputShardedOutput(Operation *op);
llvm::DenseMap<Operation *, SmallVector<float, 64>> produceMaxCoreUsage();
ShardSolverSolution finish() const;

private:
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> *legalLayouts;
Expand Down
29 changes: 15 additions & 14 deletions lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,23 +206,24 @@ void DFShardingPolicy::run() {

void DFShardingPolicy::pickOpShardLayouts(ShardSolver &shardSolver,
const L1ChainConfig &l1ChainConfig) {
// TODO(nobradovic)
// Simple picker for now, choose the highest grid size for each op, prefer
// width and height sharding over block sharding.
//
llvm::DenseMap<Operation *, SmallVector<float, 64>> accMaxCoreUsage =
shardSolver.produceMaxCoreUsage();
for (const auto &shardSpec : l1ChainConfig.getOpL1MemSpecs()) {
Operation *op = shardSpec.op;
ShardSolver::RemainingLayoutAttrs validLayouts = shardSolver.at(op);
const tt::LayoutAttr *selectedLayout = &(*validLayouts.begin());
for (const tt::LayoutAttr &layout : validLayouts) {

if (layout.getGrid().getGridVolume() >
selectedLayout->getGrid().getGridVolume()) {
selectedLayout = &layout;
} else if (layout.getGrid().getGridVolume() ==
selectedLayout->getGrid().getGridVolume()) {
if (layout.getMemLayout() != tt::TensorMemoryLayout::BlockSharded) {
selectedLayout = &layout;
const tt::LayoutAttr *selectedLayout = validLayouts.begin().get();
float maxCoreUsage = 0;
for (auto layoutIterator = validLayouts.begin();
layoutIterator != validLayouts.end(); ++layoutIterator) {
if (accMaxCoreUsage[op][layoutIterator.index()] > maxCoreUsage) {
maxCoreUsage = accMaxCoreUsage[op][layoutIterator.index()];
selectedLayout = layoutIterator.get();
} else if (accMaxCoreUsage[op][layoutIterator.index()] == maxCoreUsage) {
// If we have a tie, prefer layout that is not BlockSharded.
//
if (layoutIterator->getMemLayout() !=
tt::TensorMemoryLayout::BlockSharded) {
selectedLayout = layoutIterator.get();
}
}
}
Expand Down
144 changes: 125 additions & 19 deletions lib/Dialect/TTNN/Analysis/ShardSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,26 +530,31 @@ bool ShardSolver::checkShardCompatible(
//
assert(producerLayout.hasShardedL1TensorMemoryLayout() &&
consumerLayout.hasShardedL1TensorMemoryLayout());
RankedTensorType producerTensorType =
mlir::cast<RankedTensorType>(producerOp->getResult(0).getType());
uint64_t producerL1OutputUsage = deviceAttr.getLayoutSizeBytes(
producerTensorType.getShape(), producerLayout,
producerLayout.getMemorySpace());

RankedTensorType consumerTensorType =
mlir::cast<RankedTensorType>(consumerOp->getResult(0).getType());
uint64_t consumerL1OutputUsage = deviceAttr.getLayoutSizeBytes(
consumerTensorType.getShape(), consumerLayout,
consumerLayout.getMemorySpace());
// Figure out this const based on exec data, but will be replaced
// with API.

// Perform L1 usage check only if deviceAttr is available.
//
constexpr float tensorL1UsageCap = 0.8;
bool l1UsageValid = (producerL1OutputUsage + consumerL1OutputUsage) <
tensorL1UsageCap * usableL1CacheSize;
if (deviceAttr) {
RankedTensorType producerTensorType =
mlir::cast<RankedTensorType>(producerOp->getResult(0).getType());
uint64_t producerL1OutputUsage = deviceAttr.getLayoutSizeBytes(
producerTensorType.getShape(), producerLayout,
producerLayout.getMemorySpace());

RankedTensorType consumerTensorType =
mlir::cast<RankedTensorType>(consumerOp->getResult(0).getType());
uint64_t consumerL1OutputUsage = deviceAttr.getLayoutSizeBytes(
consumerTensorType.getShape(), consumerLayout,
consumerLayout.getMemorySpace());
// Figure out this const based on exec data, but will be replaced
// with API.
//
constexpr float tensorL1UsageCap = 0.8;
bool l1UsageValid = (producerL1OutputUsage + consumerL1OutputUsage) <
tensorL1UsageCap * usableL1CacheSize;

if (!l1UsageValid) {
return false;
if (!l1UsageValid) {
return false;
}
}

// Shard compat assumption. Try to keep same shard layout.
Expand All @@ -561,9 +566,110 @@ bool ShardSolver::checkShardCompatible(
return true;
}

// Preprocess ShardSolver search space to make a helper structure which links op
// layout choices to global max core usage.
// Example:
// Lets assume simple case where layouts at same index are compatible for input
// graph provided below. Tupples represent layout core
// usage (Layout0GridVolume, Layout1GridVolume, Layout2GridVolume).
//
// Op0 ----- (4, 8, 2)
// |
// Op1 ----- (8, 4, 2)
// / \
// / \
// Op2 Op3 -- (4, 4, 2) (4, 4, 2)
// \ /
// \ /
// Op4 ----- (2, 1, 1)
// |
// Op5 ----- (2, 1, 1)
//
// Here is how structure looks after preprocessing is complete:
//
// Op0 ----- (24, 22, 10)
// |
// Op1 ----- (20, 14, 8)
// / \
// / \
// Op2 Op3 -- (6, 5, 3) (6, 5, 3)
// \ /
// \ /
// Op4 ----- (4, 2, 2)
// |
// Op5 ----- (2, 1, 1)
//
// Global max of 24 core usage is achieved by selecting layout[0] for each Op.
//
// Returns map of op to vector of max core usage for each layout.
llvm::DenseMap<Operation *, SmallVector<float, 64>>
ShardSolver::produceMaxCoreUsage() {
using Paths = llvm::SmallVector<Path, 16>;
llvm::DenseMap<Operation *, SmallVector<float, 64>> accCoreUsage(
shardedOps->size());

// Start from the tail of the chain and build up the max core usage(schedule
// in backwards).
//
for (auto shardSpec = shardSpecs->rbegin(); shardSpec != shardSpecs->rend();
++shardSpec) {
Operation *op = shardSpec->op;
std::vector<tt::LayoutAttr> const &layouts = getLegalLayouts(op);
assert(!layouts.empty());

// Find the layout that leads to the max core usage.
// Start with grid volume of current op.
//
for (size_t i = 0; i < layouts.size(); ++i) {
tt::LayoutAttr const &layout = layouts[i];
uint64_t coreUsage = layout.getGrid().getGridVolume();
accCoreUsage[op].push_back(coreUsage);
}

// Add core usage of current op users via live path connections.
//
SmallVector<ShardSolver::PathSet *> userPathSets = getUserPathSetsPts(op);
for (size_t i = 0; i < userPathSets.size(); ++i) {
ShardSolver::PathSet *pathSet = userPathSets[i];
const Paths &paths = pathSet->getPaths();
SmallVector<uint64_t, 64> maxCoreUsage(layouts.size(), 0);
Operation *consumerOp = pathSet->getConsumerOp();
size_t consumerInChainOperandSize =
getOperandPathSetsPts(consumerOp).size();
uint64_t consumerCoreUsage = 0;
for (auto const &path : paths) {
assert(bitsets[bitsetIds[op]].test(path.producerId));
assert(bitsets[bitsetIds[consumerOp]].test(path.consumerId));
consumerCoreUsage = accCoreUsage[consumerOp][path.consumerId];
if (consumerCoreUsage > maxCoreUsage[path.producerId]) {
maxCoreUsage[path.producerId] = consumerCoreUsage;
}
}

for (size_t i = 0; i < layouts.size(); ++i) {
// Add max core usage of consumer ops to current op layout.
// We divide by consumerInChainOperandSize to normalize the core usage
// based on forking factor(so that cores are not counted more than
// once).
//
// Incorrect results will be produced in case chain consists of joins
// without previous forks, ie - chain having multiple input ops. In that
// case total sum of used cores would be a sum of maxCoreUsage generated
// by all input ops. This is currently not needed for making a
// decision on layout choice for maximizing core usage.
//
accCoreUsage[op][i] += static_cast<float>(maxCoreUsage[i]) /
static_cast<float>(consumerInChainOperandSize);
}
}
}

return accCoreUsage;
}

// Returns ShardSolverSolution.
//
ShardSolverSolution const ShardSolver::finish() {
ShardSolverSolution ShardSolver::finish() const {
assert(selectedOpLayout.size() == shardedOps->size());
return ShardSolverSolution(selectedOpLayout, memReconfigEdges);
}
Expand Down
1 change: 1 addition & 0 deletions test/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ function(add_mlir_unittest test_dirname)
endfunction()

add_subdirectory(TestScheduler)
add_subdirectory(Optimizer)
10 changes: 10 additions & 0 deletions test/unittests/Optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
add_mlir_unittest(OptimizerTests
TestShardSolver.cpp
)

target_link_libraries(OptimizerTests
PRIVATE
MLIR
MLIRTTDialect
MLIRTTNNPipelines
)
Loading

0 comments on commit 14d87b7

Please sign in to comment.