Skip to content

Commit

Permalink
Use updated operands during stableHLO to TTIR conversion (#1319)
Browse files Browse the repository at this point in the history
Tensor types are converted (e.g. 2x2xi64->2x2xi32) for unsupported data types.
So updated operands must be used for correct materialization.
  • Loading branch information
mmanzoorTT authored Nov 19, 2024
1 parent 2976a24 commit 94896f1
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 13 deletions.
4 changes: 2 additions & 2 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class StableHLOToTTIRReshapeOpConversionPattern
ArrayAttr new_shape_attr = rewriter.getI32ArrayAttr(new_shape_i32);
rewriter.replaceOpWithNewOp<mlir::tt::ttir::ReshapeOp>(
srcOp, getTypeConverter()->convertType(outputTensor.getType()),
srcOp->getOperand(0), outputTensor, new_shape_attr,
adaptor.getOperand(), outputTensor, new_shape_attr,
rewriter.getArrayAttr(
SmallVector<Attribute>(adaptor.getOperands().size() + 1,
rewriter.getAttr<OperandConstraintAttr>(
Expand Down Expand Up @@ -327,7 +327,7 @@ class StableHLOToTTIRGetDimensionSizeOpConversionPattern
intType, static_cast<int32_t>(srcOp.getDimension()));

rewriter.replaceOpWithNewOp<mlir::tt::ttir::GetDimensionSizeOp>(
srcOp, outputType, srcOp.getOperand(), dimension_attr);
srcOp, outputType, adaptor.getOperand(), dimension_attr);

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
module @jit_get_dimension_size attributes {} {
func.func public @test_get_dimension_size(%arg0: tensor<13x21x3xf32>) -> tensor<i32> {
%0 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<13x21x3xf32>) -> tensor<i32>
// CHECK: [[VAL:%[0-9]+]] = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<{{[0-9]+}}x{{[0-9]+}}x{{[0-9]+}}xf32>) -> tensor<1xi32>
// CHECK: [[VAL:%[0-9]+]] = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<13x21x3xf32>) -> tensor<1xi32>
return %0 : tensor<i32>
// CHECK: return [[VAL]] : tensor<1xi32>
}

func.func public @test_get_dimension_size_f64(%arg0: tensor<64x64xf64>) -> tensor<i32> {
// CHECK: [[VAL:%[0-9]+]] = "ttir.get_dimension_size"(%arg0) <{dimension = 1 : i32}> : (tensor<64x64xf32>) -> tensor<1xi32>
%0 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<64x64xf64>) -> tensor<i32>
// CHECK: return [[VAL]] : tensor<1xi32>
return %0 : tensor<i32>
}
}
10 changes: 0 additions & 10 deletions test/ttmlir/Conversion/StableHLOToTTIR/rehsape_op.mlir

This file was deleted.

36 changes: 36 additions & 0 deletions test/ttmlir/Conversion/StableHLOToTTIR/reshape_op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// REQUIRES: stablehlo
// RUN: ttmlir-opt --stablehlo-to-ttir-pipeline %s | FileCheck %s
module @jit_module_reshape attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @test_reshape(%arg0: tensor<1x64x64x64xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<1x1x4096x64xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
// CHECK-LABEL: func.func public @test_reshape
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty
// CHECK: %[[VAL:[0-9]+]] = "ttir.reshape"(%arg0, %[[EMPTY]])
// CHECK-SAME: shape = [1 : i32, 1 : i32, 4096 : i32, 64 : i32]
// CHECK-SAME: (tensor<1x64x64x64xf32>, tensor<1x1x4096x64xf32>) -> tensor<1x1x4096x64xf32>
%0 = stablehlo.reshape %arg0 : (tensor<1x64x64x64xf32>) -> tensor<1x1x4096x64xf32>
// CHECK: return %[[VAL]]
return %0 : tensor<1x1x4096x64xf32>
}

func.func public @test_reshape_i64(%arg0: tensor<1x1xi64>) -> tensor<1xi64> {
// CHECK-LABEL: func.func public @test_reshape_i64
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty
// CHECK: %[[VAL:[0-9]+]] = "ttir.reshape"(%arg0, %[[EMPTY]])
// CHECK-SAME: shape = [1 : i32]
// CHECK-SAME: (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1xi32>
%0 = stablehlo.reshape %arg0 : (tensor<1x1xi64>) -> tensor<1xi64>
// CHECK: return %[[VAL]]
return %0 : tensor<1xi64>
}

func.func public @test_reshape_i1(%arg0: tensor<2x7xi1>) -> tensor<7x2xi1> {
// CHECK-LABEL: func.func public @test_reshape_i1
// CHECK: %[[EMPTY:[0-9]+]] = tensor.empty
// CHECK: %[[VAL:[0-9]+]] = "ttir.reshape"(%arg0, %[[EMPTY]])
// CHECK-SAME: shape = [7 : i32, 2 : i32]
// CHECK-SAME: (tensor<2x7xbf16>, tensor<7x2xbf16>) -> tensor<7x2xbf16>
%0 = stablehlo.reshape %arg0 : (tensor<2x7xi1>) -> tensor<7x2xi1>
// CHECK: return %[[VAL]]
return %0 : tensor<7x2xi1>
}
}
9 changes: 9 additions & 0 deletions test/ttmlir/Silicon/StableHLO/get_dimension_size_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,13 @@ module @jit_get_dimension_size attributes {} {
%0 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<64x128xf32>) -> tensor<i32>
return %0 : tensor<i32>
}

func.func public @test_get_dimension_size_f64(%arg0: tensor<64x128xf64>) -> tensor<i32> {
// CHECK-LABEL: func.func public @test_get_dimension_size_f64
// CHECK: ttnn.full
// CHECK-SAME: {fillValue = 1.280000e+02 : f32}
// CHECK-SAME: -> tensor<1xi32
%0 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<64x128xf64>) -> tensor<i32>
return %0 : tensor<i32>
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,24 @@ module @jit_module_reshape attributes {mhlo.num_partitions = 1 : i32, mhlo.num_r
%0 = stablehlo.reshape %arg0 : (tensor<1x64x64x64xf32>) -> tensor<1x1x4096x64xf32>
return %0 : tensor<1x1x4096x64xf32>
}

func.func public @test_reshape_i64(%arg0: tensor<1x1x1xi64>) -> tensor<1x1xi64> {
// CHECK-LABEL: func.func public @test_reshape_i64
// CHECK: ttnn.reshape
// CHECK-SAME: {shape = [1 : i32, 1 : i32]}
// CHECK-SAME: tensor<1x1x1xi32,
// CHECK-SAME: -> tensor<1x1xi32,
%0 = stablehlo.reshape %arg0 : (tensor<1x1x1xi64>) -> tensor<1x1xi64>
return %0 : tensor<1x1xi64>
}

func.func public @test_reshape_i1(%arg0: tensor<1x1x2x7xi1>) -> tensor<1x1x7x2xi1> {
// CHECK-LABEL: func.func public @test_reshape_i1
// CHECK: ttnn.reshape
// CHECK-SAME: {shape = [1 : i32, 1 : i32, 7 : i32, 2 : i32]}
// CHECK-SAME: tensor<1x1x2x7xbf16,
// CHECK-SAME: -> tensor<1x1x7x2xbf16,
%0 = stablehlo.reshape %arg0 : (tensor<1x1x2x7xi1>) -> tensor<1x1x7x2xi1>
return %0 : tensor<1x1x7x2xi1>
}
}

0 comments on commit 94896f1

Please sign in to comment.