Skip to content

Commit

Permalink
[Backport to 15] Add fast math flag translation for OpenCL std lib (#…
Browse files Browse the repository at this point in the history
…2844)

Such possibility was added in SPIR-V 1.6.
Note, this backport doesn't add handling of nofpclass attribute, as it
doesn't present in LLVM yet.

Signed-off-by: Sidorov, Dmitry <[email protected]>
  • Loading branch information
MrSidims authored Nov 8, 2024
1 parent 0f42473 commit f279524
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 4 deletions.
7 changes: 5 additions & 2 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2510,8 +2510,11 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
case OpExtInst: {
auto *ExtInst = static_cast<SPIRVExtInst *>(BV);
switch (ExtInst->getExtSetKind()) {
case SPIRVEIS_OpenCL:
return mapValue(BV, transOCLBuiltinFromExtInst(ExtInst, BB));
case SPIRVEIS_OpenCL: {
auto *V = mapValue(BV, transOCLBuiltinFromExtInst(ExtInst, BB));
applyFPFastMathModeDecorations(BV, static_cast<Instruction *>(V));
return V;
}
case SPIRVEIS_Debug:
case SPIRVEIS_OpenCL_DebugInfo_100:
case SPIRVEIS_NonSemantic_Shader_DebugInfo_100:
Expand Down
3 changes: 3 additions & 0 deletions lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,9 @@ CallInst *mutateCallInst(
NewCI->copyMetadata(*CI);
NewCI->setAttributes(CI->getAttributes());
NewCI->setTailCall(CI->isTailCall());
if (isa<FPMathOperator>(CI))
NewCI->setFastMathFlags(CI->getFastMathFlags());

if (CI->hasFnAttr("fpbuiltin-max-error")) {
auto Attr = CI->getFnAttr("fpbuiltin-max-error");
NewCI->addFnAttr(Attr);
Expand Down
9 changes: 7 additions & 2 deletions lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2729,7 +2729,8 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
if (Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
Opcode == Instruction::FMul || Opcode == Instruction::FDiv ||
Opcode == Instruction::FRem ||
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp) &&
((Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
BV->isExtInst()) &&
BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_6))) {
FastMathFlags FMF = BVF->getFastMathFlags();
SPIRVWord M{0};
Expand All @@ -2756,8 +2757,12 @@ bool LLVMToSPIRVBase::transDecoration(Value *V, SPIRVValue *BV) {
}
}
}
if (M != 0)
if (M != 0) {
BV->setFPFastMathMode(M);
if (Opcode == Instruction::FNeg || Opcode == Instruction::FCmp ||
BV->isExtInst())
BM->setMinSPIRVVersion(VersionNumber::SPIRV_1_6);
}
}
}
if (Instruction *Inst = dyn_cast<Instruction>(V)) {
Expand Down
40 changes: 40 additions & 0 deletions test/transcoding/fast-math-opencl-builtins.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv -spirv-text %t.bc -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: llvm-spirv %t.bc -o %t.spv
; RUN: spirv-val %t.spv
; RUN: llvm-spirv -r %t.spv -o - | llvm-dis -o - | FileCheck %s --check-prefix=CHECK-LLVM-OCL
; RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o - | llvm-dis -o - | FileCheck %s --check-prefix=CHECK-LLVM-SPV

; RUN: llvm-spirv -spirv-text --spirv-max-version=1.5 %t.bc -o - | FileCheck %s --check-prefix=CHECK-SPIRV-NEG

; CHECK-SPIRV: Decorate [[#FPDec4:]] FPFastMathMode 16
; CHECK-SPIRV: ExtInst [[#]] [[#FPDec4]] [[#]] fmax [[#]] [[#]]

; CHECK-SPIRV-NEG-NOT: Decorate [[#]] FPFastMathMode [[#]]

; CHECK-LLVM-OCL: call fast spir_func float @_Z4fmaxff(float %[[#]], float %[[#]])

; CHECK-LLVM-SPV: call fast spir_func float @_Z16__spirv_ocl_fmaxff(float %[[#]], float %[[#]])

; ModuleID = 'test.bc'
source_filename = "test.cpp"
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64-unknown-unknown"

@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32

declare dso_local spir_func noundef float @_Z16__spirv_ocl_fmaxff(float noundef, float noundef) local_unnamed_addr

define weak_odr dso_local spir_kernel void @nofpclass_fast(ptr addrspace(1) noundef align 4 %_arg_data, ptr addrspace(1) noundef align 4 %_arg_dat1, ptr addrspace(1) noundef align 4 %_arg_dat2) local_unnamed_addr {
entry:
%0 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32
%arrayidx.i = getelementptr inbounds float, ptr addrspace(1) %_arg_data, i64 %0
%arrayidx3.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat1, i64 %0
%cmp.i = icmp ult i64 %0, 2147483648
%arrayidx5.i = getelementptr inbounds float, ptr addrspace(1) %_arg_dat2, i64 %0
%1 = load float, ptr addrspace(1) %arrayidx3.i, align 4
%2 = load float, ptr addrspace(1) %arrayidx5.i, align 4
%call.i.i = tail call fast spir_func noundef float @_Z16__spirv_ocl_fmaxff(float noundef %1, float noundef %2)
store float %call.i.i, ptr addrspace(1) %arrayidx.i, align 4
ret void
}

0 comments on commit f279524

Please sign in to comment.