-
Notifications
You must be signed in to change notification settings - Fork 451
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[StableHLO] Add shape refinement callback to specify additional patte…
…rns. PiperOrigin-RevId: 704350283
- Loading branch information
1 parent
b119483
commit 08493be
Showing
3 changed files
with
140 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,106 @@ | ||
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp | ||
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp | ||
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp | ||
@@ -369,6 +369,10 @@ | ||
// Which correlates to <func, sym_int_values, arg_types> | ||
class RefineShapeState { | ||
public: | ||
+ RefineShapeState( | ||
+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn) | ||
+ : additionalPatternsFn(additionalPatternsFn) {} | ||
+ | ||
enum class RefinementState { | ||
NOT_ALREADY_REFINED, | ||
ALREADY_REFINED, | ||
@@ -431,7 +435,14 @@ | ||
}); | ||
} | ||
|
||
+ void addAdditionalPatterns(RewritePatternSet& patterns) { | ||
+ if (additionalPatternsFn.has_value()) | ||
+ additionalPatternsFn.value()(&patterns); | ||
+ } | ||
+ | ||
private: | ||
+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn; | ||
+ | ||
// Maps refined functions to the refinement context: the values of dimension | ||
// arguments and the types of non-global-constant arguments. A function is | ||
// added here when we start refining it. | ||
@@ -1001,7 +1012,7 @@ | ||
LogicalResult applyShapeRefinementPatterns(func::FuncOp func, | ||
RefineShapeState& state) { | ||
MLIRContext* context = func.getContext(); | ||
- RewritePatternSet patterns(context); | ||
+ RewritePatternSet patterns(func->getContext()); | ||
GreedyRewriteConfig config; | ||
|
||
// The algorithm behind this pass consists of a single traversal of the | ||
@@ -1019,6 +1030,9 @@ | ||
populateStablehloRefineShapesPatterns(&patterns, context); | ||
patterns.add<RefineCallOpPattern>(context, state); | ||
|
||
+ // Populate additional patterns for StableHLO extensions. | ||
+ state.addAdditionalPatterns(patterns); | ||
+ | ||
// The folding patterns implement partial evaluation of shape computations | ||
// which is a critical part of implementing type refinement for ops like | ||
// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape | ||
@@ -1103,14 +1117,22 @@ | ||
|
||
// Start with empty state, and no dim args / token args. | ||
MLIRContext* context = func.getContext(); | ||
- RefineShapeState state; | ||
- RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes())); | ||
- if (failed(refineFunction(*context, state, key))) | ||
- return signalPassFailure(); | ||
+ if (failed(refineEntryFunction(*context, func))) return signalPassFailure(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
+ | ||
+LogicalResult refineEntryFunction( | ||
+ MLIRContext& context, func::FuncOp func, | ||
+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn) { | ||
+ // Start with empty state, and no dim args / token args. | ||
+ RefineShapeState state(additionalPatternsFn); | ||
+ RefinementKey key(func, 0, {}, llvm::to_vector(func.getArgumentTypes())); | ||
+ if (failed(refineFunction(context, state, key))) | ||
+ return func.emitError("Failed to refine entry function"); | ||
+ return success(); | ||
+} | ||
|
||
func::FuncOp getStablehloRefineShapesTarget(ModuleOp module) { | ||
// Only one function per module is supported at the moment to avoid the need | ||
diff --ruN a/stablehlo/stablehlo/transforms/StablehloRefineShapes.h b/stablehlo/stablehlo/transforms/StablehloRefineShapes.h | ||
--- stablehlo/stablehlo/transforms/StablehloRefineShapes.h | ||
+++ stablehlo/stablehlo/transforms/StablehloRefineShapes.h | ||
@@ -16,7 +16,6 @@ | ||
#ifndef STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H | ||
#define STABLEHLO_TRANSFORMS_STABLEHLO_REFINE_SHAPES_H | ||
|
||
-#include "llvm/ADT/SmallVector.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/Operation.h" | ||
@@ -101,6 +100,18 @@ | ||
return refineReturnShape(rewriter, op, shape); | ||
} | ||
|
||
+// Entrypoint for any pass adding extensibility to the StableHLO shape | ||
+// refinement pass. If program is inlined before shape refinement, | ||
+// populateShapeRefinementPatterns can be safely used, but if shape refinement | ||
+// needs to operate on programs with functions and calls, then | ||
+// additionalPatterns will need to be populated and passed in. | ||
+using AdditionalShapeRefinementPatternsFn = | ||
+ std::function<void(RewritePatternSet*)>; | ||
+LogicalResult refineEntryFunction( | ||
+ MLIRContext& context, func::FuncOp func, | ||
+ std::optional<AdditionalShapeRefinementPatternsFn> additionalPatternsFn = | ||
+ std::nullopt); | ||
+ | ||
// Custom call used to buffer operands for shape refinement | ||
// This is a temporary artifact that is introduced by StablehloRefineArguments | ||
// and is washed away during StablehloRefineShapes. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters