Skip to content

Commit

Permalink
llvm: partial workaround failure around https://github.com/llvm/llvm-…
Browse files Browse the repository at this point in the history
…project/issues/102868\#issuecomment-2284935755 module inlining breaks machine instruction fusion
  • Loading branch information
mratsim committed Aug 12, 2024
1 parent 08b8671 commit a76cfd8
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 60 deletions.
88 changes: 31 additions & 57 deletions constantine/math_compiler/impl_fields_sat.nim
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ import

const SectionName = "ctt.fields"

proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M, carry: ValueRef) =
proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, MM, carry: ValueRef) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
Expand All @@ -85,40 +85,29 @@ proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M, car
## To be used when the final substraction can
## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256)

let name = "_finalsub_mayo_u" & $fd.w & "x" & $fd.numWords
asy.llvmInternalFnDef(
name, SectionName,
asy.void_t, toTypes([r, a, M, carry]),
{kHot, kInline}):
let r = asy.asArray(rr, fd.fieldTy)
let M = asy.load2(fd.intBufTy, MM, "M")

let (rr, aa, MM, carry) = llvmParams
let noCarry = asy.br.`not`(carry, "notcarry")

let r = asy.asArray(rr, fd.fieldTy)
let M = asy.load2(fd.intBufTy, MM, "M")
# let aPtr = asy.asLlvmIntPtr(aa, fd.intBufTy) # Pointers are opaque in LLVM now
let a = asy.load2(fd.intBufTy, aa, "a")
# Now substract the modulus, and test a < M
# (underflow) with the last borrow.
# On x86 at least, LLVM can fuse sub and icmp into sub-with-borrow
# if this is inline the caller https://github.com/llvm/llvm-project/issues/102868
let a_minus_M = asy.br.sub(a, M, "a_minus_M")
let borrow = asy.br.icmp(kULT, a, M, "borrow")

# Now substract the modulus, and test a < M
# (underflow) with the last borrow.
# On x86 at least, LLVM can fuse sub and icmp into sub-with-borrow
let a_minus_M = asy.br.sub(a, M, "a_minus_M")
let borrow = asy.br.icmp(kULT, a, M, "borrow")

# Cases:
# No carry after a+b, no borrow after a-M -> return a-M
# carry after a+b, will borrow after a-M (last bit lost) -> return a-M
# carry after a+b, no borrow after a-M -> return a-M
# No carry after a+b, borrow after a-M -> return a
let notBorrow = asy.br.`not`(borrow, "notborrow")
let ctl = asy.br.`or`(carry, notBorrow, "needSub")
let t = asy.br.select(ctl, a_minus_M, a)

asy.store(r, t)
asy.br.retVoid()
# Cases:
# No carry after a+b, no borrow after a-M -> return a-M
# carry after a+b, will borrow after a-M (last bit lost) -> return a-M
# carry after a+b, no borrow after a-M -> return a-M
# No carry after a+b, borrow after a-M -> return a
let ctl = asy.br.`or`(noCarry, borrow, "in_range")
let t = asy.br.select(ctl, a, a_minus_M)

asy.callFn(name, [r, a, M, carry])
asy.store(r, t)

proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: ValueRef) =
proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, MM: ValueRef) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
Expand All @@ -128,32 +117,19 @@ proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: Valu
## To be used when the modulus does not use the full bitwidth of the storing words
## (say using 255 bits for the modulus out of 256 available in words)

let name = "_finalsub_noo_u" & $fd.w & "x" & $fd.numWords
asy.llvmInternalFnDef(
name, SectionName,
asy.void_t, toTypes([r, a, M]),
{kHot, kInline}):

let (rr, aa, MM) = llvmParams

let r = asy.asArray(rr, fd.fieldTy)
let M = asy.load2(fd.intBufTy, MM, "M")
# Pointers are opaque in LLVM now
let a = asy.load2(fd.intBufTy, aa, "a")

# Now substract the modulus, and test a < M
# (underflow) with the last borrow
# On x86 at least, LLVM can fuse sub and icmp into sub-with-borrow
let a_minus_M = asy.br.sub(a, M, "a_minus_M")
let borrow = asy.br.icmp(kULT, a, M, "borrow")
let r = asy.asArray(rr, fd.fieldTy)
let M = asy.load2(fd.intBufTy, MM, "M")

# If it underflows here a was smaller than the modulus, which is what we want
let t = asy.br.select(borrow, a, a_minus_M)
# Now substract the modulus, and test a < M
# (underflow) with the last borrow
# On x86 at least, LLVM can fuse sub and icmp into sub-with-borrow
let a_minus_M = asy.br.sub(a, M, "a_minus_M")
let borrow = asy.br.icmp(kULT, a, M, "borrow")

asy.store(r, t)
asy.br.retVoid()
# If it underflows here a was smaller than the modulus, which is what we want
let t = asy.br.select(borrow, a, a_minus_M)

asy.callFn(name, [r, a, M])
asy.store(r, t)

proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) =
## Generate an optimized modular addition kernel
Expand All @@ -174,14 +150,12 @@ proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) =
let b = asy.load2(fd.intBufTy, bb, "b")

let apb = asy.br.add(a, b, "a_plus_b")
let t = asy.makeArray(fd.fieldTy)
asy.store(t, apb)

if fd.spareBits >= 1:
asy.finalSubNoOverflow(fd, r, t.buf, M)
asy.finalSubNoOverflow(fd, r, apb, M)
else:
let carry = asy.br.icmp(kUlt, apb, b, "overflow")
asy.finalSubMayOverflow(fd, r, t.buf, M, carry)
asy.finalSubMayOverflow(fd, r, apb, M, carry)

asy.br.retVoid()

Expand Down
2 changes: 1 addition & 1 deletion constantine/math_compiler/pub_fields.nim
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ proc genFpAdd*(asy: Assembler_LLVM, fd: FieldDescriptor): string =
## and return the corresponding name to call it

let name = fd.name & "_add"
asy.llvmPublicFnDef(name, fd.name, asy.void_t, [fd.fieldTy, fd.fieldTy, fd.fieldTy]):
asy.llvmPublicFnDef(name, "ctt." & fd.name, asy.void_t, [fd.fieldTy, fd.fieldTy, fd.fieldTy]):
let M = asy.getModulusPtr(fd)

let (r, a, b) = llvmParams
Expand Down
32 changes: 30 additions & 2 deletions research/codegen/x86_poc.nim
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ proc t_field_add() =
discard asy.genFpAdd(fd)

echo "========================================="
echo "LLVM IR\n"
echo "LLVM IR unoptimized\n"

echo asy.module
echo "========================================="
Expand All @@ -71,11 +71,33 @@ proc t_field_add() =
codeModel = CodeModelDefault
)

# Due to https://github.com/llvm/llvm-project/issues/102868
# We want to reproduce the codegen from llc.cpp
# However we can't reproduce the code from either
# - LLVM16 https://github.com/llvm/llvm-project/blob/llvmorg-16.0.6/llvm/tools/llc/llc.cpp
# need legacy PassManagerRef and the PassManagerBuilder that interfaces between the
# legacy PssManagerRef and new PassBuilder has been deleted in LLVM17
#
# - and contrary to what is claimed in https://llvm.org/docs/NewPassManager.html#id2
# the C API of PassBuilderRef is ghost town.
#
# So we reproduce the optimization passes from
# https://reviews.llvm.org/D145835

let pbo = createPassBuilderOptions()
pbo.setMergeFunctions()
let err = asy.module.runPasses(
"default<O2>",
# "default<O2>,memcpyopt,sroa,mem2reg,function-attrs,inline,gvn,dse,aggressive-instcombine,adce",
"function(require<targetir>,require<targetlibinfo>,require<inliner-size-estimator>,require<memdep>,require<da>)" &
",function(aa-eval)" &
",always-inline,hotcoldsplit,inferattrs,instrprof,recompute-globalsaa" &
",cgscc(argpromotion,function-attrs)" &
# ",require<inline-advisor>,partial-inliner,called-value-propagation" &
# ",scc-oz-module-inliner,inline-wrapper,module-inline" & # Buggy optimization
",function(verify,loop-mssa(loop-reduce),mergeicmps,expand-memcmp,instsimplify)" &
",function(lower-constant-intrinsics,consthoist,partially-inline-libcalls,ee-instrument<post-inline>,scalarize-masked-mem-intrin,verify)" &
",memcpyopt,sroa,dse,aggressive-instcombine,gvn,ipsccp,deadargelim,adce" &
"",
machine,
pbo
)
Expand All @@ -87,6 +109,12 @@ proc t_field_add() =
errMsg.dispose()
quit 1

echo "========================================="
echo "LLVM IR optimized\n"

echo asy.module
echo "========================================="

echo "========================================="
echo "Assembly\n"

Expand Down

0 comments on commit a76cfd8

Please sign in to comment.