-
-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LLVM: field addition with saturated fields (#456)
* feat(LLVM): add codegenerator for saturated field add/sub * LLVM: WIP refactor - boilerplate, linkage, assembly sections, ... * feat(llvm): try (and fail) to workaround bad modular addition codegen with inline function. * llvm: partial workaround failure around https://github.com/llvm/llvm-project/issues/102868\#issuecomment-2284935755 module inlining breaks machine instruction fusion * llvm: define our own addcarry/subborrow which properly optimize on x86 (but not ARM see llvm/llvm-project#102062) * llvm: use builtin llvm.uadd.with.overflow.iXXX to try to generate optimal code (and fail for i320 and i384 llvm/llvm-project#103717)
- Loading branch information
Showing
19 changed files
with
1,626 additions
and
685 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Cryptography primitive compiler | ||
|
||
This implements a cryptography compiler that can be used to produce | ||
- high-performance JIT code for GPUs | ||
- or assembly files, for CPUs when we want to ensure | ||
there are no side-channel regressions for secret data | ||
- or vectorized assembly file, as LLVM IR is significantly | ||
more convenient to model vector operation | ||
|
||
There are also LLVM IR => FPGA translators that might be useful | ||
in the future. | ||
|
||
## Platforms limitations | ||
|
||
- X86 cannot use dual carry-chain ADCX/ADOX easily. | ||
- no native support for clearing a flag with `xor` | ||
and keeping it clear. | ||
- inline assembly cannot use the raw ASM printer. | ||
so workflow will need to compile -> decompile. | ||
- Nvidia GPUs cannot lower types larger than 64-bit, hence we cannot use i256 for example. | ||
- AMD GPUs have a 1/4 throughput for i32 MUL compared to f32 MUL or i24 MUL | ||
- non-x86 targets may not be as optimized for matching | ||
pattern for addcarry and subborrow, even with @llvm.usub.with.overflow | ||
|
||
## ABI | ||
|
||
Internal functions are: | ||
- prefixed with `_` | ||
- Linkage: internal | ||
- calling convention: "fast" | ||
- mark `hot` for field arithmetic functions | ||
|
||
Internal global constants are: | ||
- prefixed with `_` | ||
- Linkage: linkonce_odr (so they are merged with globals of the same name) | ||
|
||
External functions use default convention. | ||
|
||
We ensure parameters / return value fit in registers: | ||
- https://llvm.org/docs/Frontend/PerformanceTips.html | ||
|
||
TODO: | ||
- function alignment: look into | ||
- https://www.bazhenov.me/posts/2024-02-performance-roulette/ | ||
- https://lkml.org/lkml/2015/5/21/443 | ||
- function multiversioning | ||
- aggregate alignment (via datalayout) | ||
|
||
Naming convention for internal procedures: | ||
- _big_add_u64x4 | ||
- _finalsub_mayo_u64x4 -> final substraction may overflow | ||
- _finalsub_noo_u64x4 -> final sub no overflow | ||
- _mod_add_u64x4 | ||
- _mod_add2x_u64x8 -> FpDbl backend | ||
- _mty_mulur_u64x4b2 -> unreduced Montgomery multiplication (unreduced result valid iff 2 spare bits) | ||
- _mty_mul_u64x4b1 -> reduced Montgomery multiplication (result valid iff at least 1 spare bit) | ||
- _mty_mul_u64x4 -> reduced Montgomery multiplication | ||
- _mty_nsqrur_u64x4b2 -> unreduced square n times | ||
- _mty_nsqr_u64x4b1 -> reduced square n times | ||
- _mty_sqr_u64x4 -> square | ||
- _mty_red_u64x4 -> reduction u64x4 <- u64x8 | ||
- _pmp_red_mayo_u64x4 -> Pseudo-Mersenne Prime partial reduction may overflow (secp256k1) | ||
- _pmp_red_noo_u64x4 -> Pseudo-Mersenne Prime partial reduction no overflow | ||
- _secp256k1_red -> special reduction | ||
- _fp2x_sqr2x_u64x4 -> Fp2 complex, Fp -> FpDbl lazy reduced squaring | ||
- _fp2g_sqr2x_u64x4 -> Fp2 generic/non-complex (do we pass the mul-non-residue as parameter?) | ||
- _fp2_sqr_u64x4 -> Fp2 (pass the mul-by-non-residue function as parameter) | ||
- _fp4o2_mulnr1pi_u64x4 -> Fp4 over Fp2 mul with (1+i) non-residue optimization | ||
- _fp4o2_mulbynr_u64x4 | ||
- _fp12_add_u64x4 | ||
- _fp12o4o2_mul_u64x4 -> Fp12 over Fp4 over Fp2 | ||
- _ecg1swjac_adda0_u64x4 -> Shortweierstrass G1 jacobian addition a=0 | ||
- _ecg1swjac_add_u64x4_var -> Shortweierstrass G1 jacobian vartime addition | ||
- _ectwprj_add_u64x4 -> Twisted Edwards Projective addition | ||
|
||
Vectorized: | ||
- _big_add_u64x4v4 | ||
- _big_add_u32x8v8 | ||
|
||
Naming for external procedures: | ||
- bls12_381_fp_add | ||
- bls12_381_fr_add | ||
- bls12_381_fp12_add |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
# Constantine | ||
# Copyright (c) 2018-2019 Status Research & Development GmbH | ||
# Copyright (c) 2020-Present Mamy André-Ratsimbazafy | ||
# Licensed and distributed under either of | ||
# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). | ||
# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). | ||
# at your option. This file may not be copied, modified, or distributed except according to those terms. | ||
|
||
import | ||
constantine/platforms/bithacks, | ||
constantine/platforms/llvm/llvm, | ||
constantine/serialization/[io_limbs, codecs], | ||
constantine/named/deriv/precompute | ||
|
||
import ./ir | ||
|
||
# ############################################################ | ||
# | ||
# Metadata precomputation | ||
# | ||
# ############################################################ | ||
|
||
# Constantine on CPU is configured at compile-time for several properties that need to be runtime configuration GPUs: | ||
# - word size (32-bit or 64-bit) | ||
# - curve properties access like modulus bitsize or -1/M[0] a.k.a. m0ninv | ||
# - constants are stored in freestanding `const` | ||
# | ||
# This is because it's not possible to store a BigInt[254] and a BigInt[384] | ||
# in a generic way in the same structure, especially without using heap allocation. | ||
# And with Nim's dead code elimination, unused curves are not compiled in. | ||
# | ||
# As there would be no easy way to dynamically retrieve (via an array or a table) | ||
# const BLS12_381_modulus = ... | ||
# const BN254_Snarks_modulus = ... | ||
# | ||
# - We would need a macro to properly access each constant. | ||
# - We would need to create a 32-bit and a 64-bit version. | ||
# - Unused curves would be compiled in the program. | ||
# | ||
# Note: on GPU we don't manipulate secrets hence branches and dynamic memory allocations are allowed. | ||
# | ||
# As GPU is a niche usage, instead we recreate the relevant `precompute` and IO procedures | ||
# with dynamic wordsize support. | ||
|
||
type | ||
DynWord = uint32 or uint64 | ||
BigNum[T: DynWord] = object | ||
bits: uint32 | ||
limbs: seq[T] | ||
|
||
# Serialization | ||
# ------------------------------------------------ | ||
|
||
func byteLen(bits: SomeInteger): SomeInteger {.inline.} = | ||
## Length in bytes to serialize BigNum | ||
(bits + 7) shr 3 # (bits + 8 - 1) div 8 | ||
|
||
func fromHex[T](a: var BigNum[T], s: string) = | ||
var bytes = newSeq[byte](a.bits.byteLen()) | ||
bytes.paddedFromHex(s, bigEndian) | ||
|
||
# 2. Convert canonical uint to BigNum | ||
const wordBitwidth = sizeof(T) * 8 | ||
a.limbs.unmarshal(bytes, wordBitwidth, bigEndian) | ||
|
||
func fromHex[T](BN: type BigNum[T], bits: uint32, s: string): BN = | ||
const wordBitwidth = sizeof(T) * 8 | ||
let numWords = wordsRequired(bits, wordBitwidth) | ||
|
||
result.bits = bits | ||
result.limbs.setLen(numWords) | ||
result.fromHex(s) | ||
|
||
func toHexLlvm*[T](a: BigNum[T]): string = | ||
## Conversion to big-endian hex suitable for LLVM literals | ||
## It MUST NOT have a prefix | ||
## This is variable-time | ||
# 1. Convert BigInt to canonical uint | ||
const wordBitwidth = sizeof(T) * 8 | ||
var bytes = newSeq[byte](byteLen(a.bits)) | ||
bytes.marshal(a.limbs, wordBitwidth, bigEndian) | ||
|
||
# 2. Convert canonical uint to hex | ||
const hexChars = "0123456789abcdef" | ||
result = newString(2 * bytes.len) | ||
for i in 0 ..< bytes.len: | ||
let bi = bytes[i] | ||
result[2*i] = hexChars[bi shr 4 and 0xF] | ||
result[2*i+1] = hexChars[bi and 0xF] | ||
|
||
# Checks | ||
# ------------------------------------------------ | ||
|
||
func checkValidModulus(M: BigNum) = | ||
const wordBitwidth = uint32(BigNum.T.sizeof() * 8) | ||
let expectedMsb = M.bits-1 - wordBitwidth * (M.limbs.len.uint32 - 1) | ||
let msb = log2_vartime(M.limbs[M.limbs.len-1]) | ||
|
||
doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those:\n" & | ||
" Modulus '0x" & M.toHexLlvm() & "' is declared with " & $M.bits & | ||
" bits but uses " & $(msb + wordBitwidth * uint32(M.limbs.len - 1)) & " bits." | ||
|
||
# Fields metadata | ||
# ------------------------------------------------ | ||
|
||
func negInvModWord[T](M: BigNum[T]): T = | ||
## Returns the Montgomery domain magic constant for the input modulus: | ||
## | ||
## µ ≡ -1/M[0] (mod SecretWord) | ||
## | ||
## M[0] is the least significant limb of M | ||
## M must be odd and greater than 2. | ||
## | ||
## Assuming 64-bit words: | ||
## | ||
## µ ≡ -1/M[0] (mod 2^64) | ||
checkValidModulus(M) | ||
return M.limbs[0].negInvModWord() | ||
|
||
# ############################################################ | ||
# | ||
# Globals in IR | ||
# | ||
# ############################################################ | ||
|
||
proc getModulusPtr*(asy: Assembler_LLVM, fd: FieldDescriptor): ValueRef = | ||
let modname = fd.name & "_mod" | ||
var M = asy.module.getGlobal(cstring modname) | ||
if M.isNil(): | ||
M = asy.defineGlobalConstant( | ||
name = modname, | ||
section = fd.name, | ||
constIntOfStringAndSize(fd.intBufTy, fd.modulus, 16), | ||
fd.intBufTy, | ||
alignment = 64 | ||
) | ||
return M | ||
|
||
proc getM0ninv*(asy: Assembler_LLVM, fd: FieldDescriptor): ValueRef = | ||
let m0ninvname = fd.name & "_m0ninv" | ||
var m0ninv = asy.module.getGlobal(cstring m0ninvname) | ||
if m0ninv.isNil(): | ||
if fd.w == 32: | ||
let M = BigNum[uint32].fromHex(fd.bits, fd.modulus) | ||
m0ninv = asy.defineGlobalConstant( | ||
name = m0ninvname, | ||
section = fd.name, | ||
constInt(fd.wordTy, M.negInvModWord()), | ||
fd.wordTy | ||
) | ||
else: | ||
let M = BigNum[uint64].fromHex(fd.bits, fd.modulus) | ||
m0ninv = asy.defineGlobalConstant( | ||
name = m0ninvname, | ||
section = fd.name, | ||
constInt(fd.wordTy, M.negInvModWord()), | ||
fd.wordTy | ||
) | ||
|
||
|
||
return m0ninv | ||
|
||
when isMainModule: | ||
let asy = Assembler_LLVM.new("test_module", bkX86_64_Linux) | ||
let fd = asy.ctx.configureField( | ||
"bls12_381_fp", | ||
381, | ||
"1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab", | ||
v = 1, w = 64) | ||
|
||
discard asy.getModulusPtr(fd) | ||
discard asy.getM0ninv(fd) | ||
|
||
echo "=========================================" | ||
echo "LLVM IR\n" | ||
|
||
echo asy.module | ||
echo "=========================================" | ||
|
||
asy.module.verify(AbortProcessAction) | ||
|
||
# -------------------------------------------- | ||
# See the assembly - note it might be different from what the JIT compiler did | ||
initializeFullNativeTarget() | ||
|
||
const triple = "x86_64-pc-linux-gnu" | ||
|
||
let machine = createTargetMachine( | ||
target = toTarget(triple), | ||
triple = triple, | ||
cpu = "", | ||
features = "adx,bmi2", # TODO check the proper way to pass options | ||
level = CodeGenLevelAggressive, | ||
reloc = RelocDefault, | ||
codeModel = CodeModelDefault | ||
) | ||
|
||
let pbo = createPassBuilderOptions() | ||
let err = asy.module.runPasses( | ||
"default<O3>,function-attrs,memcpyopt,sroa,mem2reg,gvn,dse,instcombine,inline,adce", | ||
machine, | ||
pbo | ||
) | ||
if not err.pointer().isNil(): | ||
writeStackTrace() | ||
let errMsg = err.getErrorMessage() | ||
stderr.write("\"codegenX86_64\" for module '" & astToStr(module) & "' " & $instantiationInfo() & | ||
" exited with error: " & $cstring(errMsg) & '\n') | ||
errMsg.dispose() | ||
quit 1 | ||
|
||
echo "=========================================" | ||
echo "Assembly\n" | ||
|
||
echo machine.emitTo[:string](asy.module, AssemblyFile) | ||
echo "=========================================" |
Oops, something went wrong.