Skip to content

Commit

Permalink
[StableHLO] Add shape refinement callback to specify additional patte…
Browse files Browse the repository at this point in the history
…rns.

PiperOrigin-RevId: 704350283
  • Loading branch information
GleasonK authored and Google-ML-Automation committed Dec 13, 2024
1 parent b119483 commit 5442b82
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 25 deletions.
105 changes: 105 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
@@ -1 +1,106 @@
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp
@@ -369,6 +369,10 @@
// Which correlates to <func, sym_int_values, arg_types>
class RefineShapeState {
public:
+ RefineShapeState(
+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn)
+ : additionalPatternsFn(additionalPatternsFn) {}
+
enum class RefinementState {
NOT_ALREADY_REFINED,
ALREADY_REFINED,
@@ -431,7 +435,14 @@
});
}

+ void addAdditionalPatterns(RewritePatternSet& patterns) {
+ if (additionalPatternsFn.has_value())
+ additionalPatternsFn.value()(&patterns);
+ }
+
private:
+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn;
+
// Maps refined functions to the refinement context: the values of dimension
// arguments and the types of non-global-constant arguments. A function is
// added here when we start refining it.
@@ -1001,7 +1012,7 @@
LogicalResult applyShapeRefinementPatterns(func::FuncOp func,
RefineShapeState& state) {
MLIRContext* context = func.getContext();
- RewritePatternSet patterns(context);
+ RewritePatternSet patterns(func->getContext());
GreedyRewriteConfig config;

// The algorithm behind this pass consists of a single traversal of the
@@ -1019,6 +1030,9 @@
populateStablehloRefineShapesPatterns(&patterns, context);
patterns.add<RefineCallOpPattern>(context, state);

+ // Populate additional patterns for StableHLO extensions.
+ state.addAdditionalPatterns(patterns);
+
// The folding patterns implement partial evaluation of shape computations
// which is a critical part of implementing type refinement for ops like
// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape
@@ -1103,14 +1117,22 @@

// Start with empty state, and no dim args / token args.
MLIRContext* context = func.getContext();
- RefineShapeState state;
- RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes()));
- if (failed(refineFunction(*context, state, key)))
- return signalPassFailure();
+ if (failed(refineEntryFunction(*context, func))) return signalPassFailure();
}
};

} // namespace
+
+LogicalResult refineEntryFunction(
+ MLIRContext& context, func::FuncOp func,
+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn) {
+ // Start with empty state, and no dim args / token args.
+ RefineShapeState state(additionalPatternsFn);
+ RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes()));
+ if (failed(refineFunction(context, state, key)))
+ return func.emitError("Failed to refine entry function");
+ return success();
+}

func::FuncOp getStablehloRefineShapesTarget(ModuleOp module) {
// Only one function per module is supported at the moment to avoid the need
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h
@@ -16,7 +16,6 @@
#ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H
#define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H

-#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Operation.h"
@@ -101,6 +100,18 @@
return refineReturnShape(rewriter, op, shape);
}

+// Entrypoint for any pass adding extensibility to the StableHLO shape
+// refinement pass. If program is inlined before shape refinement,
+// populateShapeRefinementPatterns can be safely used, but if shape refinement
+// needs to operate on programs with functions and calls, then
+// additionalPatterns will need to be populated and passed in.
+using AdditionalShapeRefinementPatternsFn =
+ std::function<void(RewritePatternSet*)>;
+LogicalResult refineEntryFunction(
+ MLIRContext& context, func::FuncOp func,
+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn =
+ std::nullopt);
+
// Custom call used to buffer operands for shape refinement
// This is a temporary artifact that is introduced by StablehloRefineArguments
// and is washed away during StablehloRefineShapes.

40 changes: 15 additions & 25 deletions xla/mlir_hlo/stablehlo_ext/transforms/stablehlo_refine_shapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ limitations under the License.
==============================================================================*/

#include <cstdint>
#include <functional>

#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Support/LogicalResult.h"
Expand Down Expand Up @@ -138,32 +140,20 @@ struct StablehloRefineShapesPass
auto func = stablehlo::getStablehloRefineShapesTarget(getOperation());
if (!func) return signalPassFailure();

// The algorithm behind this pass consists of a single traversal of the
// function. This is sufficient because we only support one function per
// program at the moment.
// TODO(#1048): Find out why .maxIterations = 1 no longer works.
// There have been recent refactors to applyPatternsAndFoldGreedily
// upstream, and that might be the reason.
GreedyRewriteConfig config;
config.useTopDownTraversal = true;
config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive;
config.maxIterations = 3;
config.maxNumRewrites = GreedyRewriteConfig::kNoLimit;
config.strictMode = GreedyRewriteStrictness::AnyOp;

RewritePatternSet patterns(&getContext());
stablehlo::populateStablehloRefineShapesPatterns(&patterns, &getContext());
stablehlo::populateStablehloShapeFolderPatterns(&patterns, &getContext());
patterns.add<RefineDynamicReduceWindowOpPattern>(&getContext());
patterns.add<RefineDynamicRngBitGeneratorOpPattern>(&getContext());
patterns.add<RefineDynamicTopKOpPattern>(&getContext());
if (failed(
applyPatternsAndFoldGreedily(func, std::move(patterns), config))) {
func.emitError()
<< "Greedy rewriter in StablehloRefineShapes does not converge after "
<< config.maxIterations << " iterations.";
// Start with empty state, and no dim args / token args.
MLIRContext* context = func.getContext();

// Populate additional patterns for StableHLO extensions.
std::function<void(RewritePatternSet*)> additionalPatternsFn =
[&](RewritePatternSet* patterns) {
patterns->add<RefineDynamicReduceWindowOpPattern>(context);
patterns->add<RefineDynamicRngBitGeneratorOpPattern>(context);
patterns->add<RefineDynamicTopKOpPattern>(context);
};

if (failed(stablehlo::refineEntryFunction(*context, func,
additionalPatternsFn)))
return signalPassFailure();
}
}
};

Expand Down
20 changes: 20 additions & 0 deletions xla/mlir_hlo/tests/stablehlo_ext/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,23 @@ func.func @refine_dynamic_top_k(%arg0: tensor<16xf32>) -> (tensor<?xf32>, tensor
%1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<16xf32>, tensor<ui64>) -> (tensor<?xf32>, tensor<?xi32>)
return %1#0, %1#1 : tensor<?xf32>, tensor<?xi32>
}

// -----

// CHECK-LABEL: module @refine_call
module @refine_call {
// CHECK: func.func @main{{.*}}-> (tensor<4xf32>, tensor<4xi32>)
func.func @main(%arg1: tensor<16xf32>) -> (tensor<?xf32>, tensor<?xi32>) {
%0 = stablehlo.bitcast_convert %arg1 : (tensor<16xf32>) -> tensor<?xf32>
// CHECK: refine_call_callee{{.*}}-> (tensor<4xf32>, tensor<4xi32>)
%2:2 = call @refine_call_callee(%0) : (tensor<?xf32>) -> (tensor<?xf32>, tensor<?xi32>)
return %2#0, %2#1 : tensor<?xf32>, tensor<?xi32>
}
// CHECK: refine_call_callee(%arg0: tensor<16xf32>) -> (tensor<4xf32>, tensor<4xi32>)
func.func @refine_call_callee(%arg0: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xi32>) {
// CHECK: stablehlo.dynamic_top_k{{.*}} -> (tensor<4xf32>, tensor<4xi32>)
%k = stablehlo.constant dense<4> : tensor<ui64>
%1:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg0, %k) : (tensor<?xf32>, tensor<ui64>) -> (tensor<?xf32>, tensor<?xi32>)
return %1#0, %1#1 : tensor<?xf32>, tensor<?xi32>
}
}

0 comments on commit 5442b82

Please sign in to comment.