diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 7de6971ba2ee72..b3b4d81e7ffa5b 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -48,18 +48,20 @@ static void populateOpPatterns(LLVMTypeConverter &converter, void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { // Handled by mathToLLVM: math::AbsIOp + // Handled by mathToLLVM: math::AbsFIOp // Handled by mathToLLVM: math::CopySignOp // Handled by mathToLLVM: math::CountLeadingZerosOp // Handled by mathToLLVM: math::CountTrailingZerosOp // Handled by mathToLLVM: math::CgPopOp + // Handled by mathToLLVM: math::ExpOp (32-bit only) // Handled by mathToLLVM: math::FmaOp + // Handled by mathToLLVM: math::LogOp (32-bit only) // FIXME: math::IPowIOp // FIXME: math::FPowIOp // Handled by mathToLLVM: math::RoundEvenOp // Handled by mathToLLVM: math::RoundOp + // Handled by mathToLLVM: math::SqrtOp // Handled by mathToLLVM: math::TruncOp - populateOpPatterns(converter, patterns, "__ocml_fabs_f32", - "__ocml_fabs_f64"); populateOpPatterns(converter, patterns, "__ocml_acos_f32", "__ocml_acos_f64"); populateOpPatterns(converter, patterns, "__ocml_acosh_f32", @@ -84,16 +86,14 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter, "__ocml_cosh_f64"); populateOpPatterns(converter, patterns, "__ocml_sinh_f32", "__ocml_sinh_f64"); - populateOpPatterns(converter, patterns, "__ocml_exp_f32", - "__ocml_exp_f64"); + populateOpPatterns(converter, patterns, "", "__ocml_exp_f64"); populateOpPatterns(converter, patterns, "__ocml_exp2_f32", "__ocml_exp2_f64"); populateOpPatterns(converter, patterns, "__ocml_expm1_f32", "__ocml_expm1_f64"); populateOpPatterns(converter, patterns, "__ocml_floor_f32", "__ocml_floor_f64"); - populateOpPatterns(converter, patterns, "__ocml_log_f32", - "__ocml_log_f64"); + populateOpPatterns(converter, patterns, "", "__ocml_log_f64"); populateOpPatterns(converter, patterns, "__ocml_log10_f32", "__ocml_log10_f64"); populateOpPatterns(converter, patterns, "__ocml_log1p_f32", @@ -106,8 +106,6 @@ void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter, "__ocml_rsqrt_f64"); populateOpPatterns(converter, patterns, "__ocml_sin_f32", "__ocml_sin_f64"); - populateOpPatterns(converter, patterns, "__ocml_sqrt_f32", - "__ocml_sqrt_f64"); populateOpPatterns(converter, patterns, "__ocml_tanh_f32", "__ocml_tanh_f64"); populateOpPatterns(converter, patterns, "__ocml_tan_f32", diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir index bf49a42a115775..b6fb08522ae1f3 100644 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -131,21 +131,6 @@ gpu.module @test_module { // ----- -gpu.module @test_module { - // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32 - // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64 - // CHECK-LABEL: func @gpu_fabs - func.func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = math.absf %arg_f32 : f32 - // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.absf %arg_f64 : f64 - // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 - } -} - -// ----- - gpu.module @test_module { // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32 // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64 @@ -207,17 +192,12 @@ gpu.module @test_module { // ----- gpu.module @test_module { - // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32 // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 // CHECK-LABEL: func @gpu_exp - func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %exp_f32 = math.exp %arg_f32 : f32 - // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 - %result32 = math.exp %exp_f32 : f32 - // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + func.func @gpu_exp(%arg_f64 : f64) -> (f64) { %result64 = math.exp %arg_f64 : f64 // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result64 : f64 } } @@ -239,21 +219,20 @@ gpu.module @test_module { } // ----- - // Test that we handled properly operation with SymbolTable other than module op gpu.module @test_module { "test.symbol_scope"() ({ // CHECK: test.symbol_scope - // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32 - // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 - // CHECK-LABEL: func @gpu_exp - func.func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %exp_f32 = math.exp %arg_f32 : f32 - // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 - %result32 = math.exp %exp_f32 : f32 - // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.exp %arg_f64 : f64 - // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 + // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64 + // CHECK-LABEL: func @gpu_sin + func.func @gpu_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %sin_f32 = math.sin %arg_f32 : f32 + // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + %result32 = math.sin %sin_f32 : f32 + // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.sin %arg_f64 : f64 + // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 func.return %result32, %result64 : f32, f64 } "test.finish" () : () -> () @@ -280,15 +259,12 @@ gpu.module @test_module { // ----- gpu.module @test_module { - // CHECK: llvm.func @__ocml_log_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log_f64(f64) -> f64 // CHECK-LABEL: func @gpu_log - func.func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = math.log %arg_f32 : f32 - // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32 + func.func @gpu_log(%arg_f64 : f64) -> (f64) { %result64 = math.log %arg_f64 : f64 // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result64 : f64 } } @@ -359,26 +335,6 @@ gpu.module @test_module { // ----- -gpu.module @test_module { - // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32 - // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64 - // CHECK-LABEL: func @gpu_sqrt - func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) - -> (f16, f32, f64) { - %result16 = math.sqrt %arg_f16 : f16 - // CHECK: llvm.fpext %{{.*}} : f16 to f32 - // CHECK-NEXT: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32 - // CHECK-NEXT: llvm.fptrunc %{{.*}} : f32 to f16 - %result32 = math.sqrt %arg_f32 : f32 - // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.sqrt %arg_f64 : f64 - // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64 - func.return %result16, %result32, %result64 : f16, f32, f64 - } -} - -// ----- - gpu.module @test_module { // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32 // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64 @@ -472,15 +428,15 @@ gpu.module @test_module { gpu.module @test_module { // CHECK-LABEL: func @gpu_unroll func.func @gpu_unroll(%arg0 : vector<4xf32>) -> vector<4xf32> { - %result = math.exp %arg0 : vector<4xf32> + %result = math.sin %arg0 : vector<4xf32> // CHECK: %[[V0:.+]] = llvm.mlir.undef : vector<4xf32> - // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 // CHECK: %[[V1:.+]] = llvm.insertelement %[[CL]], %[[V0]] - // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 // CHECK: %[[V2:.+]] = llvm.insertelement %[[CL]], %[[V1]] - // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 // CHECK: %[[V3:.+]] = llvm.insertelement %[[CL]], %[[V2]] - // CHECK: %[[CL:.+]] = llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + // CHECK: %[[CL:.+]] = llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 // CHECK: %[[V4:.+]] = llvm.insertelement %[[CL]], %[[V3]] // CHECK: return %[[V4]] func.return %result : vector<4xf32> @@ -526,9 +482,9 @@ gpu.module @test_module { gpu.module @module { // CHECK-LABEL: @spirv_exp -// CHECK: llvm.call @__ocml_exp_f32 +// CHECK: llvm.call @__ocml_sin_f32 spirv.func @spirv_exp(%arg0: vector<4xf32>) -> vector<4xf32> "None" { - %0 = math.exp %arg0 : vector<4xf32> + %0 = math.sin %arg0 : vector<4xf32> spirv.ReturnValue %0 : vector<4xf32> } } diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir index a406ec45a7f109..19d89e03a7f483 100644 --- a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -15,21 +15,6 @@ module @test_module { // ----- -module @test_module { - // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32 - // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64 - // CHECK-LABEL: func @math_absf - func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = math.absf %arg_f32 : f32 - // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.absf %arg_f64 : f64 - // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 - } -} - -// ----- - module @test_module { // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32 // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64 @@ -211,15 +196,12 @@ module @test_module { // ----- module @test_module { - // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32 // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 // CHECK-LABEL: func @math_exp - func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = math.exp %arg_f32 : f32 - // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + func.func @math_exp(%arg_f64 : f64) -> (f64) { %result64 = math.exp %arg_f64 : f64 // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result64 : f64 } } @@ -271,15 +253,12 @@ module @test_module { // ----- module @test_module { - // CHECK: llvm.func @__ocml_log_f32(f32) -> f32 // CHECK: llvm.func @__ocml_log_f64(f64) -> f64 // CHECK-LABEL: func @math_log - func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = math.log %arg_f32 : f32 - // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32 + func.func @math_log(%arg_f64 : f64) -> (f64) { %result64 = math.log %arg_f64 : f64 // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 + func.return %result64 : f64 } } @@ -360,21 +339,6 @@ module @test_module { // ----- -module @test_module { - // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32 - // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64 - // CHECK-LABEL: func @math_sqrt - func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { - %result32 = math.sqrt %arg_f32 : f32 - // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32 - %result64 = math.sqrt %arg_f64 : f64 - // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64 - func.return %result32, %result64 : f32, f64 - } -} - -// ----- - module @test_module { // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32 // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64