Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][ROCDL] Remove patterns for ops supported as intrinsics in the AMDGPU backend #102971

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,20 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// Handled by mathToLLVM: math::AbsIOp
// Handled by mathToLLVM: math::AbsFIOp
// Handled by mathToLLVM: math::CopySignOp
// Handled by mathToLLVM: math::CountLeadingZerosOp
// Handled by mathToLLVM: math::CountTrailingZerosOp
// Handled by mathToLLVM: math::CgPopOp
// Handled by mathToLLVM: math::ExpOp (32-bit only)
// Handled by mathToLLVM: math::FmaOp
// Handled by mathToLLVM: math::LogOp (32-bit only)
// FIXME: math::IPowIOp
// FIXME: math::FPowIOp
// Handled by mathToLLVM: math::RoundEvenOp
// Handled by mathToLLVM: math::RoundOp
// Handled by mathToLLVM: math::SqrtOp
// Handled by mathToLLVM: math::TruncOp
populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
"__ocml_fabs_f64");
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
"__ocml_acos_f64");
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
Expand All @@ -84,16 +86,14 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
"__ocml_cosh_f64");
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
"__ocml_sinh_f64");
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having disassembled OCML, the double-precision exp isn't actually a direct wrapper around the relevant intrinsic, but I figure that's probably fine

Copy link
Contributor

@arsenm arsenm Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. We only directly handle the f32 (and I think f16) versions. The f64 versions of the hard operations do not work. We do directly handle llvm.sqrt.f64 as an exception.

Also, none of the trig functions are directly handled (correctly). We do codegen the f32 versions but probably shouldn't

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what do we do in the case of f32 being handled but f64 not, should we still just call ocml for both or modify the lowering to handle just one of them?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify the lowering to handle just one. The operation + type should be treated like different operations, so emit the working f32 intrinsics and the calls for the nonworking f64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I put the lowering for f64 back for log and exp and added the tests back for them as well.

"__ocml_exp_f64");
populateOpPatterns<math::ExpOp>(converter, patterns, "", "__ocml_exp_f64");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
"__ocml_exp2_f64");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
"__ocml_expm1_f64");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
"__ocml_floor_f64");
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same note re f64 log

"__ocml_log_f64");
populateOpPatterns<math::LogOp>(converter, patterns, "", "__ocml_log_f64");
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
"__ocml_log10_f64");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
Expand All @@ -106,8 +106,6 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
"__ocml_rsqrt_f64");
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
"__ocml_sin_f64");
populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
"__ocml_sqrt_f64");
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
"__ocml_tanh_f64");
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
Expand Down
86 changes: 21 additions & 65 deletions mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,6 @@ gpu.module @test_module {

// -----

gpu.module @test_module {
// CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
// CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
// CHECK-LABEL: func @gpu_fabs
func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
%result32 = math.absf %arg_f32 : f32
// CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
%result64 = math.absf %arg_f64 : f64
// CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
}
}

// -----

gpu.module @test_module {
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
Expand Down Expand Up @@ -207,17 +192,12 @@ gpu.module @test_module {
// -----

gpu.module @test_module {
// CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
// CHECK-LABEL: func @gpu_exp
func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
%exp_f32 = math.exp %arg_f32 : f32
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
%result32 = math.exp %exp_f32 : f32
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
func.func @gpu_exp(%arg_f64 : f64) -> (f64) {
%result64 = math.exp %arg_f64 : f64
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
func.return %result64 : f64
}
}

Expand All @@ -239,21 +219,20 @@ gpu.module @test_module {
}

// -----

// Test that we handled properly operation with SymbolTable other than module op
gpu.module @test_module {
"test.symbol_scope"() ({
// CHECK: test.symbol_scope
// CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
// CHECK-LABEL: func @gpu_exp
func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
%exp_f32 = math.exp %arg_f32 : f32
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
%result32 = math.exp %exp_f32 : f32
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
%result64 = math.exp %arg_f64 : f64
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
// CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
// CHECK-LABEL: func @gpu_sin
func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
%sin_f32 = math.sin %arg_f32 : f32
// CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
%result32 = math.sin %sin_f32 : f32
// CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
%result64 = math.sin %arg_f64 : f64
// CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
}
"test.finish" () : () -> ()
Expand All @@ -280,15 +259,12 @@ gpu.module @test_module {
// -----

gpu.module @test_module {
// CHECK: llvm.func @__ocml_log_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
// CHECK-LABEL: func @gpu_log
func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
%result32 = math.log %arg_f32 : f32
// CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
func.func @gpu_log(%arg_f64 : f64) -> (f64) {
%result64 = math.log %arg_f64 : f64
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
func.return %result64 : f64
}
}

Expand Down Expand Up @@ -359,26 +335,6 @@ gpu.module @test_module {

// -----

gpu.module @test_module {
// CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
// CHECK-LABEL: func @gpu_sqrt
func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
-> (f16, f32, f64) {
%result16 = math.sqrt %arg_f16 : f16
// CHECK: llvm.fpext %{{.*}} : f16 to f32
// CHECK-NEXT: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
// CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
%result32 = math.sqrt %arg_f32 : f32
// CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
%result64 = math.sqrt %arg_f64 : f64
// CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
func.return %result16, %result32, %result64 : f16, f32, f64
}
}

// -----

gpu.module @test_module {
// CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
Expand Down Expand Up @@ -472,15 +428,15 @@ gpu.module @test_module {
gpu.module @test_module {
// CHECK-LABEL: func @gpu_unroll
func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
%result = math.exp %arg0 : vector<4xf32>
%result = math.sin %arg0 : vector<4xf32>
// CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32>
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]]
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]]
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]]
// CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]]
// CHECK: return %[[V4]]
func.return %result : vector<4xf32>
Expand Down Expand Up @@ -526,9 +482,9 @@ gpu.module @test_module {

gpu.module @module {
// CHECK-LABEL: @spirv_exp
// CHECK: llvm.call @__ocml_exp_f32
// CHECK: llvm.call @__ocml_sin_f32
spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
%0 = math.exp %arg0 : vector<4xf32>
%0 = math.sin %arg0 : vector<4xf32>
spirv.ReturnValue %0 : vector<4xf32>
}
}
Expand Down
44 changes: 4 additions & 40 deletions mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,6 @@ module @test_module {

// -----

module @test_module {
// CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
// CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
// CHECK-LABEL: func @math_absf
func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
%result32 = math.absf %arg_f32 : f32
// CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
%result64 = math.absf %arg_f64 : f64
// CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
}
}

// -----

module @test_module {
// CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
Expand Down Expand Up @@ -211,15 +196,12 @@ module @test_module {
// -----

module @test_module {
// CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
// CHECK-LABEL: func @math_exp
func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
%result32 = math.exp %arg_f32 : f32
// CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
func.func @math_exp(%arg_f64 : f64) -> (f64) {
%result64 = math.exp %arg_f64 : f64
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
func.return %result64 : f64
}
}

Expand Down Expand Up @@ -271,15 +253,12 @@ module @test_module {
// -----

module @test_module {
// CHECK: llvm.func @__ocml_log_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
// CHECK-LABEL: func @math_log
func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
%result32 = math.log %arg_f32 : f32
// CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
func.func @math_log(%arg_f64 : f64) -> (f64) {
%result64 = math.log %arg_f64 : f64
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
func.return %result64 : f64
}
}

Expand Down Expand Up @@ -360,21 +339,6 @@ module @test_module {

// -----

module @test_module {
// CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
// CHECK-LABEL: func @math_sqrt
func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
%result32 = math.sqrt %arg_f32 : f32
// CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
%result64 = math.sqrt %arg_f64 : f64
// CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
}
}

// -----

module @test_module {
// CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
Expand Down
Loading