Skip to content

Commit

Permalink
LLVM: field addition with saturated fields (#456)
Browse files Browse the repository at this point in the history
* feat(LLVM): add codegenerator for saturated field add/sub

* LLVM: WIP refactor - boilerplate, linkage, assembly sections, ...

* feat(llvm): try (and fail) to workaround bad modular addition codegen with inline function.

* llvm: partial workaround failure around https://github.com/llvm/llvm-project/issues/102868\#issuecomment-2284935755 module inlining breaks machine instruction fusion

* llvm: define our own addcarry/subborrow which properly optimize on x86 (but not ARM see llvm/llvm-project#102062)

* llvm: use builtin llvm.uadd.with.overflow.iXXX to try to generate optimal code (and fail for i320 and i384 llvm/llvm-project#103717)
  • Loading branch information
mratsim authored Aug 14, 2024
1 parent 1e34ec2 commit 569e029
Show file tree
Hide file tree
Showing 19 changed files with 1,626 additions and 685 deletions.
8 changes: 8 additions & 0 deletions PLANNING.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ Other tracks are stretch goals, contributions towards them are accepted.
- introduce batchAffine_vartime
- Optimized square_repeated in assembly for Montgomery and Crandall/Pseudo-Mersenne primes
- Optimized elliptic curve directly calling assembly without ADX checks and limited input/output movement in registers or using function multi-versioning.
- LLVM IR:
- use internal or private linkage type
- look into calling conventions like "fast" or "Tail fast"
- check if returning a value from function is propely optimized
compared to in-place result
- use readnone (pure) and readmem attribute for functions
- look into passing parameter as arrays instead of pointers?
- use hot function attribute

### User Experience track

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ proc finalSubMayOverflowImpl*(
ctx.mov scratch[i], a[i]
ctx.sbb scratch[i], M[i]

# If it overflows here, it means that it was
# If it underflows here, it means that it was
# smaller than the modulus and we don't need `scratch`
ctx.sbb scratchReg, 0

Expand Down
83 changes: 83 additions & 0 deletions constantine/math_compiler/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Cryptography primitive compiler

This implements a cryptography compiler that can be used to produce
- high-performance JIT code for GPUs
- or assembly files, for CPUs when we want to ensure
there are no side-channel regressions for secret data
- or vectorized assembly file, as LLVM IR is significantly
more convenient to model vector operation

There are also LLVM IR => FPGA translators that might be useful
in the future.

## Platforms limitations

- X86 cannot use dual carry-chain ADCX/ADOX easily.
- no native support for clearing a flag with `xor`
and keeping it clear.
- inline assembly cannot use the raw ASM printer.
so workflow will need to compile -> decompile.
- Nvidia GPUs cannot lower types larger than 64-bit, hence we cannot use i256 for example.
- AMD GPUs have a 1/4 throughput for i32 MUL compared to f32 MUL or i24 MUL
- non-x86 targets may not be as optimized for matching
pattern for addcarry and subborrow, even with @llvm.usub.with.overflow

## ABI

Internal functions are:
- prefixed with `_`
- Linkage: internal
- calling convention: "fast"
- mark `hot` for field arithmetic functions

Internal global constants are:
- prefixed with `_`
- Linkage: linkonce_odr (so they are merged with globals of the same name)

External functions use default convention.

We ensure parameters / return value fit in registers:
- https://llvm.org/docs/Frontend/PerformanceTips.html

TODO:
- function alignment: look into
- https://www.bazhenov.me/posts/2024-02-performance-roulette/
- https://lkml.org/lkml/2015/5/21/443
- function multiversioning
- aggregate alignment (via datalayout)

Naming convention for internal procedures:
- _big_add_u64x4
- _finalsub_mayo_u64x4 -> final substraction may overflow
- _finalsub_noo_u64x4 -> final sub no overflow
- _mod_add_u64x4
- _mod_add2x_u64x8 -> FpDbl backend
- _mty_mulur_u64x4b2 -> unreduced Montgomery multiplication (unreduced result valid iff 2 spare bits)
- _mty_mul_u64x4b1 -> reduced Montgomery multiplication (result valid iff at least 1 spare bit)
- _mty_mul_u64x4 -> reduced Montgomery multiplication
- _mty_nsqrur_u64x4b2 -> unreduced square n times
- _mty_nsqr_u64x4b1 -> reduced square n times
- _mty_sqr_u64x4 -> square
- _mty_red_u64x4 -> reduction u64x4 <- u64x8
- _pmp_red_mayo_u64x4 -> Pseudo-Mersenne Prime partial reduction may overflow (secp256k1)
- _pmp_red_noo_u64x4 -> Pseudo-Mersenne Prime partial reduction no overflow
- _secp256k1_red -> special reduction
- _fp2x_sqr2x_u64x4 -> Fp2 complex, Fp -> FpDbl lazy reduced squaring
- _fp2g_sqr2x_u64x4 -> Fp2 generic/non-complex (do we pass the mul-non-residue as parameter?)
- _fp2_sqr_u64x4 -> Fp2 (pass the mul-by-non-residue function as parameter)
- _fp4o2_mulnr1pi_u64x4 -> Fp4 over Fp2 mul with (1+i) non-residue optimization
- _fp4o2_mulbynr_u64x4
- _fp12_add_u64x4
- _fp12o4o2_mul_u64x4 -> Fp12 over Fp4 over Fp2
- _ecg1swjac_adda0_u64x4 -> Shortweierstrass G1 jacobian addition a=0
- _ecg1swjac_add_u64x4_var -> Shortweierstrass G1 jacobian vartime addition
- _ectwprj_add_u64x4 -> Twisted Edwards Projective addition

Vectorized:
- _big_add_u64x4v4
- _big_add_u32x8v8

Naming for external procedures:
- bls12_381_fp_add
- bls12_381_fr_add
- bls12_381_fp12_add
20 changes: 2 additions & 18 deletions constantine/math_compiler/codegen_nvidia.nim
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import
constantine/platforms/abis/nvidia_abi {.all.},
constantine/platforms/abis/c_abi,
constantine/platforms/llvm/[llvm, nvidia_inlineasm],
constantine/platforms/llvm/llvm,
constantine/platforms/primitives,
./ir

export
nvidia_abi, nvidia_inlineasm,
nvidia_abi,
Flag, flag, wrapOpenArrayLenType

# ############################################################
Expand Down Expand Up @@ -115,22 +115,6 @@ proc cudaDeviceInit*(deviceID = 0'i32): CUdevice =
#
# ############################################################

proc tagCudaKernel(module: ModuleRef, fn: FnDef) =
## Tag a function as a Cuda Kernel, i.e. callable from host

doAssert fn.fnTy.getReturnType().isVoid(), block:
"Kernels must not return values but function returns " & $fn.fnTy.getReturnType().getTypeKind()

let ctx = module.getContext()
module.addNamedMetadataOperand(
"nvvm.annotations",
ctx.asValueRef(ctx.metadataNode([
fn.fnImpl.asMetadataRef(),
ctx.metadataNode("kernel"),
constInt(ctx.int32_t(), 1, LlvmBool(false)).asMetadataRef()
]))
)

proc wrapInCallableCudaKernel*(module: ModuleRef, fn: FnDef) =
## Create a public wrapper of a cuda device function
##
Expand Down
216 changes: 216 additions & 0 deletions constantine/math_compiler/impl_fields_globals.nim
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Constantine
# Copyright (c) 2018-2019 Status Research & Development GmbH
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy
# Licensed and distributed under either of
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT).
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0).
# at your option. This file may not be copied, modified, or distributed except according to those terms.

import
constantine/platforms/bithacks,
constantine/platforms/llvm/llvm,
constantine/serialization/[io_limbs, codecs],
constantine/named/deriv/precompute

import ./ir

# ############################################################
#
# Metadata precomputation
#
# ############################################################

# Constantine on CPU is configured at compile-time for several properties that need to be runtime configuration GPUs:
# - word size (32-bit or 64-bit)
# - curve properties access like modulus bitsize or -1/M[0] a.k.a. m0ninv
# - constants are stored in freestanding `const`
#
# This is because it's not possible to store a BigInt[254] and a BigInt[384]
# in a generic way in the same structure, especially without using heap allocation.
# And with Nim's dead code elimination, unused curves are not compiled in.
#
# As there would be no easy way to dynamically retrieve (via an array or a table)
# const BLS12_381_modulus = ...
# const BN254_Snarks_modulus = ...
#
# - We would need a macro to properly access each constant.
# - We would need to create a 32-bit and a 64-bit version.
# - Unused curves would be compiled in the program.
#
# Note: on GPU we don't manipulate secrets hence branches and dynamic memory allocations are allowed.
#
# As GPU is a niche usage, instead we recreate the relevant `precompute` and IO procedures
# with dynamic wordsize support.

type
DynWord = uint32 or uint64
BigNum[T: DynWord] = object
bits: uint32
limbs: seq[T]

# Serialization
# ------------------------------------------------

func byteLen(bits: SomeInteger): SomeInteger {.inline.} =
## Length in bytes to serialize BigNum
(bits + 7) shr 3 # (bits + 8 - 1) div 8

func fromHex[T](a: var BigNum[T], s: string) =
var bytes = newSeq[byte](a.bits.byteLen())
bytes.paddedFromHex(s, bigEndian)

# 2. Convert canonical uint to BigNum
const wordBitwidth = sizeof(T) * 8
a.limbs.unmarshal(bytes, wordBitwidth, bigEndian)

func fromHex[T](BN: type BigNum[T], bits: uint32, s: string): BN =
const wordBitwidth = sizeof(T) * 8
let numWords = wordsRequired(bits, wordBitwidth)

result.bits = bits
result.limbs.setLen(numWords)
result.fromHex(s)

func toHexLlvm*[T](a: BigNum[T]): string =
## Conversion to big-endian hex suitable for LLVM literals
## It MUST NOT have a prefix
## This is variable-time
# 1. Convert BigInt to canonical uint
const wordBitwidth = sizeof(T) * 8
var bytes = newSeq[byte](byteLen(a.bits))
bytes.marshal(a.limbs, wordBitwidth, bigEndian)

# 2. Convert canonical uint to hex
const hexChars = "0123456789abcdef"
result = newString(2 * bytes.len)
for i in 0 ..< bytes.len:
let bi = bytes[i]
result[2*i] = hexChars[bi shr 4 and 0xF]
result[2*i+1] = hexChars[bi and 0xF]

# Checks
# ------------------------------------------------

func checkValidModulus(M: BigNum) =
const wordBitwidth = uint32(BigNum.T.sizeof() * 8)
let expectedMsb = M.bits-1 - wordBitwidth * (M.limbs.len.uint32 - 1)
let msb = log2_vartime(M.limbs[M.limbs.len-1])

doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those:\n" &
" Modulus '0x" & M.toHexLlvm() & "' is declared with " & $M.bits &
" bits but uses " & $(msb + wordBitwidth * uint32(M.limbs.len - 1)) & " bits."

# Fields metadata
# ------------------------------------------------

func negInvModWord[T](M: BigNum[T]): T =
## Returns the Montgomery domain magic constant for the input modulus:
##
## µ ≡ -1/M[0] (mod SecretWord)
##
## M[0] is the least significant limb of M
## M must be odd and greater than 2.
##
## Assuming 64-bit words:
##
## µ ≡ -1/M[0] (mod 2^64)
checkValidModulus(M)
return M.limbs[0].negInvModWord()

# ############################################################
#
# Globals in IR
#
# ############################################################

proc getModulusPtr*(asy: Assembler_LLVM, fd: FieldDescriptor): ValueRef =
let modname = fd.name & "_mod"
var M = asy.module.getGlobal(cstring modname)
if M.isNil():
M = asy.defineGlobalConstant(
name = modname,
section = fd.name,
constIntOfStringAndSize(fd.intBufTy, fd.modulus, 16),
fd.intBufTy,
alignment = 64
)
return M

proc getM0ninv*(asy: Assembler_LLVM, fd: FieldDescriptor): ValueRef =
let m0ninvname = fd.name & "_m0ninv"
var m0ninv = asy.module.getGlobal(cstring m0ninvname)
if m0ninv.isNil():
if fd.w == 32:
let M = BigNum[uint32].fromHex(fd.bits, fd.modulus)
m0ninv = asy.defineGlobalConstant(
name = m0ninvname,
section = fd.name,
constInt(fd.wordTy, M.negInvModWord()),
fd.wordTy
)
else:
let M = BigNum[uint64].fromHex(fd.bits, fd.modulus)
m0ninv = asy.defineGlobalConstant(
name = m0ninvname,
section = fd.name,
constInt(fd.wordTy, M.negInvModWord()),
fd.wordTy
)


return m0ninv

when isMainModule:
let asy = Assembler_LLVM.new("test_module", bkX86_64_Linux)
let fd = asy.ctx.configureField(
"bls12_381_fp",
381,
"1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab",
v = 1, w = 64)

discard asy.getModulusPtr(fd)
discard asy.getM0ninv(fd)

echo "========================================="
echo "LLVM IR\n"

echo asy.module
echo "========================================="

asy.module.verify(AbortProcessAction)

# --------------------------------------------
# See the assembly - note it might be different from what the JIT compiler did
initializeFullNativeTarget()

const triple = "x86_64-pc-linux-gnu"

let machine = createTargetMachine(
target = toTarget(triple),
triple = triple,
cpu = "",
features = "adx,bmi2", # TODO check the proper way to pass options
level = CodeGenLevelAggressive,
reloc = RelocDefault,
codeModel = CodeModelDefault
)

let pbo = createPassBuilderOptions()
let err = asy.module.runPasses(
"default<O3>,function-attrs,memcpyopt,sroa,mem2reg,gvn,dse,instcombine,inline,adce",
machine,
pbo
)
if not err.pointer().isNil():
writeStackTrace()
let errMsg = err.getErrorMessage()
stderr.write("\"codegenX86_64\" for module '" & astToStr(module) & "' " & $instantiationInfo() &
" exited with error: " & $cstring(errMsg) & '\n')
errMsg.dispose()
quit 1

echo "========================================="
echo "Assembly\n"

echo machine.emitTo[:string](asy.module, AssemblyFile)
echo "========================================="
Loading

0 comments on commit 569e029

Please sign in to comment.