From 576426bccfb9a2c90f2abaa405995738d4a79403 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Sat, 21 Sep 2024 09:35:58 -0700 Subject: [PATCH] [BACKEND] Switch back to use llvm.load for shared memory load (#4776) When we don't have predicates we can use llvm.load. Using inline asm for i8 types can cause inefficient code generation in llvm due to the interaction with DAG legalizer. --- test/Conversion/tritongpu_to_llvm.mlir | 62 ++++--------------- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 43 ++++++++++--- 2 files changed, 46 insertions(+), 59 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 3ce43e71d2a7..b60a73f80c8c 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -709,39 +709,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK-LABEL: convert_layout_blocked_blocked tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: llvm.inline_asm - // CHECK: st.shared - // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared + // CHECK-: nvvm.barrier0 + // CHECK-COUNT-8: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -761,10 +731,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: llvm.inline_asm // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK: llvm.load + // CHECK: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -782,18 +750,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // CHECK: llvm.inline_asm // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK: llvm.load + // CHECK: llvm.load // CHECK: nvvm.barrier0 // CHECK: llvm.inline_asm // CHECK: st.shared // CHECK: nvvm.barrier0 - // CHECK: llvm.inline_asm - // CHECK: ld.shared - // CHECK: llvm.inline_asm - // CHECK: ld.shared + // CHECK: llvm.load + // CHECK: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1> tt.return } @@ -851,7 +815,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: llvm.inline_asm // CHECK-SAME: st.shared // CHECK: nvvm.barrier0 - // CHECK: ld.shared + // CHECK: llvm.load %0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0> tt.return } @@ -891,7 +855,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) { // CHECK-COUNT-128: st.shared.b8 // CHECK: nvvm.barrier0 - // CHECK-COUNT-8: ld.shared.v4.b32 + // CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32> %0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked> tt.return } @@ -920,7 +884,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice0 tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { - // CHECK: inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} ld.shared.v4.b32 + // CHECK: llvm.load {{.*}} -> vector<4xi32> %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> tt.return } @@ -933,7 +897,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice1 tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { - // CHECK-COUNT-8: inline_asm{{.*}}ld.shared.b32 + // CHECK-COUNT-8: llvm.load {{.*}} -> i32 %cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> tt.return } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 30d4639cbf41..5813b9679ef0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -290,6 +290,13 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) { } } +static bool isConstantTruePred(Value pred) { + if (auto constOp = pred.getDefiningOp()) { + return cast(constOp.getValue()).getInt() != 0; + } + return false; +} + void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, std::optional ctaId, Value val, Value pred) const { @@ -501,16 +508,32 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, .v(vec, /*predicate=*/vec > 1) .b(elemBitwidth); - std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); - auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) - : builder.newListOperand(vec, elemConstraint); - ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); - - Type resultTy = - vec == 1 ? Type(int_ty(elemBitwidth)) - : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); - Value load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); - + Value load; + if (isConstantTruePred(pred)) { + Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth)) + : Type(vec_ty(int_ty(elemBitwidth), vec)); + load = load(resultTy, ptr); + if (vec > 1) { + Type structTy = struct_ty(SmallVector(vec, int_ty(elemBitwidth))); + Value structValue = undef(structTy); + for (int i = 0; i < vec; i++) { + structValue = insert_val(structTy, structValue, + extract_element(load, i32_val(i)), i); + } + load = structValue; + } + } else { + std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth); + auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint) + : builder.newListOperand(vec, elemConstraint); + ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b"); + + Type resultTy = + vec == 1 + ? Type(int_ty(elemBitwidth)) + : Type(struct_ty(SmallVector(vec, int_ty(elemBitwidth)))); + load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true); + } SmallVector resultVals = unpackLLElements(loc, load, rewriter); return packLLVector(loc, resultVals, rewriter); }