Skip to content

Commit

Permalink
[Backport to 14] Add fast math flag translation for OpenCL std lib (#…
Browse files Browse the repository at this point in the history
…2845)

Such possibility was added in SPIR-V 1.6.
Note, this backport doesn't add handling of nofpclass attribute, as it doesn't present in LLVM yet.

Signed-off-by: Sidorov, Dmitry <[email protected]>
  • Loading branch information
MrSidims authored Nov 8, 2024
1 parent 3d182dc commit c1686e0
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 4 deletions.
7 changes: 5 additions & 2 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2460,8 +2460,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
case OpExtInst: {
auto *ExtInst = static_cast<SPIRVExtInst *>(BV);
switch (ExtInst->getExtSetKind()) {
case SPIRVEIS_OpenCL:
return mapValue(BV, transOCLBuiltinFromExtInst(ExtInst, BB));
case SPIRVEIS_OpenCL: {
auto *V = mapValue(BV, transOCLBuiltinFromExtInst(ExtInst, BB));
applyFPFastMathModeDecorations(BV, static_cast<Instruction *>(V));
return V;
}
case SPIRVEIS_Debug:
case SPIRVEIS_OpenCL_DebugInfo_100:
case SPIRVEIS_NonSemantic_Shader_DebugInfo_100:
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,9 @@ CallInst *mutateCallInst(
NewCI->copyMetadata(*CI);
NewCI->setAttributes(CI->getAttributes());
NewCI->setTailCall(CI->isTailCall());
if (isa<FPMathOperator>(CI))
NewCI->setFastMathFlags(CI->getFastMathFlags());

if (CI->hasFnAttr("fpbuiltin-max-error")) {
auto Attr = CI->getFnAttr("fpbuiltin-max-error");
NewCI->addFnAttr(Attr);
Expand Down
10 changes: 8 additions & 2 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2715,7 +2715,8 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
if (Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
Opcode == Instruction::FMul || Opcode == Instruction::FDiv ||
Opcode == Instruction::FRem ||
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp) &&
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
BV->isExtInst()) &&
BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_6))) {
FastMathFlags FMF = BVF->getFastMathFlags();
SPIRVWord M{0};
Expand All @@ -2742,8 +2743,13 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
}
}
}
if (M != 0)
if (M != 0) {
BV->setFPFastMathMode(M);
if (Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
BV->isExtInst())
BM->setMinSPIRVVersion(
static_cast<SPIRVWord>(VersionNumber::SPIRV_1_6));
}
}
}
if (Instruction *Inst = dyn_cast<Instruction>(V)) {
Expand Down
41 changes: 41 additions & 0 deletions test/transcoding/fast-math-opencl-builtins.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv -spirv-text %t.bc -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: llvm-spirv %t.bc -o %t.spv
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r %t.spv -o - | llvm-dis -o - | FileCheck %s --check-prefix=CHECK-LLVM-OCL
; RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o - | llvm-dis -o - | FileCheck %s --check-prefix=CHECK-LLVM-SPV

; RUN: llvm-spirv -spirv-text --spirv-max-version=1.5 %t.bc -o - | FileCheck %s --check-prefix=CHECK-SPIRV-NEG

; CHECK-SPIRV: Decorate [[#FPDec4:]] FPFastMathMode 16
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec4]] [[#]] fmax [[#]] [[#]]

; CHECK-SPIRV-NEG-NOT: Decorate [[#]] FPFastMathMode [[#]]

; CHECK-LLVM-OCL: call fast spir_func float @_Z4fmaxff(float %[[#]], float %[[#]])

; CHECK-LLVM-SPV: call fast spir_func float @_Z16__spirv_ocl_fmaxff(float %[[#]], float %[[#]])

; ModuleID = 'test.bc'
source_filename = "test.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

declare dso_local spir_func noundef float @_Z16__spirv_ocl_fmaxff(float noundef, float noundef) local_unnamed_addr

define weak_odr dso_local spir_kernel void @nofpclass_fast(float addrspace(1)* noundef align 4 %_arg_data, float addrspace(1)* noundef align 4 %_arg_dat1, float addrspace(1)* noundef align 4 %_arg_dat2) local_unnamed_addr {
entry:
%0 = load <3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, align 32
%elem = extractelement <3 x i64> %0, i32 2
%arrayidx.i = getelementptr inbounds float, float addrspace(1)* %_arg_data, i64 %elem
%arrayidx3.i = getelementptr inbounds float, float addrspace(1)* %_arg_dat1, i64 %elem
%cmp.i = icmp ult i64 %elem, 2147483648
%arrayidx5.i = getelementptr inbounds float, float addrspace(1)* %_arg_dat2, i64 %elem
%1 = load float, float addrspace(1)* %arrayidx3.i, align 4
%2 = load float, float addrspace(1)* %arrayidx5.i, align 4
%call.i.i = tail call fast spir_func noundef float @_Z16__spirv_ocl_fmaxff(float noundef %1, float noundef %2)
store float %call.i.i, float addrspace(1)* %arrayidx.i, align 4
ret void
}

0 comments on commit c1686e0

Please sign in to comment.