From 5442b82469754c02d0e3d65adf6757d1b948b3f6 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Mon, 9 Dec 2024 11:07:03 -0800 Subject: [PATCH] [StableHLO] Add shape refinement callback to specify additional patterns. PiperOrigin-RevId: 704350283 --- third_party/stablehlo/temporary.patch | 105 ++++++++++++++++++ .../transforms/stablehlo_refine_shapes.cpp | 40 +++---- .../stablehlo_refine_shapes.mlir | 20 ++++ 3 files changed, 140 insertions(+), 25 deletions(-) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b137891791fe..963e2d044883c 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,106 @@ +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp +@@ -369,6 +369,10 @@ + // Which correlates to + class RefineShapeState { + public: ++ RefineShapeState( ++ std::optional additionalPatternsFn) ++ : additionalPatternsFn(additionalPatternsFn) {} ++ + enum class RefinementState { + NOT_ALREADY_REFINED, + ALREADY_REFINED, +@@ -431,7 +435,14 @@ + }); + } + ++ void addAdditionalPatterns(RewritePatternSet& patterns) { ++ if (additionalPatternsFn.has_value()) ++ additionalPatternsFn.value()(&patterns); ++ } ++ + private: ++ std::optional additionalPatternsFn; ++ + // Maps refined functions to the refinement context: the values of dimension + // arguments and the types of non-global-constant arguments. A function is + // added here when we start refining it. +@@ -1001,7 +1012,7 @@ + LogicalResult applyShapeRefinementPatterns(func::FuncOp func, + RefineShapeState& state) { + MLIRContext* context = func.getContext(); +- RewritePatternSet patterns(context); ++ RewritePatternSet patterns(func->getContext()); + GreedyRewriteConfig config; + + // The algorithm behind this pass consists of a single traversal of the +@@ -1019,6 +1030,9 @@ + populateStablehloRefineShapesPatterns(&patterns, context); + patterns.add(context, state); + ++ // Populate additional patterns for StableHLO extensions. ++ state.addAdditionalPatterns(patterns); ++ + // The folding patterns implement partial evaluation of shape computations + // which is a critical part of implementing type refinement for ops like + // dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape +@@ -1103,14 +1117,22 @@ + + // Start with empty state, and no dim args / token args. + MLIRContext* context = func.getContext(); +- RefineShapeState state; +- RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes())); +- if (failed(refineFunction(*context, state, key))) +- return signalPassFailure(); ++ if (failed(refineEntryFunction(*context, func))) return signalPassFailure(); + } + }; + + } // namespace ++ ++LogicalResult refineEntryFunction( ++ MLIRContext& context, func::FuncOp func, ++ std::optional additionalPatternsFn) { ++ // Start with empty state, and no dim args / token args. ++ RefineShapeState state(additionalPatternsFn); ++ RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes())); ++ if (failed(refineFunction(context, state, key))) ++ return func.emitError("Failed to refine entry function"); ++ return success(); ++} + + func::FuncOp getStablehloRefineShapesTarget(ModuleOp module) { + // Only one function per module is supported at the moment to avoid the need +diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h +--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h ++++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h +@@ -16,7 +16,6 @@ + #ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H + #define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H + +-#include "llvm/ADT/SmallVector.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/IR/BuiltinOps.h" + #include "mlir/IR/Operation.h" +@@ -101,6 +100,18 @@ + return refineReturnShape(rewriter, op, shape); + } + ++// Entrypoint for any pass adding extensibility to the StableHLO shape ++// refinement pass. If program is inlined before shape refinement, ++// populateShapeRefinementPatterns can be safely used, but if shape refinement ++// needs to operate on programs with functions and calls, then ++// additionalPatterns will need to be populated and passed in. ++using AdditionalShapeRefinementPatternsFn = ++ std::function; ++LogicalResult refineEntryFunction( ++ MLIRContext& context, func::FuncOp func, ++ std::optional additionalPatternsFn = ++ std::nullopt); ++ + // Custom call used to buffer operands for shape refinement + // This is a temporary artifact that is introduced by StablehloRefineArguments + // and is washed away during StablehloRefineShapes. diff --git a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp index 7f630f0e11eea..37effdeadd65a 100644 --- a/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp +++ b/xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp @@ -13,9 +13,11 @@ limitations under the License. ==============================================================================*/ #include +#include #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Support/LogicalResult.h" @@ -138,32 +140,20 @@ struct StablehloRefineShapesPass auto func = stablehlo::getStablehloRefineShapesTarget(getOperation()); if (!func) return signalPassFailure(); - // The algorithm behind this pass consists of a single traversal of the - // function. This is sufficient because we only support one function per - // program at the moment. - // TODO(#1048): Find out why .maxIterations = 1 no longer works. - // There have been recent refactors to applyPatternsAndFoldGreedily - // upstream, and that might be the reason. - GreedyRewriteConfig config; - config.useTopDownTraversal = true; - config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; - config.maxIterations = 3; - config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; - config.strictMode = GreedyRewriteStrictness::AnyOp; - - RewritePatternSet patterns(&getContext()); - stablehlo::populateStablehloRefineShapesPatterns(&patterns, &getContext()); - stablehlo::populateStablehloShapeFolderPatterns(&patterns, &getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); - if (failed( - applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { - func.emitError() - << "Greedy rewriter in StablehloRefineShapes does not converge after " - << config.maxIterations << " iterations."; + // Start with empty state, and no dim args / token args. + MLIRContext* context = func.getContext(); + + // Populate additional patterns for StableHLO extensions. + std::function additionalPatternsFn = + [&](RewritePatternSet* patterns) { + patterns->add(context); + patterns->add(context); + patterns->add(context); + }; + + if (failed(stablehlo::refineEntryFunction(*context, func, + additionalPatternsFn))) return signalPassFailure(); - } } }; diff --git a/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_refine_shapes.mlir b/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_refine_shapes.mlir index 85d3c97dcaf58..63560cf04a3e3 100644 --- a/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_refine_shapes.mlir +++ b/xla/mlir_hlo/tests/stablehlo_ext/stablehlo_refine_shapes.mlir @@ -40,3 +40,23 @@ func.func @refine_dynamic_top_k(%arg0: tensor<16xf32>) -> (tensor, tensor %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor) -> (tensor, tensor) return %1#0, %1#1 : tensor, tensor } + +// ----- + +// CHECK-LABEL: module @refine_call +module @refine_call { + // CHECK: func.func @main{{.*}}-> (tensor<4xf32>, tensor<4xi32>) + func.func @main(%arg1: tensor<16xf32>) -> (tensor, tensor) { + %0 = stablehlo.bitcast_convert %arg1 : (tensor<16xf32>) -> tensor + // CHECK: refine_call_callee{{.*}}-> (tensor<4xf32>, tensor<4xi32>) + %2:2 = call @refine_call_callee(%0) : (tensor) -> (tensor, tensor) + return %2#0, %2#1 : tensor, tensor + } + // CHECK: refine_call_callee(%arg0: tensor<16xf32>) -> (tensor<4xf32>, tensor<4xi32>) + func.func @refine_call_callee(%arg0: tensor) -> (tensor, tensor) { + // CHECK: stablehlo.dynamic_top_k{{.*}} -> (tensor<4xf32>, tensor<4xi32>) + %k = stablehlo.constant dense<4> : tensor + %1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor, tensor) -> (tensor, tensor) + return %1#0, %1#1 : tensor, tensor + } +}