-
Notifications
You must be signed in to change notification settings - Fork 12.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][ROCDL] Remove patterns for ops supported as intrinsics in the AMDGPU backend #102971
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Jan Leyonberg (jsjodin) ChangesThis patch removes patterns for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend. Full diff: https://github.com/llvm/llvm-project/pull/102971.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 7de6971ba2ee72..fd4eab0e10d67e 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -48,18 +48,20 @@ static void populateOpPatterns(LLVMTypeConverter &converter,
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
// Handled by mathToLLVM: math::AbsIOp
+ // Handled by mathToLLVM: math::AbsFIOp
// Handled by mathToLLVM: math::CopySignOp
// Handled by mathToLLVM: math::CountLeadingZerosOp
// Handled by mathToLLVM: math::CountTrailingZerosOp
// Handled by mathToLLVM: math::CgPopOp
+ // Handled by mathToLLVM: math::ExpOp
// Handled by mathToLLVM: math::FmaOp
+ // Handled by mathToLLVM: math::LogOp
// FIXME: math::IPowIOp
// FIXME: math::FPowIOp
// Handled by mathToLLVM: math::RoundEvenOp
// Handled by mathToLLVM: math::RoundOp
+ // Handled by mathToLLVM: math::SqrtOp
// Handled by mathToLLVM: math::TruncOp
- populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
- "__ocml_fabs_f64");
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
"__ocml_acos_f64");
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
@@ -84,16 +86,12 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
"__ocml_cosh_f64");
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
"__ocml_sinh_f64");
- populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
- "__ocml_exp_f64");
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
"__ocml_exp2_f64");
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
"__ocml_expm1_f64");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
"__ocml_floor_f64");
- populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
- "__ocml_log_f64");
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
"__ocml_log10_f64");
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
@@ -106,8 +104,6 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
"__ocml_rsqrt_f64");
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
"__ocml_sin_f64");
- populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
- "__ocml_sqrt_f64");
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
"__ocml_tanh_f64");
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
index bf49a42a115775..4f1f26e8794d9e 100644
--- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir
@@ -131,21 +131,6 @@ gpu.module @test_module {
// -----
-gpu.module @test_module {
- // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_fabs
- func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.absf %arg_f32 : f32
- // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.absf %arg_f64 : f64
- // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
gpu.module @test_module {
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64
@@ -206,23 +191,6 @@ gpu.module @test_module {
// -----
-gpu.module @test_module {
- // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %exp_f32 = math.exp %arg_f32 : f32
- // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
- %result32 = math.exp %exp_f32 : f32
- // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.exp %arg_f64 : f64
- // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
gpu.module @test_module {
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -239,21 +207,20 @@ gpu.module @test_module {
}
// -----
-
// Test that we handled properly operation with SymbolTable other than module op
gpu.module @test_module {
"test.symbol_scope"() ({
// CHECK: test.symbol_scope
- // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_exp
- func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %exp_f32 = math.exp %arg_f32 : f32
- // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
- %result32 = math.exp %exp_f32 : f32
- // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.exp %arg_f64 : f64
- // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
+ // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64
+ // CHECK-LABEL: func @gpu_sin
+ func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %sin_f32 = math.sin %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ %result32 = math.sin %sin_f32 : f32
+ // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.sin %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64
func.return %result32, %result64 : f32, f64
}
"test.finish" () : () -> ()
@@ -279,21 +246,6 @@ gpu.module @test_module {
// -----
-gpu.module @test_module {
- // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_log
- func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.log %arg_f32 : f32
- // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.log %arg_f64 : f64
- // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
gpu.module @test_module {
// CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64
@@ -359,26 +311,6 @@ gpu.module @test_module {
// -----
-gpu.module @test_module {
- // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
- // CHECK-LABEL: func @gpu_sqrt
- func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64)
- -> (f16, f32, f64) {
- %result16 = math.sqrt %arg_f16 : f16
- // CHECK: llvm.fpext %{{.*}} : f16 to f32
- // CHECK-NEXT: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
- // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16
- %result32 = math.sqrt %arg_f32 : f32
- // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.sqrt %arg_f64 : f64
- // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
- func.return %result16, %result32, %result64 : f16, f32, f64
- }
-}
-
-// -----
-
gpu.module @test_module {
// CHECK: llvm.func @__ocml_tan_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tan_f64(f64) -> f64
@@ -472,15 +404,15 @@ gpu.module @test_module {
gpu.module @test_module {
// CHECK-LABEL: func @gpu_unroll
func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> {
- %result = math.exp %arg0 : vector<4xf32>
+ %result = math.sin %arg0 : vector<4xf32>
// CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32>
- // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]]
- // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]]
- // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]]
- // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
+ // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32
// CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]]
// CHECK: return %[[V4]]
func.return %result : vector<4xf32>
@@ -526,9 +458,9 @@ gpu.module @test_module {
gpu.module @module {
// CHECK-LABEL: @spirv_exp
-// CHECK: llvm.call @__ocml_exp_f32
+// CHECK: llvm.call @__ocml_sin_f32
spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" {
- %0 = math.exp %arg0 : vector<4xf32>
+ %0 = math.sin %arg0 : vector<4xf32>
spirv.ReturnValue %0 : vector<4xf32>
}
}
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
index a406ec45a7f109..9a05a94f9f1ac7 100644
--- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -15,21 +15,6 @@ module @test_module {
// -----
-module @test_module {
- // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
- // CHECK-LABEL: func @math_absf
- func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.absf %arg_f32 : f32
- // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.absf %arg_f64 : f64
- // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
module @test_module {
// CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
// CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
@@ -210,21 +195,6 @@ module @test_module {
// -----
-module @test_module {
- // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
- // CHECK-LABEL: func @math_exp
- func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.exp %arg_f32 : f32
- // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.exp %arg_f64 : f64
- // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
module @test_module {
// CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32
// CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64
@@ -270,21 +240,6 @@ module @test_module {
// -----
-module @test_module {
- // CHECK: llvm.func @__ocml_log_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_log_f64(f64) -> f64
- // CHECK-LABEL: func @math_log
- func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.log %arg_f32 : f32
- // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.log %arg_f64 : f64
- // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
module @test_module {
// CHECK: llvm.func @__ocml_log10_f32(f32) -> f32
// CHECK: llvm.func @__ocml_log10_f64(f64) -> f64
@@ -360,21 +315,6 @@ module @test_module {
// -----
-module @test_module {
- // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32
- // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64
- // CHECK-LABEL: func @math_sqrt
- func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
- %result32 = math.sqrt %arg_f32 : f32
- // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32
- %result64 = math.sqrt %arg_f64 : f64
- // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64
- func.return %result32, %result64 : f32, f64
- }
-}
-
-// -----
-
module @test_module {
// CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32
// CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine, but I want to double-check that nothing'll go wrong with double-precision exp
and log
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32", | ||
"__ocml_exp2_f64"); | ||
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32", | ||
"__ocml_expm1_f64"); | ||
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32", | ||
"__ocml_floor_f64"); | ||
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same note re f64 log
@@ -84,16 +86,12 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter, | |||
"__ocml_cosh_f64"); | |||
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32", | |||
"__ocml_sinh_f64"); | |||
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having disassembled OCML, the double-precision exp isn't actually a direct wrapper around the relevant intrinsic, but I figure that's probably fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. We only directly handle the f32 (and I think f16) versions. The f64 versions of the hard operations do not work. We do directly handle llvm.sqrt.f64 as an exception.
Also, none of the trig functions are directly handled (correctly). We do codegen the f32 versions but probably shouldn't
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what do we do in the case of f32 being handled but f64 not, should we still just call ocml for both or modify the lowering to handle just one of them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modify the lowering to handle just one. The operation + type should be treated like different operations, so emit the working f32 intrinsics and the calls for the nonworking f64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I put the lowering for f64 back for log and exp and added the tests back for them as well.
…AMDGPU backend This patch removes pattens for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.
8fe4b0e
to
58f0fc6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved
(If it's possible, could you get the half-precision functions hooked up in a later PR?)
Sure, that should be pretty easy to do. |
LLVM::FAbsOp and LLVM::SqrtOp are legal after #102971
LLVM::FAbsOp and LLVM::SqrtOp are legal after llvm#102971
LLVM::FAbsOp and LLVM::SqrtOp are legal after llvm#102971
This patch now breaks lowering for bf16 ExpOp or LogOp. For bf16, we upcast to F32, and now we have no lowering for F32 anymore. |
@krzysz00 XLA is triggering this bug that bf16 ExpOp and LogOp cannot be lowered anymore: openxla/xla#19700 |
@akuegel You'll want to run Alternatively, I'll take a followup patch to lower BF16 |
That is to say, there is no bf16 |
Yes, and that is also done in this pass, but due to this PR it says there is no lowering pattern for F32. Reverting this Patch would fix it. |
Let me expand a bit. OpToFuncCallLowering (which is the shared logic with the NVVM conversion) upcasts bf16 to f32 before checking whether there is a device function available: Previously, this worked fine, as when upcasting ExpOp (or LogOp) from bf16 to f32, there was a corresponding f32 function specified. Now, this doesn't work anymore, as there is no corresponding f32 function specified. Users of this pass now need to upcast bf16 ops earlier. Is this really what we want? Let me repeat my earlier question: Does this patch actually make a difference in the final lowering of math::ExpOp for f32? If there are intrinsics, shouldn't the other pattern also eventually lower to them? And if not, wouldn't the right fix be to change the ocml functions to make use of the intrinsics? @krzysz00 @jsjodin @arsenm maybe one of you can answer this? |
This patch does change the lowering to use the compiler intrinsics. The OCML functions in question were, for f32, redundant wrappers for the intrinsics that the compiler team wants to remove. For bf16 exp and log, I'd almost think that an LLVM patch is in order |
For XLA, there is now openxla/xla#19913 which should fix the issue by upcasting log and exp ops with bf16 type early. |
@arsenm and other folks on the compiler side - would there be any reason not to expand exp and log on bfloats during SelectionDAG/GISel? |
No. In principle all the math intrinsics should be legalized for all types |
Ok, so maybe I'm misreading the above comments and we need an MLIR-side patch to make sure that the bf16 |
This patch removes patterns for a few operations which allows mathToLLVM conversion to convert the operations into LLVM intrinsics instead since they are supported directly by the AMDGPU backend.