diff --git a/xla/mlir_hlo/mhlo/IR/hlo_ops.td b/xla/mlir_hlo/mhlo/IR/hlo_ops.td index f33701336bbb64..4eb95ef326a659 100644 --- a/xla/mlir_hlo/mhlo/IR/hlo_ops.td +++ b/xla/mlir_hlo/mhlo/IR/hlo_ops.td @@ -2877,7 +2877,7 @@ def MHLO_ReshapeOp: MHLO_Op<"reshape", let arguments = (ins MHLO_AnyTensor:$operand); - let results = (outs MHLO_StaticShapeTensor); + let results = (outs MHLO_AnyTensor); let hasFolder = 1; let hasCanonicalizer = 1; let hasVerifier = 1; diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 70a27eabd67856..92b59bda4c1c05 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -2143,6 +2143,15 @@ func.func @op_fusion(%arg0: tensor) -> tensor { // ----- +func.func @reshape_with_dynamic_size_convert(%arg0: tensor>) -> tensor> { + // expected-error@+1 {{'stablehlo.reshape' op result #0 must be statically shaped tensor}} + %0 = "mhlo.reshape"(%arg0) : (tensor>) + -> tensor> + return %0 : tensor> +} + +// ----- + func.func @op_stochastic_convert(%arg0: tensor, %arg1: tensor) -> tensor { // expected-error@+1 {{failed to legalize operation 'mhlo.stochastic_convert' that was explicitly marked illegal}} %0 = "mhlo.stochastic_convert"(%arg0, %arg1) : (tensor, tensor) -> tensor diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir index dc913645068104..d07a178c6c4e7f 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir @@ -2920,6 +2920,18 @@ func.func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32> // ----- +// CHECK-LABEL: func @reshape_can_have_dynamic_dimensions +func.func @reshape_can_have_dynamic_dimensions() -> tensor> { + %0 = "mhlo.constant"() {value = dense<[[1],[2],[3],[4],[5],[6],[7]]> : tensor<7x1xi64>} : () -> tensor<7x1xi64> + %size = builtin.unrealized_conversion_cast to tensor + %1 = "mhlo.set_dimension_size"(%0, %size) <{dimension = 0 : i64}> : (tensor<7x1xi64>, tensor) -> tensor> + %2 = "mhlo.reshape"(%1) : (tensor>) + -> tensor> + return %2 : tensor> +} + +// ----- + // CHECK-LABEL: func @reverse func.func @reverse(%operand: tensor<3x2xi32>) -> tensor<3x2xi32> { %0 = "mhlo.reverse"(%operand) { diff --git a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index c8687cfe3ff0da..fdf12a56cefb08 100644 --- a/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -1355,6 +1355,13 @@ func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { func.return %0 : tensor<4x4xf32> } +// CHECK-LABEL: "op_reshape_dynamic" +func.func @op_reshape_dynamic(%arg0: tensor>) -> tensor<7xi64> { + // CHECK: "mhlo.reshape"({{.*}}) : (tensor>) -> tensor<7xi64> + %0 = "stablehlo.reshape"(%arg0) : (tensor>) -> tensor<7xi64> + return %0 : tensor<7xi64> +} + // CHECK-LABEL: "op_return" func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: "mhlo.case"([[ARG0:%arg[0-9]+]]) ({