From a41205ac40f36624ddfc75285a29c536472e5d81 Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Fri, 1 Nov 2024 13:26:11 -0700 Subject: [PATCH] Make ReshapeOp return MHLO_AnyTensor instead of MHLO_StaticShapeTensor. Note that this only removes the TableGen generated MLIR verification of the return value. ReshapeOp::verify will still check the validity/compatiblity of the input/output types. PiperOrigin-RevId: 692274546 --- xla/mlir_hlo/mhlo/IR/hlo_ops.td | 2 +- .../hlo_legalize_to_stablehlo.cc | 6 ++++++ .../Dialect/mhlo/hlo-legalize-to-stablehlo.mlir | 9 +++++++++ xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir | 12 ++++++++++++ .../Dialect/mhlo/stablehlo-legalize-to-hlo.mlir | 7 +++++++ 5 files changed, 35 insertions(+), 1 deletion(-) diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/xla/mlir_hlo/mhlo/IR/hlo_ops.td index f4dbec25383855..ca34a0512739d2 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2842,7 +2842,7 @@ def MHLO_ReshapeOp: MHLO_Op<"reshape", let arguments = (ins MHLO_AnyTensor:$operand); - let results = (outs MHLO_StaticShapeTensor); + let results = (outs MHLO_AnyTensor); let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; diff --git a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 3931db159ec81c..a6d6ee0299dd4c 100644 --- a/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -59,6 +59,12 @@ bool hasPrivateFeaturesNotInStablehlo(HloOpTy hloOp) { mhlo::XlaRngGetAndUpdateStateOp>(hloOp.getOperation())) { return true; } + if (auto reshape = dyn_cast(hloOp.getOperation())) { + // Only mhlo and HLO support dynamic reshape results. stablehlo doesn't + // (yet), at the time of this writing. + auto t = dyn_cast(reshape.getResult().getType()); + if (t && !t.hasStaticShape()) return true; + } if constexpr (std::is_same::value) { // StableHLO convolution doesn't support "unknown" dimensions. // This is an esoteric feature of MHLO convolutions, and it's different diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 70a27eabd67856..f92ef789c819d5 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -2143,6 +2143,15 @@ func.func @op_fusion(%arg0: tensor) -> tensor { // ----- +func.func @reshape_with_dynamic_size_convert(%arg0: tensor>) -> tensor> { + // expected-error@+1 {{failed to legalize operation 'mhlo.reshape' that was explicitly marked illegal}} + %0 = "mhlo.reshape"(%arg0) : (tensor>) + -> tensor> + return %0 : tensor> +} + +// ----- + func.func @op_stochastic_convert(%arg0: tensor, %arg1: tensor) -> tensor { // expected-error@+1 {{failed to legalize operation 'mhlo.stochastic_convert' that was explicitly marked illegal}} %0 = "mhlo.stochastic_convert"(%arg0, %arg1) : (tensor, tensor) -> tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index 72a8248f2526cb..edd003902f4de1 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -2920,6 +2920,18 @@ func.func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> // ----- +// CHECK-LABEL: func @reshape_can_have_dynamic_dimensions +func.func @reshape_can_have_dynamic_dimensions() -> tensor> { + %0 = "mhlo.constant"() {value = dense<[[1],[2],[3],[4],[5],[6],[7]]> : tensor<7x1xi64>} : () -> tensor<7x1xi64> + %size = builtin.unrealized_conversion_cast to tensor + %1 = "mhlo.set_dimension_size"(%0, %size) <{dimension = 0 : i64}> : (tensor<7x1xi64>, tensor) -> tensor> + %2 = "mhlo.reshape"(%1) : (tensor>) + -> tensor> + return %2 : tensor> +} + +// ----- + // CHECK-LABEL: func @reverse func.func @reverse(%operand: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "mhlo.reverse"(%operand) { diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index c8687cfe3ff0da..fdf12a56cefb08 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1355,6 +1355,13 @@ func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { func.return %0 : tensor<4x4xf32> } +// CHECK-LABEL: "op_reshape_dynamic" +func.func @op_reshape_dynamic(%arg0: tensor>) -> tensor<7xi64> { + // CHECK: "mhlo.reshape"({{.*}}) : (tensor>) -> tensor<7xi64> + %0 = "stablehlo.reshape"(%arg0) : (tensor>) -> tensor<7xi64> + return %0 : tensor<7xi64> +} + // CHECK-LABEL: "op_return" func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "mhlo.case"([[ARG0:%arg[0-9]+]]) ({