Skip to content

Commit

Permalink
Make ReshapeOp return MHLO_AnyTensor instead of MHLO_StaticShapeTensor.
Browse files Browse the repository at this point in the history
Note that this only removes the TableGen generated MLIR verification of the
return value. ReshapeOp::verify will still check the validity/compatiblity of
the input/output types.

PiperOrigin-RevId: 692274546
  • Loading branch information
matthiaskramm authored and Google-ML-Automation committed Nov 1, 2024
1 parent 6618e6b commit a41205a
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 1 deletion.
2 changes: 1 addition & 1 deletion xla/mlir_hlo/mhlo/IR/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -2842,7 +2842,7 @@ def MHLO_ReshapeOp: MHLO_Op<"reshape",

let arguments = (ins MHLO_AnyTensor:$operand);

let results = (outs MHLO_StaticShapeTensor);
let results = (outs MHLO_AnyTensor);
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ bool hasPrivateFeaturesNotInStablehlo(HloOpTy hloOp) {
mhlo::XlaRngGetAndUpdateStateOp>(hloOp.getOperation())) {
return true;
}
if (auto reshape = dyn_cast<mhlo::ReshapeOp>(hloOp.getOperation())) {
// Only mhlo and HLO support dynamic reshape results. stablehlo doesn't
// (yet), at the time of this writing.
auto t = dyn_cast<RankedTensorType>(reshape.getResult().getType());
if (t && !t.hasStaticShape()) return true;
}
if constexpr (std::is_same<HloOpTy, mhlo::ConvolutionOp>::value) {
// StableHLO convolution doesn't support "unknown" dimensions.
// This is an esoteric feature of MHLO convolutions, and it's different
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2143,6 +2143,15 @@ func.func @op_fusion(%arg0: tensor<f32>) -> tensor<f32> {

// -----

func.func @reshape_with_dynamic_size_convert(%arg0: tensor<?x1xi64, #mhlo.type_extensions<bounds = [7, ?]>>) -> tensor<?xi64, #mhlo.type_extensions<bounds = [7]>> {
// expected-error@+1 {{failed to legalize operation 'mhlo.reshape' that was explicitly marked illegal}}
%0 = "mhlo.reshape"(%arg0) : (tensor<?x1xi64, #mhlo.type_extensions<bounds = [7, ?]>>)
-> tensor<?xi64, #mhlo.type_extensions<bounds = [7]>>
return %0 : tensor<?xi64, #mhlo.type_extensions<bounds = [7]>>
}

// -----

func.func @op_stochastic_convert(%arg0: tensor<f32>, %arg1: tensor<ui32>) -> tensor<i8> {
// expected-error@+1 {{failed to legalize operation 'mhlo.stochastic_convert' that was explicitly marked illegal}}
%0 = "mhlo.stochastic_convert"(%arg0, %arg1) : (tensor<f32>, tensor<ui32>) -> tensor<i8>
Expand Down
12 changes: 12 additions & 0 deletions xla/mlir_hlo/tests/Dialect/mhlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2920,6 +2920,18 @@ func.func @reshape_invalid_shapes(%operand: tensor<2x4xf32>) -> tensor<3x3xf32>

// -----

// CHECK-LABEL: func @reshape_can_have_dynamic_dimensions
func.func @reshape_can_have_dynamic_dimensions() -> tensor<?xi64, #mhlo.type_extensions<bounds = [7]>> {
%0 = "mhlo.constant"() {value = dense<[[1],[2],[3],[4],[5],[6],[7]]> : tensor<7x1xi64>} : () -> tensor<7x1xi64>
%size = builtin.unrealized_conversion_cast to tensor<i32>
%1 = "mhlo.set_dimension_size"(%0, %size) <{dimension = 0 : i64}> : (tensor<7x1xi64>, tensor<i32>) -> tensor<?x1xi64, #mhlo.type_extensions<bounds = [7, ?]>>
%2 = "mhlo.reshape"(%1) : (tensor<?x1xi64, #mhlo.type_extensions<bounds = [7, ?]>>)
-> tensor<?xi64, #mhlo.type_extensions<bounds = [7]>>
return %2 : tensor<?xi64, #mhlo.type_extensions<bounds = [7]>>
}

// -----

// CHECK-LABEL: func @reverse
func.func @reverse(%operand: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "mhlo.reverse"(%operand) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,13 @@ func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> {
func.return %0 : tensor<4x4xf32>
}

// CHECK-LABEL: "op_reshape_dynamic"
func.func @op_reshape_dynamic(%arg0: tensor<?x1xi64, #mhlo.type_extensions<bounds = [7, ?]>>) -> tensor<7xi64> {
// CHECK: "mhlo.reshape"({{.*}}) : (tensor<?x1xi64, #mhlo.type_extensions<bounds = [7, ?]>>) -> tensor<7xi64>
%0 = "stablehlo.reshape"(%arg0) : (tensor<?x1xi64, #mhlo.type_extensions<bounds = [7, ?]>>) -> tensor<7xi64>
return %0 : tensor<7xi64>
}

// CHECK-LABEL: "op_return"
func.func @op_return(%arg0: tensor<i32>, %arg1: tensor<f32>) -> tensor<f32> {
// CHECK: "mhlo.case"([[ARG0:%arg[0-9]+]]) ({
Expand Down

0 comments on commit a41205a

Please sign in to comment.