Skip to content

Commit

Permalink
[BACKEND] Switch back to use llvm.load for shared memory load (#4776)
Browse files Browse the repository at this point in the history
When we don't have predicates we can use llvm.load. Using inline asm for
i8 types can cause inefficient code generation in llvm due to the
interaction with DAG legalizer.
  • Loading branch information
ThomasRaoux authored Sep 21, 2024
1 parent 3a647f0 commit 576426b
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 59 deletions.
62 changes: 13 additions & 49 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -709,39 +709,9 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK-LABEL: convert_layout_blocked_blocked
tt.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared
// CHECK-: nvvm.barrier0
// CHECK-COUNT-8: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand All @@ -761,10 +731,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand All @@ -782,18 +750,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: st.shared
// CHECK: nvvm.barrier0
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.inline_asm
// CHECK: ld.shared
// CHECK: llvm.load
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<16x16xf32, #blocked0> -> tensor<16x16xf32, #blocked1>
tt.return
}
Expand Down Expand Up @@ -851,7 +815,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK: llvm.inline_asm
// CHECK-SAME: st.shared
// CHECK: nvvm.barrier0
// CHECK: ld.shared
// CHECK: llvm.load
%0 = triton_gpu.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0>
tt.return
}
Expand Down Expand Up @@ -891,7 +855,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) {
// CHECK-COUNT-128: st.shared.b8
// CHECK: nvvm.barrier0
// CHECK-COUNT-8: ld.shared.v4.b32
// CHECK-COUNT-8: llvm.load {{.*}} -> vector<4xi32>
%0 = triton_gpu.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked>
tt.return
}
Expand Down Expand Up @@ -920,7 +884,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice0
tt.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) {
// CHECK: inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@${{.*}} ld.shared.v4.b32
// CHECK: llvm.load {{.*}} -> vector<4xi32>
%cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
tt.return
}
Expand All @@ -933,7 +897,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-8: inline_asm{{.*}}ld.shared.b32
// CHECK-COUNT-8: llvm.load {{.*}} -> i32
%cvt = triton_gpu.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return
}
Expand Down
43 changes: 33 additions & 10 deletions third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,13 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) {
}
}

static bool isConstantTruePred(Value pred) {
if (auto constOp = pred.getDefiningOp<LLVM::ConstantOp>()) {
return cast<IntegerAttr>(constOp.getValue()).getInt() != 0;
}
return false;
}

void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
std::optional<Value> ctaId, Value val,
Value pred) const {
Expand Down Expand Up @@ -501,16 +508,32 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
.v(vec, /*predicate=*/vec > 1)
.b(elemBitwidth);

std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth);
auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint)
: builder.newListOperand(vec, elemConstraint);
ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b");

Type resultTy =
vec == 1 ? Type(int_ty(elemBitwidth))
: Type(struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth))));
Value load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true);

Value load;
if (isConstantTruePred(pred)) {
Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth))
: Type(vec_ty(int_ty(elemBitwidth), vec));
load = load(resultTy, ptr);
if (vec > 1) {
Type structTy = struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth)));
Value structValue = undef(structTy);
for (int i = 0; i < vec; i++) {
structValue = insert_val(structTy, structValue,
extract_element(load, i32_val(i)), i);
}
load = structValue;
}
} else {
std::string elemConstraint = "=" + getConstraintForBitwidth(elemBitwidth);
auto *outOpr = vec == 1 ? builder.newOperand(elemConstraint)
: builder.newListOperand(vec, elemConstraint);
ld(outOpr, builder.newAddrOperand(ptr, "r")).predicate(pred, "b");

Type resultTy =
vec == 1
? Type(int_ty(elemBitwidth))
: Type(struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth))));
load = builder.launch(rewriter, loc, resultTy, /*hasSideEffects=*/true);
}
SmallVector<Value> resultVals = unpackLLElements(loc, load, rewriter);
return packLLVector(loc, resultVals, rewriter);
}
Expand Down

0 comments on commit 576426b

Please sign in to comment.