Skip to content

Commit

Permalink
Add fast math flag translation for OpenCL std lib (#2762) (#2840)
Browse files Browse the repository at this point in the history
Such possibility was added in SPIR-V 1.6.
This patch also introduces limited translation of nofpclass LLVM parameter attribute.

Signed-off-by: Sidorov, Dmitry <[email protected]>
  • Loading branch information
MrSidims authored Nov 8, 2024
1 parent 90a9764 commit 1fc553b
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 4 deletions.
3 changes: 3 additions & 0 deletions lib/SPIRV/SPIRVBuiltinHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ Value *BuiltinCallMutator::doConversion() {
NewCall->copyMetadata(*CI);
NewCall->setAttributes(CallAttrs);
NewCall->setTailCall(CI->isTailCall());
if (isa<FPMathOperator>(CI))
NewCall->setFastMathFlags(CI->getFastMathFlags());

if (CI->hasFnAttr("fpbuiltin-max-error")) {
auto Attr = CI->getFnAttr("fpbuiltin-max-error");
NewCall->addFnAttr(Attr);
Expand Down
7 changes: 5 additions & 2 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2478,8 +2478,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
case OpExtInst: {
auto *ExtInst = static_cast<SPIRVExtInst *>(BV);
switch (ExtInst->getExtSetKind()) {
case SPIRVEIS_OpenCL:
return mapValue(BV, transOCLBuiltinFromExtInst(ExtInst, BB));
case SPIRVEIS_OpenCL: {
auto *V = mapValue(BV, transOCLBuiltinFromExtInst(ExtInst, BB));
applyFPFastMathModeDecorations(BV, static_cast<Instruction *>(V));
return V;
}
case SPIRVEIS_Debug:
case SPIRVEIS_OpenCL_DebugInfo_100:
case SPIRVEIS_NonSemantic_Shader_DebugInfo_100:
Expand Down
49 changes: 47 additions & 2 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3022,7 +3022,8 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
if (Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
Opcode == Instruction::FMul || Opcode == Instruction::FDiv ||
Opcode == Instruction::FRem ||
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp) &&
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
BV->isExtInst()) &&
BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_6))) {
FastMathFlags FMF = BVF->getFastMathFlags();
SPIRVWord M{0};
Expand All @@ -3049,8 +3050,52 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
}
}
}
if (M != 0)
// Handle nofpclass attribute. Nothing to do if fast math flag is already
// set.
if ((BV->isExtInst() &&
static_cast<SPIRVExtInst *>(BV)->getExtSetKind() ==
SPIRVEIS_OpenCL) &&
BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_6) &&
!(M & FPFastMathModeFastMask)) {
auto *F = cast<CallInst>(V)->getCalledFunction();
auto FAttrs = F->getAttributes();
AttributeSet RetAttrs = FAttrs.getRetAttrs();
if (RetAttrs.hasAttribute(Attribute::NoFPClass)) {
FPClassTest RetTest =
RetAttrs.getAttribute(Attribute::NoFPClass).getNoFPClass();
AttributeSet RetAttrs = FAttrs.getRetAttrs();
// Only Nan and Inf tests are representable in SPIR-V now.
bool ToAddNoNan = RetTest & fcNan;
bool ToAddNoInf = RetTest & fcInf;
if (ToAddNoNan || ToAddNoInf) {
const auto *FT = F->getFunctionType();
const size_t NumParams = FT->getNumParams();
for (size_t I = 0; I != NumParams; ++I) {
if (!FT->getParamType(I)->isFloatTy())
continue;
if (!F->hasParamAttribute(I, Attribute::NoFPClass)) {
ToAddNoNan = false;
ToAddNoInf = false;
break;
}
FPClassTest ArgTest =
FAttrs.getParamAttr(I, Attribute::NoFPClass).getNoFPClass();
ToAddNoNan = ToAddNoNan && static_cast<bool>(ArgTest & fcNan);
ToAddNoInf = ToAddNoInf && static_cast<bool>(ArgTest & fcInf);
}
}
if (ToAddNoNan)
M |= FPFastMathModeNotNaNMask;
if (ToAddNoInf)
M |= FPFastMathModeNotInfMask;
}
}
if (M != 0) {
BV->setFPFastMathMode(M);
if (Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
BV->isExtInst())
BM->setMinSPIRVVersion(VersionNumber::SPIRV_1_6);
}
}
}
if (Instruction *Inst = dyn_cast<Instruction>(V)) {
Expand Down
98 changes: 98 additions & 0 deletions test/transcoding/fast-math-opencl-builtins.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv -spirv-text %t.bc -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: llvm-spirv %t.bc -o %t.spv
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r %t.spv -o - | llvm-dis -o - | FileCheck %s --check-prefix=CHECK-LLVM-OCL
; RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o - | llvm-dis -o - | FileCheck %s --check-prefix=CHECK-LLVM-SPV

; RUN: llvm-spirv -spirv-text --spirv-max-version=1.5 %t.bc -o - | FileCheck %s --check-prefix=CHECK-SPIRV-NEG

; CHECK-SPIRV: Decorate [[#FPDec1:]] FPFastMathMode 3
; CHECK-SPIRV: Decorate [[#FPDec2:]] FPFastMathMode 2
; CHECK-SPIRV: Decorate [[#FPDec3:]] FPFastMathMode 3
; CHECK-SPIRV: Decorate [[#FPDec4:]] FPFastMathMode 16
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec1]] [[#]] fmax [[#]] [[#]]
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec2]] [[#]] fmin [[#]] [[#]]
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec3]] [[#]] ldexp [[#]] [[#]]
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec4]] [[#]] fmax [[#]] [[#]]

; CHECK-SPIRV-NEG-NOT: Decorate [[#]] FPFastMathMode [[#]]

; CHECK-LLVM-OCL: call nnan ninf spir_func float @_Z4fmaxff(float %[[#]], float %[[#]])
; CHECK-LLVM-OCL: call ninf spir_func float @_Z4fminff(float %[[#]], float %[[#]])
; CHECK-LLVM-OCL: call nnan ninf spir_func float @_Z5ldexpfi(float %[[#]], i32 %[[#]])
; CHECK-LLVM-OCL: call fast spir_func float @_Z4fmaxff(float %[[#]], float %[[#]])

; CHECK-LLVM-SPV: call nnan ninf spir_func float @_Z16__spirv_ocl_fmaxff(float %[[#]], float %[[#]])
; CHECK-LLVM-SPV: call ninf spir_func float @_Z16__spirv_ocl_fminff(float %[[#]], float %[[#]])
; CHECK-LLVM-SPV: call nnan ninf spir_func float @_Z17__spirv_ocl_ldexpfi(float %[[#]], i32 %[[#]])
; CHECK-LLVM-SPV: call fast spir_func float @_Z16__spirv_ocl_fmaxff(float %[[#]], float %[[#]])

; ModuleID = 'test.bc'
source_filename = "test.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

declare dso_local spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fmaxff(float noundef nofpclass(nan inf), float noundef nofpclass(nan inf)) local_unnamed_addr

declare dso_local spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fminff(float noundef nofpclass(inf), float noundef nofpclass(nan inf)) local_unnamed_addr

declare dso_local spir_func noundef nofpclass(nan inf) float @_Z17__spirv_ocl_ldexpfi(float noundef nofpclass(nan inf), i32 noundef)

define weak_odr dso_local spir_kernel void @nofpclass_all(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
entry:
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
%cmp.i = icmp ult i64 %0, 2147483648
%arrayidx5.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat2, i64 %0
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
%2 = load float, ptr addrspace(1) %arrayidx5.i, align 4
%call.i.i = tail call spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fmaxff(float noundef nofpclass(nan inf) %1, float noundef nofpclass(nan inf) %2)
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
ret void
}

define weak_odr dso_local spir_kernel void @nofpclass_part(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
entry:
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
%cmp.i = icmp ult i64 %0, 2147483648
%arrayidx5.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat2, i64 %0
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
%2 = load float, ptr addrspace(1) %arrayidx5.i, align 4
%call.i.i = tail call spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fminff(float noundef nofpclass(inf) %1, float noundef nofpclass(nan inf) %2)
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
ret void
}

define weak_odr dso_local spir_kernel void @nofpclass_int(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
entry:
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
%cmp.i = icmp ult i64 %0, 2147483648
%arrayidx5.i = getelementptr inbounds i32, ptr addrspace(1) %_arg_dat2, i64 %0
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
%2 = load i32, ptr addrspace(1) %arrayidx5.i, align 4
%call.i.i = tail call spir_func noundef nofpclass(nan inf) float @_Z17__spirv_ocl_ldexpfi(float noundef nofpclass(inf) %1, i32 noundef %2)
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
ret void
}

define weak_odr dso_local spir_kernel void @nofpclass_fast(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
entry:
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
%cmp.i = icmp ult i64 %0, 2147483648
%arrayidx5.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat2, i64 %0
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
%2 = load float, ptr addrspace(1) %arrayidx5.i, align 4
%call.i.i = tail call fast spir_func noundef nofpclass(nan inf) float @_Z16__spirv_ocl_fmaxff(float noundef nofpclass(nan inf) %1, float noundef nofpclass(nan inf) %2)
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
ret void
}

0 comments on commit 1fc553b

Please sign in to comment.