Skip to content

Commit

Permalink
llvm: some more tentatives to generate optimal field addition code in…
Browse files Browse the repository at this point in the history
… pure LLVM IR
  • Loading branch information
mratsim committed Aug 14, 2024
1 parent 569e029 commit 73955d4
Showing 1 changed file with 62 additions and 46 deletions.
108 changes: 62 additions & 46 deletions constantine/math_compiler/impl_fields_sat.nim
Original file line number Diff line number Diff line change
Expand Up @@ -78,49 +78,6 @@ import

const SectionName = "ctt.fields"

proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, M, carry: ValueRef) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
## This is constant-time straightline code.
## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU.
##
## To be used when the final substraction can
## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256)

# Mask: contains 0xFFFF or 0x0000
let (_, mask) = asy.br.subborrow(fd.zero, fd.zero, carry)

# Now substract the modulus, and test a < M
# (underflow) with the last borrow
let (borrow, a_minus_M) = asy.br.llvm_sub_overflow(a, M)

# If it underflows here, it means that it was
# smaller than the modulus and we don't need `a-M`
let (ctl, _) = asy.br.subborrow(mask, fd.zero, borrow)

let t = asy.br.select(ctl, a, a_minus_M)
asy.store(rr, t)

proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, M: ValueRef) =
## If a >= Modulus: r <- a-M
## else: r <- a
##
## This is constant-time straightline code.
## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU.
##
## To be used when the modulus does not use the full bitwidth of the storing words
## (say using 255 bits for the modulus out of 256 available in words)

# Now substract the modulus, and test a < M
# (underflow) with the last borrow
let (borrow, a_minus_M) = asy.br.llvm_sub_overflow(a, M)

# If it underflows here, it means that it was
# smaller than the modulus and we don't need `a-M`
let t = asy.br.select(borrow, a, a_minus_M)
asy.store(rr, t)

proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) =
## Generate an optimized modular addition kernel
## with parameters `a, b, modulus: Limbs -> Limbs`
Expand All @@ -142,11 +99,70 @@ proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) =
let b = asy.load2(fd.intBufTy, bb, "b")
let M = asy.load2(fd.intBufTy, MM, "M")

let (carry, apb) = asy.br.llvm_add_overflow(a, b)
if fd.spareBits >= 1:
asy.finalSubNoOverflow(fd, rr, apb, M)
let apb = asy.br.add(a, b, "a_plus_b")
if false:
# 33% more instructions
# https://github.com/llvm/llvm-project/issues/103717

# Now substract the modulus, and test apb < M
# (underflow) with the last borrow
let (borrow, apb_minus_M) = asy.br.llvm_sub_overflow(apb, M)

# If it underflows here, it means that it was
# smaller than the modulus and we don't need `a-M`
let t = asy.br.select(borrow, apb, apb_minus_M)
asy.store(rr, t)

else:
# 1 or 2 extra instructions
# https://github.com/llvm/llvm-project/issues/103841
# https://github.com/llvm/llvm-project/issues/103855

let s = constInt(fd.intBufTy, fd.w * fd.numWords - 1)

let apb_minus_M = asy.br.sub(apb, M)
let underflow = asy.br.lshr(apb_minus_M, s)
let borrow = asy.br.trunc(underflow, asy.ctx.int1_t())

let t = asy.br.select(borrow, apb, apb_minus_M)
asy.store(rr, t)
else:
asy.finalSubMayOverflow(fd, rr, apb, M, carry)
if false:
let (carry, apb) = asy.br.llvm_add_overflow(a, b, "a_plus_b")

# Mask: contains 0xFFFF or 0x0000
let (_, mask) = asy.br.subborrow(fd.zero, fd.zero, carry)

# Now substract the modulus, and test a < M
# (underflow) with the last borrow
let (borrow, apb_minus_M) = asy.br.llvm_sub_overflow(apb, M)

# If it underflows here, it means that it was
# smaller than the modulus and we don't need `a-M`
let (ctl, _) = asy.br.subborrow(mask, fd.zero, borrow)

let t = asy.br.select(ctl, apb, apb_minus_M)
asy.store(rr, t)
else:
let biggerIntBits = fd.w * (fd.numWords+1)
let biggerInt = asy.ctx.int_t(biggerIntBits)

let ax = asy.br.zext(a, biggerInt, "ax")
let bx = asy.br.zext(b, biggerInt, "bx")

let apb = asy.br.add(ax, bx, "a_plus_b")

let mx = asy.br.zext(M, biggerInt, "mx")
let apb_minus_M = asy.br.sub(apb, mx, "apb_minus_M")

let s = constInt(biggerInt, biggerIntBits - 1)
let underflow = asy.br.lshr(apb_minus_M, s)
let borrow = asy.br.trunc(underflow, asy.ctx.int1_t())

let tLarge = asy.br.select(borrow, apb, apb_minus_M)
let t = asy.br.trunc(tLarge, fd.intBufTy)
asy.store(rr, t)

asy.br.retVoid()

Expand Down

0 comments on commit 73955d4

Please sign in to comment.