From 14d87b724f5614657054416e19fe2ea1576bb884 Mon Sep 17 00:00:00 2001 From: Nikola Obradovic <132568163+nobradovictt@users.noreply.github.com> Date: Sat, 16 Nov 2024 10:15:36 +0100 Subject: [PATCH] [Optimizer] Globally optimal shard config picker(numcores). (#1251) --- .github/CODEOWNERS | 1 + .../Dialect/TTNN/Analysis/ShardSolver.h | 21 +- .../TTNN/Analysis/DFShardingPolicy.cpp | 29 +- lib/Dialect/TTNN/Analysis/ShardSolver.cpp | 144 ++++++++-- test/unittests/CMakeLists.txt | 1 + test/unittests/Optimizer/CMakeLists.txt | 10 + test/unittests/Optimizer/TestShardSolver.cpp | 248 ++++++++++++++++++ 7 files changed, 412 insertions(+), 42 deletions(-) create mode 100644 test/unittests/Optimizer/CMakeLists.txt create mode 100644 test/unittests/Optimizer/TestShardSolver.cpp diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 3c77ac170..6d3a68709 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -21,4 +21,5 @@ /python/ @nsmithtt @tapspatel @odjuricicTT @nobradovictt @vprajapati-tt /runtime/ @jnie-TT @kmabeeTT @AleksKnezevic @pilkicTT /runtime/tools/ @tapspatel @nsmithtt +/test/unittests/Optimizer @nobradovictt @odjuricicTT /tools/explorer/ @odjuricicTT @nobradovictt @vprajapati-tt diff --git a/include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h b/include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h index 81ce61bd2..1178a578e 100644 --- a/include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h +++ b/include/ttmlir/Dialect/TTNN/Analysis/ShardSolver.h @@ -89,6 +89,9 @@ class ShardSolver { } bool operator!=(Iterator other) const { return not(*this == other); } reference operator*() const { return (*p)[i]; } + pointer operator->() const { return get(); } + pointer get() const { return &(*p)[i]; } + std::uint64_t index() const { return i; } }; RemainingLayoutAttrs(std::vector const &p, @@ -105,8 +108,6 @@ class ShardSolver { Bitset mask = 0; }; - ShardSolverSolution const finish(); - private: static Bitset bitset(std::uint64_t bit) { Bitset b; @@ -241,8 +242,8 @@ class ShardSolver { Operation *getProducerOp() const { return producerOperation; } Operation *getConsumerOp() const { return consumerOperation; } + const Paths &getPaths() const { return paths; } - private: private: BitsetId producerSetId = -1; BitsetId consumerSetId = -1; @@ -280,15 +281,17 @@ class ShardSolver { tt::LayoutAttr const &consumerLayout) const; public: - ShardSolver( - const llvm::DenseMap> &legalLayouts, - const std::vector &shardSpecs, - const llvm::DenseSet &shardedOps, - const unsigned usableL1CacheSize, - const std::unordered_set &overrideReshardEdges); + ShardSolver(const llvm::DenseMap> + &legalLayouts, + const std::vector &shardSpecs, + const llvm::DenseSet &shardedOps, + const unsigned usableL1CacheSize, + const std::unordered_set &overrideReshardEdges); RemainingLayoutAttrs at(Operation *operation) const; void set(Operation *operation, tt::LayoutAttr const &layout); static bool supportsInterleavedInputShardedOutput(Operation *op); + llvm::DenseMap> produceMaxCoreUsage(); + ShardSolverSolution finish() const; private: const llvm::DenseMap> *legalLayouts; diff --git a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp index 1f2e23de2..11e3fcdbf 100644 --- a/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp +++ b/lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp @@ -206,23 +206,24 @@ void DFShardingPolicy::run() { void DFShardingPolicy::pickOpShardLayouts(ShardSolver &shardSolver, const L1ChainConfig &l1ChainConfig) { - // TODO(nobradovic) - // Simple picker for now, choose the highest grid size for each op, prefer - // width and height sharding over block sharding. - // + llvm::DenseMap> accMaxCoreUsage = + shardSolver.produceMaxCoreUsage(); for (const auto &shardSpec : l1ChainConfig.getOpL1MemSpecs()) { Operation *op = shardSpec.op; ShardSolver::RemainingLayoutAttrs validLayouts = shardSolver.at(op); - const tt::LayoutAttr *selectedLayout = &(*validLayouts.begin()); - for (const tt::LayoutAttr &layout : validLayouts) { - - if (layout.getGrid().getGridVolume() > - selectedLayout->getGrid().getGridVolume()) { - selectedLayout = &layout; - } else if (layout.getGrid().getGridVolume() == - selectedLayout->getGrid().getGridVolume()) { - if (layout.getMemLayout() != tt::TensorMemoryLayout::BlockSharded) { - selectedLayout = &layout; + const tt::LayoutAttr *selectedLayout = validLayouts.begin().get(); + float maxCoreUsage = 0; + for (auto layoutIterator = validLayouts.begin(); + layoutIterator != validLayouts.end(); ++layoutIterator) { + if (accMaxCoreUsage[op][layoutIterator.index()] > maxCoreUsage) { + maxCoreUsage = accMaxCoreUsage[op][layoutIterator.index()]; + selectedLayout = layoutIterator.get(); + } else if (accMaxCoreUsage[op][layoutIterator.index()] == maxCoreUsage) { + // If we have a tie, prefer layout that is not BlockSharded. + // + if (layoutIterator->getMemLayout() != + tt::TensorMemoryLayout::BlockSharded) { + selectedLayout = layoutIterator.get(); } } } diff --git a/lib/Dialect/TTNN/Analysis/ShardSolver.cpp b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp index 580921332..c0c541fc1 100644 --- a/lib/Dialect/TTNN/Analysis/ShardSolver.cpp +++ b/lib/Dialect/TTNN/Analysis/ShardSolver.cpp @@ -530,26 +530,31 @@ bool ShardSolver::checkShardCompatible( // assert(producerLayout.hasShardedL1TensorMemoryLayout() && consumerLayout.hasShardedL1TensorMemoryLayout()); - RankedTensorType producerTensorType = - mlir::cast(producerOp->getResult(0).getType()); - uint64_t producerL1OutputUsage = deviceAttr.getLayoutSizeBytes( - producerTensorType.getShape(), producerLayout, - producerLayout.getMemorySpace()); - - RankedTensorType consumerTensorType = - mlir::cast(consumerOp->getResult(0).getType()); - uint64_t consumerL1OutputUsage = deviceAttr.getLayoutSizeBytes( - consumerTensorType.getShape(), consumerLayout, - consumerLayout.getMemorySpace()); - // Figure out this const based on exec data, but will be replaced - // with API. + + // Perform L1 usage check only if deviceAttr is available. // - constexpr float tensorL1UsageCap = 0.8; - bool l1UsageValid = (producerL1OutputUsage + consumerL1OutputUsage) < - tensorL1UsageCap * usableL1CacheSize; + if (deviceAttr) { + RankedTensorType producerTensorType = + mlir::cast(producerOp->getResult(0).getType()); + uint64_t producerL1OutputUsage = deviceAttr.getLayoutSizeBytes( + producerTensorType.getShape(), producerLayout, + producerLayout.getMemorySpace()); + + RankedTensorType consumerTensorType = + mlir::cast(consumerOp->getResult(0).getType()); + uint64_t consumerL1OutputUsage = deviceAttr.getLayoutSizeBytes( + consumerTensorType.getShape(), consumerLayout, + consumerLayout.getMemorySpace()); + // Figure out this const based on exec data, but will be replaced + // with API. + // + constexpr float tensorL1UsageCap = 0.8; + bool l1UsageValid = (producerL1OutputUsage + consumerL1OutputUsage) < + tensorL1UsageCap * usableL1CacheSize; - if (!l1UsageValid) { - return false; + if (!l1UsageValid) { + return false; + } } // Shard compat assumption. Try to keep same shard layout. @@ -561,9 +566,110 @@ bool ShardSolver::checkShardCompatible( return true; } +// Preprocess ShardSolver search space to make a helper structure which links op +// layout choices to global max core usage. +// Example: +// Lets assume simple case where layouts at same index are compatible for input +// graph provided below. Tupples represent layout core +// usage (Layout0GridVolume, Layout1GridVolume, Layout2GridVolume). +// +// Op0 ----- (4, 8, 2) +// | +// Op1 ----- (8, 4, 2) +// / \ +// / \ +// Op2 Op3 -- (4, 4, 2) (4, 4, 2) +// \ / +// \ / +// Op4 ----- (2, 1, 1) +// | +// Op5 ----- (2, 1, 1) +// +// Here is how structure looks after preprocessing is complete: +// +// Op0 ----- (24, 22, 10) +// | +// Op1 ----- (20, 14, 8) +// / \ +// / \ +// Op2 Op3 -- (6, 5, 3) (6, 5, 3) +// \ / +// \ / +// Op4 ----- (4, 2, 2) +// | +// Op5 ----- (2, 1, 1) +// +// Global max of 24 core usage is achieved by selecting layout[0] for each Op. +// +// Returns map of op to vector of max core usage for each layout. +llvm::DenseMap> +ShardSolver::produceMaxCoreUsage() { + using Paths = llvm::SmallVector; + llvm::DenseMap> accCoreUsage( + shardedOps->size()); + + // Start from the tail of the chain and build up the max core usage(schedule + // in backwards). + // + for (auto shardSpec = shardSpecs->rbegin(); shardSpec != shardSpecs->rend(); + ++shardSpec) { + Operation *op = shardSpec->op; + std::vector const &layouts = getLegalLayouts(op); + assert(!layouts.empty()); + + // Find the layout that leads to the max core usage. + // Start with grid volume of current op. + // + for (size_t i = 0; i < layouts.size(); ++i) { + tt::LayoutAttr const &layout = layouts[i]; + uint64_t coreUsage = layout.getGrid().getGridVolume(); + accCoreUsage[op].push_back(coreUsage); + } + + // Add core usage of current op users via live path connections. + // + SmallVector userPathSets = getUserPathSetsPts(op); + for (size_t i = 0; i < userPathSets.size(); ++i) { + ShardSolver::PathSet *pathSet = userPathSets[i]; + const Paths &paths = pathSet->getPaths(); + SmallVector maxCoreUsage(layouts.size(), 0); + Operation *consumerOp = pathSet->getConsumerOp(); + size_t consumerInChainOperandSize = + getOperandPathSetsPts(consumerOp).size(); + uint64_t consumerCoreUsage = 0; + for (auto const &path : paths) { + assert(bitsets[bitsetIds[op]].test(path.producerId)); + assert(bitsets[bitsetIds[consumerOp]].test(path.consumerId)); + consumerCoreUsage = accCoreUsage[consumerOp][path.consumerId]; + if (consumerCoreUsage > maxCoreUsage[path.producerId]) { + maxCoreUsage[path.producerId] = consumerCoreUsage; + } + } + + for (size_t i = 0; i < layouts.size(); ++i) { + // Add max core usage of consumer ops to current op layout. + // We divide by consumerInChainOperandSize to normalize the core usage + // based on forking factor(so that cores are not counted more than + // once). + // + // Incorrect results will be produced in case chain consists of joins + // without previous forks, ie - chain having multiple input ops. In that + // case total sum of used cores would be a sum of maxCoreUsage generated + // by all input ops. This is currently not needed for making a + // decision on layout choice for maximizing core usage. + // + accCoreUsage[op][i] += static_cast(maxCoreUsage[i]) / + static_cast(consumerInChainOperandSize); + } + } + } + + return accCoreUsage; +} + // Returns ShardSolverSolution. // -ShardSolverSolution const ShardSolver::finish() { +ShardSolverSolution ShardSolver::finish() const { assert(selectedOpLayout.size() == shardedOps->size()); return ShardSolverSolution(selectedOpLayout, memReconfigEdges); } diff --git a/test/unittests/CMakeLists.txt b/test/unittests/CMakeLists.txt index bfb40aad7..a66e00a43 100644 --- a/test/unittests/CMakeLists.txt +++ b/test/unittests/CMakeLists.txt @@ -6,3 +6,4 @@ function(add_mlir_unittest test_dirname) endfunction() add_subdirectory(TestScheduler) +add_subdirectory(Optimizer) diff --git a/test/unittests/Optimizer/CMakeLists.txt b/test/unittests/Optimizer/CMakeLists.txt new file mode 100644 index 000000000..681d78ff0 --- /dev/null +++ b/test/unittests/Optimizer/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_unittest(OptimizerTests + TestShardSolver.cpp +) + +target_link_libraries(OptimizerTests + PRIVATE + MLIR + MLIRTTDialect + MLIRTTNNPipelines +) diff --git a/test/unittests/Optimizer/TestShardSolver.cpp b/test/unittests/Optimizer/TestShardSolver.cpp new file mode 100644 index 000000000..ac2454558 --- /dev/null +++ b/test/unittests/Optimizer/TestShardSolver.cpp @@ -0,0 +1,248 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "llvm/ADT/SmallVector.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/MLIRContext.h" + +#include "ttmlir/Dialect/TTNN/IR/TTNN.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" + +#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h" + +using namespace mlir::tt::ttnn; + +constexpr int TensorDimX = 128; +constexpr int TensorDimY = 128; + +class ShardSolverBase : public ::testing::Test { +public: + mlir::MLIRContext context; + mlir::OwningOpRef module; + mlir::OpBuilder builder = mlir::OpBuilder(&context); + mlir::func::FuncOp func; + + void SetUp() override { + context.loadDialect(); + module = mlir::ModuleOp::create(builder.getUnknownLoc()); + builder.setInsertionPointToStart(&module->getBodyRegion().front()); + createFuncOp(); + } + + llvm::SmallVector getTensorShape() { + return {TensorDimX, TensorDimY}; + } + + mlir::RankedTensorType getTensorRankedType() { + return mlir::RankedTensorType::get(getTensorShape(), builder.getF32Type()); + } + + mlir::Value createEmptyTensor() { + ShapeAttr shapeAttr = ShapeAttr::get(&context, getTensorShape()); + return builder.create(builder.getUnknownLoc(), + getTensorRankedType(), nullptr, shapeAttr, + nullptr, nullptr, nullptr); + } + + mlir::func::FuncOp createFuncOp() { + mlir::SmallVector input; + input.push_back(getTensorRankedType()); + + mlir::SmallVector output; + output.push_back(getTensorRankedType()); + + auto funcType = builder.getType( + mlir::TypeRange(input), mlir::TypeRange(output)); + func = builder.create(builder.getUnknownLoc(), "test", + funcType); + + mlir::Block *block = func.addEntryBlock(); + block->addArgument(getTensorRankedType(), builder.getUnknownLoc()); + block->addArgument(getTensorRankedType(), builder.getUnknownLoc()); + + builder.setInsertionPointToStart(block); + + return func; + } + + void + prepareOpForShardSolver(mlir::Operation *op, + std::vector &opL1MemSpecs, + llvm::DenseSet &l1ChainedOps) { + OpL1MemSpec opL1MemSpec; + opL1MemSpec.op = op; + opL1MemSpecs.push_back(opL1MemSpec); + l1ChainedOps.insert(op); + } + + void addLayoutForOp( + mlir::Operation *op, + llvm::DenseMap> + &legalLayouts, + mlir::tt::MemorySpace memorySpace, + mlir::tt::TensorMemoryLayout tensorMemoryLayout, int gridWidth, + int gridHeight) { + if (legalLayouts.find(op) == legalLayouts.end()) { + legalLayouts[op] = + std::vector{mlir::tt::LayoutAttr::get( + &context, getTensorRankedType(), memorySpace, + mlir::tt::GridAttr::get(&context, {gridWidth, gridHeight}), + builder.getF32Type(), tensorMemoryLayout)}; + } else { + legalLayouts[op].push_back(mlir::tt::LayoutAttr::get( + &context, getTensorRankedType(), memorySpace, + mlir::tt::GridAttr::get(&context, {gridWidth, gridHeight}), + builder.getF32Type(), tensorMemoryLayout)); + } + } + + void TearDown() override {} +}; + +// Validate that ShardSolver can produce correct max core usage for a shard +// chain, total accumulated in the first op. +// +// Op0 ----- (4, 8, 4) +// | +// Op1 ----- (8, 4, 4) +// / \ +// / \ +// Op2 Op3 -- (4, 4, 1) (4, 4, 1) +// \ / +// \ / +// Op4 ----- (2, 1, 1) +// | +// Op5 ----- (2, 1, 1) +// +// Verification target: +// +// Op0 ----- (24, 22, 12) +// | +// Op1 ----- (20, 14, 8) +// / \ +// / \ +// Op2 Op3 -- (6, 5, 3) (6, 5, 3) +// \ / +// \ / +// Op4 ----- (4, 2, 2) +// | +// Op5 ----- (2, 1, 1) +// +TEST_F(ShardSolverBase, VerifyProduceMaxCoreUsage) { + llvm::DenseMap> + legalLayouts; + std::vector opL1MemSpecs; + llvm::DenseSet l1ChainedOps; + constexpr unsigned usableL1CacheSize = 1024 * 1024; + std::unordered_set overrideReshardEdges; + + mlir::Value dest = createEmptyTensor(); + mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0); + mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1); + mlir::Operation *op = + builder.create(builder.getUnknownLoc(), lhs, rhs, dest); + mlir::Operation *firstOp = op; + + prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::WidthSharded, 1, 4); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::HeightSharded, 8, 1); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::BlockSharded, 2, 2); + + rhs = op->getResult(0); + dest = createEmptyTensor(); + op = builder.create(builder.getUnknownLoc(), rhs, dest); + prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::WidthSharded, 1, 8); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::HeightSharded, 4, 1); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::BlockSharded, 2, 2); + + lhs = func.getBody().getBlocks().front().getArgument(0); + rhs = op->getResult(0); + + dest = createEmptyTensor(); + op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest); + prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::WidthSharded, 1, 4); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::HeightSharded, 4, 1); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::BlockSharded, 1, 1); + + dest = createEmptyTensor(); + op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest); + prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::WidthSharded, 1, 4); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::HeightSharded, 4, 1); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::BlockSharded, 1, 1); + + lhs = opL1MemSpecs[opL1MemSpecs.size() - 2].op->getResult(0); + rhs = opL1MemSpecs[opL1MemSpecs.size() - 1].op->getResult(0); + dest = createEmptyTensor(); + op = builder.create(builder.getUnknownLoc(), lhs, rhs, dest); + prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::WidthSharded, 1, 2); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::HeightSharded, 1, 1); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::BlockSharded, 1, 1); + + rhs = op->getResult(0); + dest = createEmptyTensor(); + op = builder.create(builder.getUnknownLoc(), rhs, dest); + prepareOpForShardSolver(op, opL1MemSpecs, l1ChainedOps); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::WidthSharded, 1, 2); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::HeightSharded, 1, 1); + addLayoutForOp(op, legalLayouts, mlir::tt::MemorySpace::DeviceL1, + mlir::tt::TensorMemoryLayout::BlockSharded, 1, 1); + + ShardSolver shardSolver(legalLayouts, opL1MemSpecs, l1ChainedOps, + usableL1CacheSize, overrideReshardEdges); + + llvm::DenseMap> + accMaxCoreUsage = shardSolver.produceMaxCoreUsage(); + + ASSERT_EQ(accMaxCoreUsage[firstOp][0], 24); + ASSERT_EQ(accMaxCoreUsage[firstOp][1], 22); + ASSERT_EQ(accMaxCoreUsage[firstOp][2], 12); + + // Set layouts for all ops in ShardSolver and validate that their total core + // usage matches the expected values. Picking legal layout at index 0 for all + // ops should lead to accMaxCoreUsage[firstOp][0] total core usage. + // + for (auto &opL1MemSpec : opL1MemSpecs) { + ShardSolver::RemainingLayoutAttrs validLayouts = + shardSolver.at(opL1MemSpec.op); + const mlir::tt::LayoutAttr *selectedLayout = validLayouts.begin().get(); + shardSolver.set(opL1MemSpec.op, *selectedLayout); + } + + llvm::DenseMap selectedOpLayout = + shardSolver.finish().selectedOpLayout; + float totalCoreUsage = 0; + for (const auto &opLayout : selectedOpLayout) { + totalCoreUsage += opLayout.second.getGrid().getGridVolume(); + } + + ASSERT_EQ(totalCoreUsage, accMaxCoreUsage[firstOp][0]); +}