From 0354d5b25a801a14ec3b715c2f5ded3fed54f804 Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Fri, 9 Aug 2024 03:08:19 +0200 Subject: [PATCH] LLVM: WIP refactor - boilerplate, linkage, assembly sections, ... --- constantine/math_compiler/README.md | 83 +++ constantine/math_compiler/codegen_nvidia.nim | 16 - .../math_compiler/impl_fields_globals.nim | 216 ++++++ .../math_compiler/impl_fields_nvidia.nim | 10 +- constantine/math_compiler/impl_fields_sat.nim | 155 +++-- constantine/math_compiler/ir.nim | 655 +++++++++++------- constantine/math_compiler/pub_fields.nim | 30 + constantine/platforms/abis/llvm_abi.nim | 88 ++- constantine/platforms/llvm/llvm.nim | 49 +- .../platforms/llvm/super_instructions.nim | 20 +- research/codegen/x86_poc.nim | 107 +-- tests/gpu/t_nvidia_fp.nim | 12 +- 12 files changed, 990 insertions(+), 451 deletions(-) create mode 100644 constantine/math_compiler/README.md create mode 100644 constantine/math_compiler/impl_fields_globals.nim create mode 100644 constantine/math_compiler/pub_fields.nim diff --git a/constantine/math_compiler/README.md b/constantine/math_compiler/README.md new file mode 100644 index 00000000..2ff81b44 --- /dev/null +++ b/constantine/math_compiler/README.md @@ -0,0 +1,83 @@ +# Cryptography primitive compiler + +This implements a cryptography compiler that can be used to produce +- high-performance JIT code for GPUs +- or assembly files, for CPUs when we want to ensure + there are no side-channel regressions for secret data +- or vectorized assembly file, as LLVM IR is significantly + more convenient to model vector operation + +There are also LLVM IR => FPGA translators that might be useful +in the future. + +## Platforms limitations + +- X86 cannot use dual carry-chain ADCX/ADOX easily. + - no native support for clearing a flag with `xor` + and keeping it clear. + - inline assembly cannot use the raw ASM printer. + so workflow will need to compile -> decompile. +- Nvidia GPUs cannot lower types larger than 64-bit, hence we cannot use i256 for example. +- AMD GPUs have a 1/4 throughput for i32 MUL compared to f32 MUL or i24 MUL +- non-x86 targets may not be as optimized for matching + pattern for addcarry and subborrow, even with @llvm.usub.with.overflow + +## ABI + +Internal functions are: +- prefixed with `_` +- Linkage: internal +- calling convention: "fast" +- mark `hot` for field arithmetic functions + +Internal global constants are: +- prefixed with `_` +- Linkage: linkonce_odr (so they are merged with globals of the same name) + +External functions use default convention. + +We ensure parameters / return value fit in registers: +- https://llvm.org/docs/Frontend/PerformanceTips.html + +TODO: +- function alignment: look into + - https://www.bazhenov.me/posts/2024-02-performance-roulette/ + - https://lkml.org/lkml/2015/5/21/443 +- function multiversioning +- aggregate alignment (via datalayout) + +Naming convention for internal procedures: +- _big_add_u64x4 +- _finalsub_mayo_u64x4 -> final substraction may overflow +- _finalsub_noo_u64x4 -> final sub no overflow +- _mod_add_u64x4 +- _mod_add2x_u64x8 -> FpDbl backend +- _mty_mulur_u64x4b2 -> unreduced Montgomery multiplication (unreduced result valid iff 2 spare bits) +- _mty_mul_u64x4b1 -> reduced Montgomery multiplication (result valid iff at least 1 spare bit) +- _mty_mul_u64x4 -> reduced Montgomery multiplication +- _mty_nsqrur_u64x4b2 -> unreduced square n times +- _mty_nsqr_u64x4b1 -> reduced square n times +- _mty_sqr_u64x4 -> square +- _mty_red_u64x4 -> reduction u64x4 <- u64x8 +- _pmp_red_mayo_u64x4 -> Pseudo-Mersenne Prime partial reduction may overflow (secp256k1) +- _pmp_red_noo_u64x4 -> Pseudo-Mersenne Prime partial reduction no overflow +- _secp256k1_red -> special reduction +- _fp2x_sqr2x_u64x4 -> Fp2 complex, Fp -> FpDbl lazy reduced squaring +- _fp2g_sqr2x_u64x4 -> Fp2 generic/non-complex (do we pass the mul-non-residue as parameter?) +- _fp2_sqr_u64x4 -> Fp2 (pass the mul-by-non-residue function as parameter) +- _fp4o2_mulnr1pi_u64x4 -> Fp4 over Fp2 mul with (1+i) non-residue optimization +- _fp4o2_mulbynr_u64x4 +- _fp12_add_u64x4 +- _fp12o4o2_mul_u64x4 -> Fp12 over Fp4 over Fp2 +- _ecg1swjac_adda0_u64x4 -> Shortweierstrass G1 jacobian addition a=0 +- _ecg1swjac_add_u64x4_var -> Shortweierstrass G1 jacobian vartime addition +- _ectwprj_add_u64x4 -> Twisted Edwards Projective addition + +Vectorized: +- _big_add_u64x4v4 +- _big_add_u32x8v8 + +Naming for external procedures: +- bls12_381_fp_add +- bls12_381_fr_add +- bls12_381_fp12_add diff --git a/constantine/math_compiler/codegen_nvidia.nim b/constantine/math_compiler/codegen_nvidia.nim index 63245f95..fdc4c393 100644 --- a/constantine/math_compiler/codegen_nvidia.nim +++ b/constantine/math_compiler/codegen_nvidia.nim @@ -115,22 +115,6 @@ proc cudaDeviceInit*(deviceID = 0'i32): CUdevice = # # ############################################################ -proc tagCudaKernel(module: ModuleRef, fn: FnDef) = - ## Tag a function as a Cuda Kernel, i.e. callable from host - - doAssert fn.fnTy.getReturnType().isVoid(), block: - "Kernels must not return values but function returns " & $fn.fnTy.getReturnType().getTypeKind() - - let ctx = module.getContext() - module.addNamedMetadataOperand( - "nvvm.annotations", - ctx.asValueRef(ctx.metadataNode([ - fn.fnImpl.asMetadataRef(), - ctx.metadataNode("kernel"), - constInt(ctx.int32_t(), 1, LlvmBool(false)).asMetadataRef() - ])) - ) - proc wrapInCallableCudaKernel*(module: ModuleRef, fn: FnDef) = ## Create a public wrapper of a cuda device function ## diff --git a/constantine/math_compiler/impl_fields_globals.nim b/constantine/math_compiler/impl_fields_globals.nim new file mode 100644 index 00000000..faac591c --- /dev/null +++ b/constantine/math_compiler/impl_fields_globals.nim @@ -0,0 +1,216 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + constantine/platforms/bithacks, + constantine/platforms/llvm/llvm, + constantine/serialization/[io_limbs, codecs], + constantine/named/deriv/precompute + +import ./ir + +# ############################################################ +# +# Metadata precomputation +# +# ############################################################ + +# Constantine on CPU is configured at compile-time for several properties that need to be runtime configuration GPUs: +# - word size (32-bit or 64-bit) +# - curve properties access like modulus bitsize or -1/M[0] a.k.a. m0ninv +# - constants are stored in freestanding `const` +# +# This is because it's not possible to store a BigInt[254] and a BigInt[384] +# in a generic way in the same structure, especially without using heap allocation. +# And with Nim's dead code elimination, unused curves are not compiled in. +# +# As there would be no easy way to dynamically retrieve (via an array or a table) +# const BLS12_381_modulus = ... +# const BN254_Snarks_modulus = ... +# +# - We would need a macro to properly access each constant. +# - We would need to create a 32-bit and a 64-bit version. +# - Unused curves would be compiled in the program. +# +# Note: on GPU we don't manipulate secrets hence branches and dynamic memory allocations are allowed. +# +# As GPU is a niche usage, instead we recreate the relevant `precompute` and IO procedures +# with dynamic wordsize support. + +type + DynWord = uint32 or uint64 + BigNum[T: DynWord] = object + bits: uint32 + limbs: seq[T] + +# Serialization +# ------------------------------------------------ + +func byteLen(bits: SomeInteger): SomeInteger {.inline.} = + ## Length in bytes to serialize BigNum + (bits + 7) shr 3 # (bits + 8 - 1) div 8 + +func fromHex[T](a: var BigNum[T], s: string) = + var bytes = newSeq[byte](a.bits.byteLen()) + bytes.paddedFromHex(s, bigEndian) + + # 2. Convert canonical uint to BigNum + const wordBitwidth = sizeof(T) * 8 + a.limbs.unmarshal(bytes, wordBitwidth, bigEndian) + +func fromHex[T](BN: type BigNum[T], bits: uint32, s: string): BN = + const wordBitwidth = sizeof(T) * 8 + let numWords = wordsRequired(bits, wordBitwidth) + + result.bits = bits + result.limbs.setLen(numWords) + result.fromHex(s) + +func toHexLlvm*[T](a: BigNum[T]): string = + ## Conversion to big-endian hex suitable for LLVM literals + ## It MUST NOT have a prefix + ## This is variable-time + # 1. Convert BigInt to canonical uint + const wordBitwidth = sizeof(T) * 8 + var bytes = newSeq[byte](byteLen(a.bits)) + bytes.marshal(a.limbs, wordBitwidth, bigEndian) + + # 2. Convert canonical uint to hex + const hexChars = "0123456789abcdef" + result = newString(2 * bytes.len) + for i in 0 ..< bytes.len: + let bi = bytes[i] + result[2*i] = hexChars[bi shr 4 and 0xF] + result[2*i+1] = hexChars[bi and 0xF] + +# Checks +# ------------------------------------------------ + +func checkValidModulus(M: BigNum) = + const wordBitwidth = uint32(BigNum.T.sizeof() * 8) + let expectedMsb = M.bits-1 - wordBitwidth * (M.limbs.len.uint32 - 1) + let msb = log2_vartime(M.limbs[M.limbs.len-1]) + + doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those:\n" & + " Modulus '0x" & M.toHexLlvm() & "' is declared with " & $M.bits & + " bits but uses " & $(msb + wordBitwidth * uint32(M.limbs.len - 1)) & " bits." + +# Fields metadata +# ------------------------------------------------ + +func negInvModWord[T](M: BigNum[T]): T = + ## Returns the Montgomery domain magic constant for the input modulus: + ## + ## µ ≡ -1/M[0] (mod SecretWord) + ## + ## M[0] is the least significant limb of M + ## M must be odd and greater than 2. + ## + ## Assuming 64-bit words: + ## + ## µ ≡ -1/M[0] (mod 2^64) + checkValidModulus(M) + return M.limbs[0].negInvModWord() + +# ############################################################ +# +# Globals in IR +# +# ############################################################ + +proc getModulusPtr*(asy: Assembler_LLVM, fd: FieldDescriptor): ValueRef = + let modname = fd.name & "_mod" + var M = asy.module.getGlobal(cstring modname) + if M.isNil(): + M = asy.defineGlobalConstant( + name = modname, + section = fd.name, + constIntOfStringAndSize(fd.intBufTy, fd.modulus, 16), + fd.intBufTy, + alignment = 64 + ) + return M + +proc getM0ninv*(asy: Assembler_LLVM, fd: FieldDescriptor): ValueRef = + let m0ninvname = fd.name & "_m0ninv" + var m0ninv = asy.module.getGlobal(cstring m0ninvname) + if m0ninv.isNil(): + if fd.w == 32: + let M = BigNum[uint32].fromHex(fd.bits, fd.modulus) + m0ninv = asy.defineGlobalConstant( + name = m0ninvname, + section = fd.name, + constInt(fd.wordTy, M.negInvModWord()), + fd.wordTy + ) + else: + let M = BigNum[uint64].fromHex(fd.bits, fd.modulus) + m0ninv = asy.defineGlobalConstant( + name = m0ninvname, + section = fd.name, + constInt(fd.wordTy, M.negInvModWord()), + fd.wordTy + ) + + + return m0ninv + +when isMainModule: + let asy = Assembler_LLVM.new("test_module", bkX86_64_Linux) + let fd = asy.ctx.configureField( + "bls12_381_fp", + 381, + "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab", + v = 1, w = 64) + + discard asy.getModulusPtr(fd) + discard asy.getM0ninv(fd) + + echo "=========================================" + echo "LLVM IR\n" + + echo asy.module + echo "=========================================" + + asy.module.verify(AbortProcessAction) + + # -------------------------------------------- + # See the assembly - note it might be different from what the JIT compiler did + initializeFullNativeTarget() + + const triple = "x86_64-pc-linux-gnu" + + let machine = createTargetMachine( + target = toTarget(triple), + triple = triple, + cpu = "", + features = "adx,bmi2", # TODO check the proper way to pass options + level = CodeGenLevelAggressive, + reloc = RelocDefault, + codeModel = CodeModelDefault + ) + + let pbo = createPassBuilderOptions() + let err = asy.module.runPasses( + "default,function-attrs,memcpyopt,sroa,mem2reg,gvn,dse,instcombine,inline,adce", + machine, + pbo + ) + if not err.pointer().isNil(): + writeStackTrace() + let errMsg = err.getErrorMessage() + stderr.write("\"codegenX86_64\" for module '" & astToStr(module) & "' " & $instantiationInfo() & + " exited with error: " & $cstring(errMsg) & '\n') + errMsg.dispose() + quit 1 + + echo "=========================================" + echo "Assembly\n" + + echo machine.emitTo[:string](asy.module, AssemblyFile) + echo "=========================================" diff --git a/constantine/math_compiler/impl_fields_nvidia.nim b/constantine/math_compiler/impl_fields_nvidia.nim index 6b034701..5843d02d 100644 --- a/constantine/math_compiler/impl_fields_nvidia.nim +++ b/constantine/math_compiler/impl_fields_nvidia.nim @@ -43,6 +43,8 @@ import # # See instruction throughput # - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions +# +# We cannot use i256 on Nvidia target: https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L244-L276 proc finalSubMayOverflow(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) = ## If a >= Modulus: r <- a-M @@ -168,8 +170,8 @@ proc field_sub_gen*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field): FnDef let t = bld.makeArray(fieldTy) let N = cm.getNumWords(field) let zero = case cm.wordSize - of size32: constInt(asy.i32_t, 0) - of size64: constInt(asy.i64_t, 0) + of w32: constInt(asy.i32_t, 0) + of w64: constInt(asy.i64_t, 0) t[0] = bld.sub_bo(a[0], b[0]) for i in 1 ..< N: @@ -261,8 +263,8 @@ proc field_mul_CIOS_sparebit_gen(asy: Assembler_LLVM, cm: CurveMetadata, field: let m0ninv = ValueRef cm.getMontgomeryNegInverse0(field) let M = (seq[ValueRef])(cm.getModulus(field)) let zero = case cm.wordSize - of size32: constInt(asy.i32_t, 0) - of size64: constInt(asy.i64_t, 0) + of w32: constInt(asy.i32_t, 0) + of w64: constInt(asy.i64_t, 0) for i in 0 ..< N: # Multiplication diff --git a/constantine/math_compiler/impl_fields_sat.nim b/constantine/math_compiler/impl_fields_sat.nim index 4910548d..2d214ba6 100644 --- a/constantine/math_compiler/impl_fields_sat.nim +++ b/constantine/math_compiler/impl_fields_sat.nim @@ -8,7 +8,7 @@ import constantine/platforms/llvm/[llvm, super_instructions], - ./ir, ./codegen_nvidia + ./ir # ############################################################ # @@ -30,15 +30,18 @@ import # # It may be suitable for Intel GPUs as the virtual ISA does support add-carry # -# It is suitable for: +# It is (theoretically) suitable for: # - ARM -# - AMD GPUs (for prototyping) +# - AMD GPUs # # The following backends have better optimizations through assembly: # - x86: access to ADOX and ADCX interleaved double-carry chain # - Nvidia: access to multiply accumulate instruction # and non-interleaved double-carry chain # +# Hardware limitations +# -------------------- +# # AMD GPUs may benefits from using 24-bit limbs # - https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/programmer-references/AMD_OpenCL_Programming_Optimization_Guide2.pdf # p2-23: @@ -58,8 +61,21 @@ import # and AMD doesn’t seem to have optimized for it. # 32-bit integer multiplication executes at around a quarter of FP32 rate, # and latency is pretty high too." +# +# Software limitations +# -------------------- +# +# Unfortunately implementing unrolled using word size is fraught with perils +# for add-carry / sub-borrow +# AMDGPU crash: https://github.com/llvm/llvm-project/issues/102058 +# ARM64 missed optim: https://github.com/llvm/llvm-project/issues/102062 +# +# and while using @llvm.usub.with.overflow.i64 allows ARM64 to solve the missing optimization +# it is also missed on AMDGPU (or nvidia) + +const SectionName = "ctt.fields" -proc finalSubMayOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array, carry: ValueRef) = +proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M, carry: ValueRef) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -69,34 +85,37 @@ proc finalSubMayOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, ## To be used when the final substraction can ## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256) - let bld = asy.builder - let fieldTy = cm.getFieldType(field) - let wordTy = cm.getWordType(field) - let scratch = bld.makeArray(fieldTy) - let M = cm.getModulus(field) - let N = M.len + let name = "_finalsub_mayo_u" & $fd.w & "x" & $fd.numWords + asy.llvmInternalFnDef(name, SectionName, asy.void_t, toTypes([r, a, M, carry])): + + let (rr, aa, MM, carry) = llvmParams - let zero_i1 = constInt(asy.i1_t, 0) - let zero = constInt(wordTy, 0) + let r = asy.asArray(rr, fd.fieldTy) + let M = asy.load2(fd.intBufTy, MM, "M") + # let aPtr = asy.asLlvmIntPtr(aa, fd.intBufTy) # Pointers are opaque in LLVM now + let a = asy.load2(fd.intBufTy, aa, "a") - # Mask: contains 0xFFFF or 0x0000 - let (_, mask) = bld.subborrow(zero, zero, carry) + # Now substract the modulus, and test a < M + # (underflow) with the last borrow. + # On x86 at least, LLVM can fuse sub and icmp into sub-with-borrow + let a_minus_M = asy.br.sub(a, M, "a_minus_M") + let borrow = asy.br.icmp(kULT, a, M, "borrow") - # Now substract the modulus, and test a < M - # (underflow) with the last borrow - var b: ValueRef - (b, scratch[0]) = bld.subborrow(a[0], M[0], zero_i1) - for i in 1 ..< N: - (b, scratch[i]) = bld.subborrow(a[i], M[i], b) + # Cases: + # No carry after a+b, no borrow after a-M -> return a-M + # carry after a+b, will borrow after a-M (last bit lost) -> return a-M + # carry after a+b, no borrow after a-M -> return a-M + # No carry after a+b, borrow after a-M -> return a + let notBorrow = asy.br.`not`(borrow, "notborrow") + let ctl = asy.br.`or`(carry, notBorrow, "needSub") + let t = asy.br.select(ctl, a_minus_M, a) - # If it underflows here, it means that it was - # smaller than the modulus and we don't need `scratch` - (b, _) = bld.subborrow(mask, zero, b) + asy.store(r, t) + asy.br.retVoid() - for i in 0 ..< N: - r[i] = bld.select(b, a[i], scratch[i]) + asy.callFn(name, [r, a, M, carry]) -proc finalSubNoOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) = +proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, M: ValueRef) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -106,61 +125,55 @@ proc finalSubNoOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r ## To be used when the modulus does not use the full bitwidth of the storing words ## (say using 255 bits for the modulus out of 256 available in words) - let bld = asy.builder - let fieldTy = cm.getFieldType(field) - let scratch = bld.makeArray(fieldTy) - let M = cm.getModulus(field) - let N = M.len + let name = "_finalsub_noo_u" & $fd.w & "x" & $fd.numWords + asy.llvmInternalFnDef(name, SectionName, asy.void_t, toTypes([r, a, M])): - # Now substract the modulus, and test a < M with the last borrow - let zero_i1 = constInt(asy.i1_t, 0) - var b: ValueRef - (b, scratch[0]) = bld.subborrow(a[0], M[0], zero_i1) - for i in 1 ..< N: - (b, scratch[i]) = bld.subborrow(a[i], M[i], b) + let (rr, aa, MM) = llvmParams - # If it underflows here a was smaller than the modulus, which is what we want - for i in 0 ..< N: - r[i] = bld.select(b, a[i], scratch[i]) + let r = asy.asArray(rr, fd.fieldTy) + let M = asy.load2(fd.intBufTy, MM, "M") + # Pointers are opaque in LLVM now + let a = asy.load2(fd.intBufTy, aa, "a") -proc field_add_gen_sat*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field): FnDef = - ## Generate an optimized modular addition kernel - ## with parameters `a, b, modulus: Limbs -> Limbs` + # Now substract the modulus, and test a < M + # (underflow) with the last borrow + # On x86 at least, LLVM can fuse sub and icmp into sub-with-borrow + let a_minus_M = asy.br.sub(a, M, "a_minus_M") + let borrow = asy.br.icmp(kULT, a, M, "borrow") + + # If it underflows here a was smaller than the modulus, which is what we want + let t = asy.br.select(borrow, a, a_minus_M) - let procName = cm.genSymbol(block: - case field - of fp: opFpAdd - of fr: opFrAdd) - let fieldTy = cm.getFieldType(field) - let pFieldTy = pointer_t(fieldTy) + asy.store(r, t) + asy.br.retVoid() - let addModTy = function_t(asy.void_t, [pFieldTy, pFieldTy, pFieldTy]) - let addModKernel = asy.module.addFunction(cstring procName, addModTy) - let blck = asy.ctx.appendBasicBlock(addModKernel, "addModSatBody") - asy.builder.positionAtEnd(blck) + asy.callFn(name, [r, a, M]) - let bld = asy.builder +proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) = + ## Generate an optimized modular addition kernel + ## with parameters `a, b, modulus: Limbs -> Limbs` - let r = bld.asArray(addModKernel.getParam(0), fieldTy) - let a = bld.asArray(addModKernel.getParam(1), fieldTy) - let b = bld.asArray(addModKernel.getParam(2), fieldTy) + let red = if fd.spareBits >= 1: "noo" + else: "mayo" + let name = "_modadd_" & red & "_u" & $fd.w & "x" & $fd.numWords + asy.llvmInternalFnDef(name, SectionName, asy.void_t, toTypes([r, a, b, M])): - let t = bld.makeArray(fieldTy) - let N = cm.getNumWords(field) + let (r, aa, bb, M) = llvmParams - var c: ValueRef - let zero = constInt(asy.i1_t, 0) + # Pointers are opaque in LLVM now + let a = asy.load2(fd.intBufTy, aa, "a") + let b = asy.load2(fd.intBufTy, bb, "b") - (c, t[0]) = bld.addcarry(a[0], b[0], zero) - for i in 1 ..< N: - (c, t[i]) = bld.addcarry(a[i], b[i], c) + let apb = asy.br.add(a, b, "a_plus_b") + let t = asy.makeArray(fd.fieldTy) + asy.store(t, apb) - if cm.getSpareBits(field) >= 1: - asy.finalSubNoOverflow(cm, field, t, t) - else: - asy.finalSubMayOverflow(cm, field, t, t, c) + if fd.spareBits >= 1: + asy.finalSubNoOverflow(fd, r, t.buf, M) + else: + let carry = asy.br.icmp(kUlt, apb, b, "overflow") + asy.finalSubMayOverflow(fd, r, t.buf, M, carry) - bld.store(r, t) - bld.retVoid() + asy.br.retVoid() - return (addModTy, addModKernel) + asy.callFn(name, [r, a, b, M]) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 23464a72..bf6abbea 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -7,12 +7,9 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import - constantine/named/algebras, - constantine/named/deriv/precompute, - constantine/math/io/io_bigints, - constantine/platforms/[primitives, bithacks], + constantine/platforms/bithacks, constantine/platforms/llvm/llvm, - constantine/serialization/[endians, codecs, io_limbs] + std/[tables, macros] # ############################################################ # @@ -22,285 +19,170 @@ import type Assembler_LLVM* = ref object - # LLVM ctx*: ContextRef module*: ModuleRef - builder*: BuilderRef - i1_t*, i32_t*, i64_t*, i128_t*, void_t*: TypeRef + br*: BuilderRef + datalayout: TargetDataRef + psize*: int32 + publicCC: CallingConvention backend*: Backend + byteOrder: ByteOrder + + # It doesn't seem possible to retrieve a function type + # from its value, so we store them here. + # If we store the type we might as well store the impl + fns: Table[string, tuple[ty: TypeRef, impl: ValueRef]] + + # Convenience + void_t*: TypeRef Backend* = enum + bkAmdGpu bkNvidiaPTX bkX86_64_Linux - FnDef* = tuple[fnTy: TypeRef, fnImpl: ValueRef] - # calling getTypeOf on a ValueRef function - # loses type information like return value type or arity - proc finalizeAssemblerLLVM(asy: Assembler_LLVM) = if not asy.isNil: - asy.builder.dispose() + asy.br.dispose() asy.module.dispose() asy.ctx.dispose() + # asy.datalayout.dispose() # unnecessary when module is cleared + +proc configure(asy: var Assembler_LLVM, backend: Backend) = + case backend + of bkAmdGpu: + asy.module.setTarget("amdgcn-amd-amdhsa") + + const datalayout1 {.used.} = + "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-" & + "i64:64-" & + "v16:16-v24:32-" & + "v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-" & + "n32:64-S32-A5-G1-ni:7" + + const datalayout2 = + "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-" & + "i64:64-" & + "v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-" & + "n32:64-S32-A5-G1-ni:7:8" + + asy.module.setDataLayout(datalayout2) + + of bkNvidiaPTX: + asy.module.setTarget("nvptx64-nvidia-cuda") + # Datalayout for NVVM IR 1.8 (CUDA 11.6) + asy.module.setDataLayout( + "e-" & "p:64:64:64-" & + "i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-" & + "f32:32:32-f64:64:64-" & + "v16:16:16-v32:32:32-v64:64:64-v128:128:128-" & + "n16:32:64") + of bkX86_64_Linux: + asy.module.setTarget("x86_64-pc-linux-gnu") + + asy.datalayout = asy.module.getDataLayout() + asy.psize = int32 asy.datalayout.getPointerSize() + asy.backend = backend + asy.byteOrder = asy.dataLayout.getEndianness() proc new*(T: type Assembler_LLVM, backend: Backend, moduleName: cstring): Assembler_LLVM = new result, finalizeAssemblerLLVM result.ctx = createContext() result.module = result.ctx.createModule(moduleName) + result.br = result.ctx.createBuilder() + result.datalayout = result.module.getDataLayout() - case backend - of bkNvidiaPTX: - result.module.setTarget("nvptx64-nvidia-cuda") - # Datalayout for NVVM IR 1.8 (CUDA 11.6) - result.module.setDataLayout("e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64") - of bkX86_64_Linux: - {.warning : "The x86 LLVM backend is incomplete and for research purposes only".} - result.module.setTarget("x86_64-pc-linux-gnu") - - result.builder = result.ctx.createBuilder() - result.i1_t = result.ctx.int1_t() - result.i32_t = result.ctx.int32_t() - result.i64_t = result.ctx.int64_t() - result.i128_t = result.ctx.int128_t() result.void_t = result.ctx.void_t() - result.backend = backend + + result.configure(backend) # ############################################################ # -# Metadata precomputation +# Syntax Sugar # # ############################################################ -# Constantine on CPU is configured at compile-time for several properties that need to be runtime configuration GPUs: -# - word size (32-bit or 64-bit) -# - curve properties access like modulus bitsize or -1/M[0] a.k.a. m0ninv -# - constants are stored in freestanding `const` -# -# This is because it's not possible to store a BigInt[254] and a BigInt[384] -# in a generic way in the same structure, especially without using heap allocation. -# And with Nim's dead code elimination, unused curves are not compiled in. -# -# As there would be no easy way to dynamically retrieve (via an array or a table) -# const BLS12_381_modulus = ... -# const BN254_Snarks_modulus = ... -# -# - We would need a macro to properly access each constant. -# - We would need to create a 32-bit and a 64-bit version. -# - Unused curves would be compiled in the program. -# -# Note: on GPU we don't manipulate secrets hence branches and dynamic memory allocations are allowed. -# -# As GPU is a niche usage, instead we recreate the relevant `precompute` and IO procedures -# with dynamic wordsize support. - -type - DynWord = uint32 or uint64 - BigNum[T: DynWord] = object - bits: uint32 - limbs: seq[T] +func i1*(asy: Assembler_LLVM, v: SomeInteger): ValueRef = + constInt(asy.ctx.int1_t(), v) -# Serialization -# ------------------------------------------------ +func i32*(asy: Assembler_LLVM, v: SomeInteger): ValueRef = + constInt(asy.ctx.int32_t(), v) -func byteLen(bits: SomeInteger): SomeInteger {.inline.} = - ## Length in bytes to serialize BigNum - (bits + 7) shr 3 # (bits + 8 - 1) div 8 +# ############################################################ +# +# Intermediate Representation +# +# ############################################################ -func wordsRequired(bits, wordBitwidth: SomeInteger): SomeInteger {.inline.} = +func wordsRequired*(bits, wordBitwidth: SomeInteger): SomeInteger {.inline.} = ## Compute the number of limbs required ## from the announced bit length - debug: doAssert wordBitwidth == 32 or wordBitwidth == 64 # Power of 2 + doAssert wordBitwidth == 32 or wordBitwidth == 64 # Power of 2 (bits + wordBitwidth - 1) shr log2_vartime(uint32 wordBitwidth) # 5x to 55x faster than dividing by wordBitwidth -func fromHex[T](a: var BigNum[T], s: string) = - var bytes = newSeq[byte](a.bits.byteLen()) - bytes.paddedFromHex(s, bigEndian) - - # 2. Convert canonical uint to BigNum - const wordBitwidth = sizeof(T) * 8 - a.limbs.unmarshal(bytes, wordBitwidth, bigEndian) - -func fromHex[T](BN: type BigNum[T], bits: uint32, s: string): BN = - const wordBitwidth = sizeof(T) * 8 - let numWords = wordsRequired(bits, wordBitwidth) - - result.bits = bits - result.limbs.setLen(numWords) - result.fromHex(s) - -func toHex[T](a: BigNum[T]): string = - ## Conversion to big-endian hex - ## This is variable-time - # 1. Convert BigInt to canonical uint - const wordBitwidth = sizeof(T) * 8 - var bytes = newSeq[byte](byteLen(a.bits)) - bytes.marshal(a.limbs, wordBitwidth, bigEndian) - - # 2 Convert canonical uint to hex - return bytes.toHex() - -# Checks -# ------------------------------------------------ +type + FieldDescriptor* = object + name*: string + modulus*: string # Modulus as Big-Endian uppercase hex, NOT prefixed with 0x + # primeKind*: PrimeKind + + # Word: i32, i64 but can also be v4i32, v16i32 ... + wordTy*: TypeRef + word2xTy*: TypeRef # Double the word size + v*, w*: uint32 + numWords*: uint32 + zero*, zero_i1*: ValueRef + intBufTy*: TypeRef # int type, multiple of the word size, that can store the field elements + # For example a 381 bit field is stored in 384-bit ints (whether on 32 or 64-bit platforms) + + # Field metadata + fieldTy*: TypeRef + bits*: uint32 + spareBits*: uint8 -func checkValidModulus(M: BigNum) = - const wordBitwidth = uint32(BigNum.T.sizeof() * 8) - let expectedMsb = M.bits-1 - wordBitwidth * (M.limbs.len.uint32 - 1) - let msb = log2_vartime(M.limbs[M.limbs.len-1]) +proc configureField*(ctx: ContextRef, + name: string, + modBits: int, modulus: string, + v, w: int): FieldDescriptor = + ## Configure a field descriptor with: + ## - v: vector length + ## - w: base word size in bits + ## - a `modulus` of bitsize `modBits` + ## + ## - Name is a prefix for example + ## `mycurve_fp_` - doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those:\n" & - " Modulus '" & M.toHex() & "' is declared with " & $M.bits & - " bits but uses " & $(msb + wordBitwidth * uint32(M.limbs.len - 1)) & " bits." + let v = uint32 v + let w = uint32 w + let modBits = uint32 modBits -# Fields metadata -# ------------------------------------------------ + result.name = name + result.modulus = modulus -func negInvModWord[T](M: BigNum[T]): T = - ## Returns the Montgomery domain magic constant for the input modulus: - ## - ## µ ≡ -1/M[0] (mod SecretWord) - ## - ## M[0] is the least significant limb of M - ## M must be odd and greater than 2. - ## - ## Assuming 64-bit words: - ## - ## µ ≡ -1/M[0] (mod 2^64) - checkValidModulus(M) - return M.limbs[0].negInvModWord() + doAssert v == 1, "At the moment SIMD vectorization is not supported." + result.v = v + result.w = w -# ############################################################ -# -# Intermediate Representation -# -# ############################################################ + result.numWords = wordsRequired(modBits, w) + result.wordTy = ctx.int_t(w) + result.word2xTy = ctx.int_t(w+w) + result.zero = constInt(result.wordTy, 0) + result.zero_i1 = constInt(ctx.int1_t(), 0) -type - WordSize* = enum - size32 - size64 - - Field* = enum - fp - fr - - FieldConst* = object - wordTy: TypeRef - fieldTy: TypeRef - modulus*: seq[ConstValueRef] - m0ninv*: ConstValueRef - bits*: uint32 - spareBits*: uint8 + let next_multiple_wordsize = result.numWords * w + result.intBufTy = ctx.int_t(next_multiple_wordsize) - CurveMetadata* = object - curve*: Algebra - prefix*: string - wordSize*: WordSize - fp*: FieldConst - fr*: FieldConst - - Opcode* = enum - opFpAdd = "fp_add" - opFrAdd = "fr_add" - opFpSub = "fp_sub" - opFrSub = "fr_sub" - opFpMul = "fp_mul" - opFrMul = "fr_mul" - opFpMulSkipFinalSub = "fp_mul_skip_final_sub" - opFrMulSkipFinalSub = "fr_mul_skip_final_sub" - -proc setFieldConst(fc: var FieldConst, ctx: ContextRef, wordSize: WordSize, modBits: uint32, modulus: string) = - let wordTy = case wordSize - of size32: ctx.int32_t() - of size64: ctx.int64_t() - - let wordBitwidth = case wordSize - of size32: 32'u32 - of size64: 64'u32 - - let numWords = wordsRequired(modBits, wordBitwidth) - - fc.wordTy = wordTy - fc.fieldTy = array_t(wordTy, numWords) - - case wordSize - of size32: - let m = BigNum[uint32].fromHex(modBits, modulus) - fc.modulus.setlen(m.limbs.len) - for i in 0 ..< m.limbs.len: - fc.modulus[i] = ctx.int32_t().constInt(m.limbs[i]) - - fc.m0ninv = ctx.int32_t().constInt(m.negInvModWord()) - - of size64: - let m = BigNum[uint64].fromHex(modBits, modulus) - fc.modulus.setlen(m.limbs.len) - for i in 0 ..< m.limbs.len: - fc.modulus[i] = ctx.int64_t().constInt(m.limbs[i]) - - fc.m0ninv = ctx.int64_t().constInt(m.negInvModWord()) - - debug: doAssert numWords == fc.modulus.len.uint32 - fc.bits = modBits - fc.spareBits = uint8(numWords*wordBitwidth - modBits) - -proc init*( - C: type CurveMetadata, ctx: ContextRef, - prefix: string, wordSize: WordSize, - fpBits: uint32, fpMod: string, - frBits: uint32, frMod: string): CurveMetadata = - - result = C(prefix: prefix, wordSize: wordSize) - result.fp.setFieldConst(ctx, wordSize, fpBits, fpMod) - result.fr.setFieldConst(ctx, wordSize, frBits, frMod) - -proc genSymbol*(cm: CurveMetadata, opcode: Opcode): string {.inline.} = - cm.prefix & - (if cm.wordSize == size32: "32b_" else: "64b_") & - $opcode - -func getFieldType*(cm: CurveMetadata, field: Field): TypeRef {.inline.} = - if field == fp: - return cm.fp.fieldTy - else: - return cm.fr.fieldTy + result.fieldTy = array_t(result.wordTy, result.numWords) + result.bits = modBits + result.spareBits = uint8(next_multiple_wordsize - modBits) -func getWordType*(cm: CurveMetadata, field: Field): TypeRef {.inline.} = - if field == fp: - return cm.fp.wordTy - else: - return cm.fr.wordTy - -func getNumWords*(cm: CurveMetadata, field: Field): int {.inline.} = - case field - of fp: - return cm.fp.modulus.len - of fr: - return cm.fr.modulus.len - -func getModulus*(cm: CurveMetadata, field: Field): lent seq[ConstValueRef] {.inline.} = - case field - of fp: - return cm.fp.modulus - of fr: - return cm.fr.modulus - -func getMontgomeryNegInverse0*(cm: CurveMetadata, field: Field): lent ConstValueRef {.inline.} = - case field - of fp: - return cm.fp.m0ninv - of fr: - return cm.fr.m0ninv - -func getSpareBits*(cm: CurveMetadata, field: Field): uint8 {.inline.} = - if field == fp: - return cm.fp.sparebits - else: - return cm.fr.sparebits +proc wordTy*(fd: FieldDescriptor, value: SomeInteger) = + constInt(fd.wordTy, value) # ############################################################ # -# Syntax Sugar +# Aggregate Types # # ############################################################ @@ -315,34 +197,34 @@ func getSpareBits*(cm: CurveMetadata, field: Field): uint8 {.inline.} = type Array* = object builder: BuilderRef - p: ValueRef + buf*: ValueRef arrayTy: TypeRef elemTy: TypeRef int32_t: TypeRef -proc asArray*(builder: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef): Array = +proc asArray*(asy: Assembler_LLVM, arrayPtr: ValueRef, arrayTy: TypeRef): Array = Array( - builder: builder, - p: arrayPtr, + builder: asy.br, + buf: arrayPtr, arrayTy: arrayTy, elemTy: arrayTy.getElementType(), int32_t: arrayTy.getContext().int32_t() ) -proc makeArray*(builder: BuilderRef, arrayTy: TypeRef): Array = +proc makeArray*(asy: Assembler_LLVM, arrayTy: TypeRef): Array = Array( - builder: builder, - p: builder.alloca(arrayTy), + builder: asy.br, + buf: asy.br.alloca(arrayTy), arrayTy: arrayTy, elemTy: arrayTy.getElementType(), int32_t: arrayTy.getContext().int32_t() ) -proc makeArray*(builder: BuilderRef, elemTy: TypeRef, len: uint32): Array = +proc makeArray*(asy: Assembler_LLVM, elemTy: TypeRef, len: uint32): Array = let arrayTy = array_t(elemTy, len) Array( - builder: builder, - p: builder.alloca(arrayTy), + builder: asy.br, + buf: asy.br.alloca(arrayTy), arrayTy: arrayTy, elemTy: elemTy, int32_t: arrayTy.getContext().int32_t() @@ -357,6 +239,273 @@ proc `[]=`*(a: Array, index: SomeInteger, val: ValueRef) {.inline.}= let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.p, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)]) a.builder.store(val, pelem) -proc store*(builder: BuilderRef, dst: Array, src: Array) {.inline.}= - let v = builder.load2(src.arrayTy, src.p) - builder.store(v, dst.p) +proc store*(asy: Assembler_LLVM, dst: Array, src: Array) {.inline.}= + let v = asy.br.load2(src.arrayTy, src.buf) + asy.br.store(v, dst.buf) + +proc store*(asy: Assembler_LLVM, dst: Array, src: ValueRef) {.inline.}= + ## Heterogeneous store of i256 into 4xuint64 + doAssert asy.byteOrder == kLittleEndian + asy.br.store(src, dst.buf) + +# Conversion to native LLVM int +# ------------------------------- + +proc asLlvmIntPtr*(asy: Assembler_LLVM, a: Array): ValueRef = + doAssert asy.byteOrder == kLittleEndian, "Only little-endian machines are supported at the moment." + let bits = asy.datalayout.getSizeInBits(a.arrayTy) + let pInt = pointer_t(asy.ctx.int_t(uint32 bits)) + asy.br.bitcast(a.buf, pInt) + +proc asLlvmIntPtr*(asy: Assembler_LLVM, v: ValueRef, ty: TypeRef): ValueRef = + doAssert asy.byteOrder == kLittleEndian, "Only little-endian machines are supported at the moment." + let pInt = pointer_t(ty) + asy.br.bitcast(v, pInt) + +# ############################################################ +# +# Globals +# +# ############################################################ + +proc loadGlobal*(asy: Assembler_LLVM, name: string): ValueRef = + let g = asy.module.getGlobal(cstring name) + doAssert not result.isNil(), "The global '" & name & "' has not been declared in the module" + let ty = result.getTypeOf() + return asy.br.load2(ty, g, name = "g") + +proc defineGlobalConstant*( + asy: Assembler_LLVM, + name, section: string, + value: ValueRef, + ty: TypeRef, alignment = -1): ValueRef = + ## Declare a global constant + ## name: The name of the constant + ## section: globals are kept near each other in memory to improve locality + ## and avoid page-faults + ## an alignment of -1 leaves it at default for the ISA. + ## Otherwise configure the alignment in bytes. + ## + ## Return a pointer to the global + let g = asy.module.addGlobal(ty, cstring name) + g.setGlobal(value) + if alignment > 0: + g.setAlignment(cuint alignment) + # We intentionally keep globals internal: + # - for performance, this may avoids a translation table, + # they also can be inlined. + # - for forward-compatibility, for example to expose the modulus + # a function can handle non-matching in internal representation + # for example if we want to have different endianness of words on bigEndian machine. + # g.setLinkage(linkInternal) + g.setImmutable() + + # Group related globals in the same section + # This doesn't prevent globals from being optimized away + # if they are fully inlined or unused. + # This has the following benefits: + # - They might all be loaded in memory if they share a cacheline + # - If a section is unused, it can be garbage collected by the linker + g.setSection(cstring("ctt." & section & ".constants")) + return g + +# ############################################################ +# +# ISA configuration +# +# ############################################################ + +proc tagCudaKernel(asy: Assembler_LLVM, fn: ValueRef) = + ## Tag a function as a Cuda Kernel, i.e. callable from host + + let returnTy = fn.getTypeOf().getReturnType() + doAssert returnTy.isVoid(), block: + "Kernels must not return values but function returns " & $returnTy.getTypeKind() + + asy.module.addNamedMetadataOperand( + "nvvm.annotations", + asy.ctx.asValueRef(asy.ctx.metadataNode([ + fn.asMetadataRef(), + asy.ctx.metadataNode("kernel"), + asy.i32(1).asMetadataRef() + ])) + ) + +proc setPublic(asy: Assembler_LLVM, fn: ValueRef) = + case asy.backend + of bkAmdGpu: fn.setCallingConvention(AMDGPU_KERNEL) + of bkNvidiaPtx: asy.tagCudaKernel(fn) + else: discard + +# ############################################################ +# +# Function Definition and calling convention +# +# ############################################################ + +# Most architectures can pass up to 4 or 6 arguments directly into registers +# And we allow LLVM to use the best calling convention possible with "Fast". +# +# Recommendation: +# https://llvm.org/docs/Frontend/PerformanceTips.html +# +# Avoid creating values of aggregate types (i.e. structs and arrays). +# In particular, avoid loading and storing them, +# or manipulating them with insertvalue and extractvalue instructions. +# Instead, only load and store individual fields of the aggregate. +# +# There are some exceptions to this rule: +# - It is fine to use values of aggregate type in global variable initializers. +# - It is fine to return structs, if this is done to represent the return of multiple values in registers. +# - It is fine to work with structs returned by LLVM intrinsics, such as the with.overflow family of intrinsics. +# - It is fine to use aggregate types without creating values. For example, they are commonly used in getelementptr instructions or attributes like sret. +# +# Furthermore for aggregate types like struct we need to check the number of elements +# - https://groups.google.com/g/llvm-dev/c/CafdpEzOEp0 +# - https://stackoverflow.com/questions/27386912/prevent-clang-from-expanding-arguments-that-are-aggregate-types +# - https://people.freebsd.org/~obrien/amd64-elf-abi.pdf +# Though that might be overkill for functions tagged 'internal' linkage and 'Fast' CC +# +# Hopefully the compiler will remove the unnecessary lod/store/register movement, especially when inlining. + +proc toTypes*[N: static int](v: array[N, ValueRef]): array[N, TypeRef] = + for i in 0 ..< v.len: + result[i] = v[i].getTypeOf() + +proc wrapTypesForFnCall[N: static int]( + asy: AssemblerLLVM, + paramTypes: array[N, TypeRef] + ): tuple[wrapped, src: array[N, TypeRef]] = + ## Wrap parameters that would need more than 3x registers + ## into a pointer. + ## There are 2 such cases: + ## - An array/struct of more than 3 elements, for example 4x uint32 + ## - A type larger than 3x the pointer size, for example 4x uint64 + ## Vectors are passed by special SIMD registers + ## + ## Due to LLVM opaque pointers, we return the wrapped and src types + + for i in 0 ..< paramTypes.len: + let ty = paramTypes[i] + let tk = ty.getTypeKind() + if tk in {tkVector, tkScalableVector}: + ## There are special SIMD registers for vectors + result.wrapped[i] = paramTypes[i] + result.src[i] = paramTypes[i] + elif asy.datalayout.getAbiSize(ty).int32 > 3*asy.psize: + result.wrapped[i] = pointer_t(paramTypes[i]) + result.src[i] = paramTypes[i] + else: + case tk + of tkArray: + if ty.getArrayLength() >= 3: + result.wrapped[i] = pointer_t(paramTypes[i]) + result.src[i] = paramTypes[i] + else: + result.wrapped[i] = paramTypes[i] + result.src[i] = paramTypes[i] + of tkStruct: + if ty.getNumElements() >= 3: + result.wrapped[i] = pointer_t(paramTypes[i]) + result.src[i] = paramTypes[i] + else: + result.wrapped[i] = paramTypes[i] + result.src[i] = paramTypes[i] + else: + result.wrapped[i] = paramTypes[i] + result.src[i] = paramTypes[i] + +macro unpackParams[N: static int]( + br: BuilderRef, + paramsTys: tuple[wrapped, src: array[N, TypeRef]]): untyped = + ## Unpack function parameters. + ## + ## The new function basic block MUST be setup before calling unpackParams. + ## + ## In the future we may automatically unwrap types. + + result = nnkPar.newTree() + for i in 0 ..< N: + result.add quote do: + # let tySrc = `paramsTys`.src[`i`] + # let tyCC = `paramsTys`.wrapped[`i`] + let fn = `br`.getCurrentFunction() + fn.getParam(uint32 `i`) + +template llvmFnDef[N: static int]( + asy: Assembler_LLVM, + name, sectionName: string, + returnType: TypeRef, + paramTypes: array[N, TypeRef], + internal: bool, + body: untyped) = + ## This setups common prologue to implement a function in LLVM + ## Function parameters are available with the `llvmParams` magic variable + + let paramsTys = asy.wrapTypesForFnCall(paramTypes) + + var fn = asy.module.getFunction(cstring name) + if fn.pointer.isNil(): + var savedLoc = asy.br.getInsertBlock() + + let fnTy = function_t(returnType, paramsTys.wrapped) + fn = asy.module.addFunction(cstring name, fnTy) + + asy.fns[name] = (fnTy, fn) + + let blck = asy.ctx.appendBasicBlock(fn) + asy.br.positionAtEnd(blck) + + if savedLoc.pointer.isNil(): + # We're installing the first function + # of the call tree, return to its basic block + savedLoc = blck + + let llvmParams {.inject.} = unpackParams(asy.br, paramsTys) + body + + if internal: + fn.setCallingConvention(Fast) + fn.setLinkage(linkInternal) + else: + asy.setPublic(fn) + fn.setSection(sectionName) + + asy.br.positionAtEnd(savedLoc) + +template llvmInternalFnDef*[N: static int]( + asy: Assembler_LLVM, + name, sectionName: string, + returnType: TypeRef, + paramTypes: array[N, TypeRef], + body: untyped) = + llvmFnDef(asy, name, sectionName, returnType, paramTypes, internal = true, body) + +template llvmPublicFnDef*[N: static int]( + asy: Assembler_LLVM, + name, sectionName: string, + returnType: TypeRef, + paramTypes: array[N, TypeRef], + body: untyped) = + llvmFnDef(asy, name, sectionName, returnType, paramTypes, internal = false, body) + +proc callFn*( + asy: Assembler_LLVM, + name: string, + params: openArray[ValueRef]): ValueRef {.discardable.} = + + if asy.fns[name].ty.getReturnType().getTypeKind() == tkVoid: + asy.br.call2(asy.fns[name].ty, asy.fns[name].impl, params) + else: + asy.br.call2(asy.fns[name].ty, asy.fns[name].impl, params, cstring(name)) + +# ############################################################ +# +# Forward to Builder +# +# ############################################################ + +# {.experimental: "dotOperators".} dos not seem to work within templates?macros + +template load2*(asy: Assembler_LLVM, ty: TypeRef, `ptr`: ValueRef, name: cstring = ""): ValueRef = + asy.br.load2(ty, `ptr`, name) diff --git a/constantine/math_compiler/pub_fields.nim b/constantine/math_compiler/pub_fields.nim new file mode 100644 index 00000000..dc8c7126 --- /dev/null +++ b/constantine/math_compiler/pub_fields.nim @@ -0,0 +1,30 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + constantine/platforms/llvm/llvm, + ./ir, + ./impl_fields_globals, + ./impl_fields_sat + +proc genFpAdd*(asy: Assembler_LLVM, fd: FieldDescriptor): string = + ## Generate a public field addition proc + ## with signature + ## void name(FieldType r, FieldType a, FieldType b) + ## with r the result and a, b the operants + ## and return the corresponding name to call it + + let name = fd.name & "_add" + asy.llvmPublicFnDef(name, fd.name, asy.void_t, [fd.fieldTy, fd.fieldTy, fd.fieldTy]): + let M = asy.getModulusPtr(fd) + + let (r, a, b) = llvmParams + asy.modadd(fd, r, a, b, M) + asy.br.retVoid() + + return name diff --git a/constantine/platforms/abis/llvm_abi.nim b/constantine/platforms/abis/llvm_abi.nim index 974e29c1..3a0da543 100644 --- a/constantine/platforms/abis/llvm_abi.nim +++ b/constantine/platforms/abis/llvm_abi.nim @@ -38,6 +38,7 @@ type ContextRef* = distinct pointer ModuleRef* = distinct pointer TargetRef* = distinct pointer + TargetDataRef* = distinct pointer ExecutionEngineRef* = distinct pointer TargetMachineRef* = distinct pointer PassBuilderOptionsRef* = distinct pointer @@ -186,19 +187,33 @@ type CodeGenFileType* {.size: sizeof(cint).} = enum AssemblyFile, ObjectFile - TargetDataRef* = distinct pointer - TargetLibraryInfoRef* = distinct pointer - # "" proc createTargetMachine*( target: TargetRef, triple, cpu, features: cstring, level: CodeGenOptLevel, reloc: RelocMode, codeModel: CodeModel): TargetMachineRef {.importc: "LLVMCreateTargetMachine".} proc dispose*(m: TargetMachineRef) {.importc: "LLVMDisposeTargetMachine".} -proc createTargetDataLayout*(t: TargetMachineRef): TargetDataRef {.importc: "LLVMCreateTargetDataLayout".} +proc getDataLayout*(t: TargetMachineRef): TargetDataRef {.importc: "LLVMCreateTargetDataLayout".} +proc getDataLayout*(module: ModuleRef): TargetDataRef {.importc: "LLVMGetModuleDataLayout".} proc dispose*(m: TargetDataRef) {.importc: "LLVMDisposeTargetData".} proc setDataLayout*(module: ModuleRef, dataLayout: TargetDataRef) {.importc: "LLVMSetModuleDataLayout".} +proc getPointerSize*(datalayout: TargetDataRef): cuint {.importc: "LLVMPointerSize".} +proc getSizeInBits*(datalayout: TargetDataRef, ty: TypeRef): culonglong {.importc: "LLVMSizeOfTypeInBits".} + ## Computes the size of a type in bits for a target. +proc getStoreSize*(datalayout: TargetDataRef, ty: TypeRef): culonglong {.importc: "LLVMStoreSizeOfType".} + ## Computes the storage size of a type in bytes for a target. +proc getAbiSize*(datalayout: TargetDataRef, ty: TypeRef): culonglong {.importc: "LLVMABISizeOfType".} + ## Computes the ABI size of a type in bytes for a target. + +type + ByteOrder {.size: sizeof(cint).} = enum + kBigEndian + kLittleEndian + +proc getEndianness*(datalayout: TargetDataref): ByteOrder {.importc: "LLVMByteOrder".} + + proc targetMachineEmitToFile*(t: TargetMachineRef, m: ModuleRef, fileName: cstring, codegen: CodeGenFileType, errorMessage: var LLVMstring): LLVMBool {.importc: "LLVMTargetMachineEmitToFile".} proc targetMachineEmitToMemoryBuffer*(t: TargetMachineRef, m: ModuleRef, @@ -282,10 +297,14 @@ proc struct_t*( elemTypes: openArray[TypeRef], packed: LlvmBool): TypeRef {.wrapOpenArrayLenType: cuint, importc: "LLVMStructTypeInContext".} proc array_t*(elemType: TypeRef, elemCount: uint32): TypeRef {.importc: "LLVMArrayType".} +proc vector_t*(elemType: TypeRef, elemCount: uint32): TypeRef {.importc: "LLVMVectorType".} + ## Create a SIMD vector type (for SSE, AVX or Neon for example) proc pointerType(elementType: TypeRef; addressSpace: cuint): TypeRef {.used, importc: "LLVMPointerType".} proc getElementType*(arrayOrVectorTy: TypeRef): TypeRef {.importc: "LLVMGetElementType".} +proc getArrayLength*(arrayTy: TypeRef): uint64 {.importc: "LLVMGetArrayLength2".} +proc getNumElements*(structTy: TypeRef): cuint {.importc: "LLVMCountStructElementTypes".} # Functions # ------------------------------------------------------------ @@ -537,6 +556,44 @@ type # The highest possible ID. Must be some 2^k - 1. MaxID = 1023 +type + Linkage {.size: sizeof(cint).} = enum + # https://web.archive.org/web/20240224034505/https://bluesadi.me/2024/01/05/Linkage-types-in-LLVM/ + # Weak linkage means unreferenced globals may not be discarded when linking. + # + # Also relevant: https://stackoverflow.com/a/55599037 + # The necessity of making code relocatable in order allow shared objects to be loaded a different addresses + # in different process means that statically allocated variables, + # whether they have global or local scope, + # can't be accessed with directly with a single instruction on most architectures. + # The only exception I know of is the 64-bit x86 architecture, as you see above. + # It supports memory operands that are both PC-relative and have large 32-bit displacements + # that can reach any variable defined in the same component. + linkExternal, ## Externally visible function + linkAvailableExternally, ## no description + linkOnceAny, ## Keep one copy of function when linking (inline) + linkOnceODR, ## Same, but only replaced by something equivalent. (ODR: one definition rule) + linkOnceODRAutoHide, ## Obsolete + linkWeakAny, ## Keep one copy of function when linking (weak) + linkWeakODR, ## Same, but only replaced by something equivalent. + linkAppending, ## Special purpose, only applies to global arrays + linkInternal, ## Rename collisions when linking (static functions) + linkPrivate, ## Like Internal, but omit from symbol table + linkDLLImport, ## Obsolete + linkDLLExport, ## Obsolete + linkExternalWeak, ## ExternalWeak linkage description + linkGhost, ## Obsolete + linkCommon, ## Tentative definitions + linkLinkerPrivate, ## Like Private, but linker removes. + linkLinkerPrivateWeak ## Like LinkerPrivate, but is weak. + + Visibility {.size: sizeof(cint).} = enum + # Note: Function with internal or private linkage must have default visibility + visDefault + visHidden + visProtected + + proc function_t*( returnType: TypeRef, paramTypes: openArray[TypeRef], @@ -546,9 +603,15 @@ proc addFunction*(m: ModuleRef, name: cstring, ty: TypeRef): ValueRef {.importc: ## Declare a function `name` in a module. ## Returns a handle to specify its instructions +proc getFunction*(m: ModuleRef, name: cstring): ValueRef {.importc: "LLVMGetNamedFunction".} + ## Get a function by name from the curent module. + ## Return nil if not found. + proc getReturnType*(functionTy: TypeRef): TypeRef {.importc: "LLVMGetReturnType".} proc countParamTypes*(functionTy: TypeRef): uint32 {.importc: "LLVMCountParamTypes".} +proc getCalledFunctionType*(fn: ValueRef): TypeRef {.importc: "LLVMGetCalledFunctionType".} + proc getCallingConvention*(function: ValueRef): CallingConvention {.importc: "LLVMGetFunctionCallConv".} proc setCallingConvention*(function: ValueRef, cc: CallingConvention) {.importc: "LLVMSetFunctionCallConv".} @@ -560,6 +623,16 @@ proc setCallingConvention*(function: ValueRef, cc: CallingConvention) {.importc: # {.push header: "".} +proc getGlobal*(module: ModuleRef, name: cstring): ValueRef {.importc: "LLVMGetNamedGlobal".} +proc addGlobal*(module: ModuleRef, ty: TypeRef, name: cstring): ValueRef {.importc: "LLVMAddGlobal".} +proc setGlobal*(globalVar: ValueRef, constantVal: ValueRef) {.importc: "LLVMSetInitializer".} +proc setImmutable*(globalVar: ValueRef, immutable = LlvmBool(true)) {.importc: "LLVMSetGlobalConstant".} + +proc setLinkage*(global: ValueRef, linkage: Linkage) {.importc: "LLVMSetLinkage".} +proc setVisibility*(global: ValueRef, vis: Visibility) {.importc: "LLVMSetVisibility".} +proc setAlignment*(v: ValueRef, bytes: cuint) {.importc: "LLVMSetAlignment".} +proc setSection*(global: ValueRef, section: cstring) {.importc: "LLVMSetSection".} + proc getTypeOf*(v: ValueRef): TypeRef {.importc: "LLVMTypeOf".} proc getValueName2(v: ValueRef, rLen: var csize_t): cstring {.used, importc: "LLVMGetValueName2".} ## Returns the name of a valeu if it exists. @@ -578,6 +651,9 @@ proc toLLVMstring(v: ValueRef): LLVMstring {.used, importc: "LLVMPrintValueToStr # https://llvm.org/doxygen/group__LLVMCCoreValueConstant.html proc constInt(ty: TypeRef, n: culonglong, signExtend: LlvmBool): ValueRef {.used, importc: "LLVMConstInt".} +proc constIntOfArbitraryPrecision(ty: TypeRef, numWords: cuint, words: ptr uint64): ValueRef {.used, importc: "LLVMConstIntOfArbitraryPrecision".} +proc constIntOfStringAndSize(ty: TypeRef, text: openArray[char], radix: uint8): ValueRef {.used, importc: "LLVMConstIntOfStringAndSize".} + proc constReal*(ty: TypeRef, n: cdouble): ValueRef {.importc: "LLVMConstReal".} proc constNull*(ty: TypeRef): ValueRef {.importc: "LLVMConstNull".} @@ -622,7 +698,7 @@ type # Instantiation # ------------------------------------------------------------ -proc appendBasicBlock*(ctx: ContextRef, fn: ValueRef, name: cstring): BasicBlockRef {.importc: "LLVMAppendBasicBlockInContext".} +proc appendBasicBlock*(ctx: ContextRef, fn: ValueRef, name: cstring = ""): BasicBlockRef {.importc: "LLVMAppendBasicBlockInContext".} ## Append a basic block to the end of a function proc createBuilder*(ctx: ContextRef): BuilderRef {.importc: "LLVMCreateBuilderInContext".} @@ -716,7 +792,7 @@ proc select*(builder: BuilderRef, condition, then, otherwise: ValueRef, name: cs proc icmp*(builder: BuilderRef, op: Predicate, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildICmp".} -proc bitcast*(builder: BuilderRef, val: ValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildBitcast".} +proc bitcast*(builder: BuilderRef, val: ValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildBitCast".} proc trunc*(builder: BuilderRef, val: ValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildTrunc".} proc zext*(builder: BuilderRef, val: ValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildZExt".} ## Zero-extend diff --git a/constantine/platforms/llvm/llvm.nim b/constantine/platforms/llvm/llvm.nim index 38addee5..838b229c 100644 --- a/constantine/platforms/llvm/llvm.nim +++ b/constantine/platforms/llvm/llvm.nim @@ -146,11 +146,14 @@ proc emitTo*[T: string or seq[byte]](t: TargetMachineRef, m: ModuleRef, codegen: # Builder # ------------------------------------------------------------ +proc getCurrentFunction*(builder: BuilderRef): ValueRef = + builder.getInsertBlock().getBasicBlockParent() + proc getContext*(builder: BuilderRef): ContextRef = # LLVM C API does not expose IRBuilder.getContext() # making this unnecessary painful # https://github.com/llvm/llvm-project/issues/59875 - builder.getInsertBlock().getBasicBlockParent().getTypeOf().getContext() + builder.getCurrentFunction().getTypeOf().getContext() # Types # ------------------------------------------------------------ @@ -175,15 +178,7 @@ proc function_t*(returnType: TypeRef, paramTypes: openArray[TypeRef]): TypeRef { # Values # ------------------------------------------------------------ -# TODO: remove ConstValueRef -# - This is used in `selConstraint` in asm_nvidia -# to choose the `n` literal constraint. -# Instead of inlining as literal and hurting instruction decoding -# with large 8 bytes value, we load from const memory. - -type - ConstValueRef* = distinct ValueRef - AnyValueRef* = ValueRef or ConstValueRef +proc isNil*(v: ValueRef): bool {.borrow.} proc getName*(v: ValueRef): string = var rLen: csize_t @@ -192,34 +187,6 @@ proc getName*(v: ValueRef): string = result = newString(rLen.int) copyMem(result[0].addr, rStr, rLen.int) -proc constInt*(ty: TypeRef, n: SomeInteger, signExtend = false): ConstValueRef {.inline.} = - ConstValueRef constInt(ty, culonglong(n), LlvmBool(signExtend)) - -proc getTypeOf*(v: ConstValueRef): TypeRef {.borrow.} -proc zext*(builder: BuilderRef, val: ConstValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.inline.} = - ## Zero-extend - builder.zext(ValueRef val, destTy, name) -proc sext*(builder: BuilderRef, val: ConstValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.inline.} = - ## Sign-extend - builder.sext(ValueRef val, destTy, name) - -proc add*(builder: BuilderRef, lhs, rhs: distinct AnyValueRef, name: cstring = ""): ValueRef {.inline.} = - builder.add(ValueRef lhs, ValueRef rhs, name) -proc addNSW*(builder: BuilderRef, lhs, rhs: distinct AnyValueRef, name: cstring = ""): ValueRef {.inline.} = - ## Addition No Signed Wrap, i.e. guaranteed to not overflow - builder.addNSW(ValueRef lhs, ValueRef rhs, name) -proc addNUW*(builder: BuilderRef, lhs, rhs: distinct AnyValueRef, name: cstring = ""): ValueRef {.inline.} = - ## Addition No Unsigned Wrap, i.e. guaranteed to not overflow - builder.addNUW(ValueRef lhs, ValueRef rhs, name) - -proc sub*(builder: BuilderRef, lhs, rhs: distinct AnyValueRef, name: cstring = ""): ValueRef {.inline.} = - builder.sub(ValueRef lhs, ValueRef rhs, name) -proc subNSW*(builder: BuilderRef, lhs, rhs: distinct AnyValueRef, name: cstring = ""): ValueRef {.inline.} = - ## Substraction No Signed Wrap, i.e. guaranteed to not overflow - builder.subNSW(ValueRef lhs, ValueRef rhs, name) -proc subNUW*(builder: BuilderRef, lhs, rhs: distinct AnyValueRef, name: cstring = ""): ValueRef {.inline.} = - ## Substraction No Unsigned Wrap, i.e. guaranteed to not overflow - builder.subNUW(ValueRef lhs, ValueRef rhs, name) - -proc icmp*(builder: BuilderRef, op: Predicate, lhs, rhs: distinct AnyValueRef, name: cstring = ""): ValueRef {.inline.} = - builder.icmp(op, ValueRef lhs, ValueRef rhs, name) +proc constInt*(ty: TypeRef, n: SomeInteger, signExtend = false): ValueRef {.inline.} = + constInt(ty, culonglong(n), LlvmBool(signExtend)) + diff --git a/constantine/platforms/llvm/super_instructions.nim b/constantine/platforms/llvm/super_instructions.nim index f090785e..7cf1e7a7 100644 --- a/constantine/platforms/llvm/super_instructions.nim +++ b/constantine/platforms/llvm/super_instructions.nim @@ -39,7 +39,7 @@ import ./llvm # - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/mulx32.ll # - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/mulx64.ll -# Warning: +# Warning 1: # # There is no guarantee of constant-time with LLVM IR # It MAY introduce branches. @@ -53,6 +53,16 @@ import ./llvm # - https://github.com/mratsim/constantine/wiki/Constant-time-arithmetics#fighting-the-compiler # - https://blog.cr.yp.to/20240803-clang.html # - https://www.cl.cam.ac.uk/~rja14/Papers/whatyouc.pdf +# +# Warning 2: +# +# Unfortunately implementing unrolled bigint arithmetic using word size +# is fraught with perils for add-carry / sub-borrow +# AMDGPU crash: https://github.com/llvm/llvm-project/issues/102058 +# ARM64 missed optim: https://github.com/llvm/llvm-project/issues/102062 +# +# and while using @llvm.usub.with.overflow.i64 allows ARM64 to solve the missing optimization +# it is also missed on AMDGPU (or nvidia) proc hi(bld: BuilderRef, val: ValueRef, baseTy: TypeRef, oversize: uint32, prefix: string): ValueRef = let ctx = bld.getContext() @@ -69,11 +79,9 @@ proc hi(bld: BuilderRef, val: ValueRef, baseTy: TypeRef, oversize: uint32, prefi return hi -proc addcarry*(bld: BuilderRef, a, b, carryIn: distinct AnyValueRef): tuple[carryOut, r: ValueRef] = +proc addcarry*(bld: BuilderRef, a, b, carryIn: ValueRef): tuple[carryOut, r: ValueRef] = ## (cOut, result) <- a+b+cIn - let ctx = bld.getContext() let ty = a.getTypeOf() - let bits = ty.getIntTypeWidth() let add = bld.add(a, b, name = "adc01_") let carry0 = bld.icmp(kULT, add, b, name = "adc01c_") @@ -84,11 +92,9 @@ proc addcarry*(bld: BuilderRef, a, b, carryIn: distinct AnyValueRef): tuple[carr return (carryOut, adc) -proc subborrow*(bld: BuilderRef, a, b, borrowIn: distinct AnyValueRef): tuple[borrowOut, r: ValueRef] = +proc subborrow*(bld: BuilderRef, a, b, borrowIn: ValueRef): tuple[borrowOut, r: ValueRef] = ## (bOut, result) <- a-b-bIn - let ctx = bld.getContext() let ty = a.getTypeOf() - let bits = ty.getIntTypeWidth() let sub = bld.sub(a, b, name = "sbb01_") let borrow0 = bld.icmp(kULT, a, b, name = "sbb01b_") diff --git a/research/codegen/x86_poc.nim b/research/codegen/x86_poc.nim index a677158e..b7949b8c 100644 --- a/research/codegen/x86_poc.nim +++ b/research/codegen/x86_poc.nim @@ -7,35 +7,46 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import - constantine/named/algebras, - constantine/math/io/io_bigints, - constantine/platforms/llvm/llvm, - constantine/platforms/primitives, - constantine/math_compiler/[ir, impl_fields_sat] - -proc init(T: type CurveMetadata, asy: Assembler_LLVM, curve: static Algebra, wordSize: WordSize): T = - CurveMetadata.init( - asy.ctx, - $curve & "_", wordSize, - fpBits = uint32 Fp[curve].bits(), - fpMod = Fp[curve].getModulus().toHex(), - frBits = uint32 Fr[curve].bits(), - frMod = Fr[curve].getModulus().toHex()) - -proc genFieldAddSat(asy: Assembler_LLVM, cm: CurveMetadata) = - let fpAdd = asy.field_add_gen_sat(cm, fp) - let frAdd = asy.field_add_gen_sat(cm, fr) - - -proc t_field_add(curve: static Algebra) = + constantine/math_compiler/[ir, pub_fields] + +const Fields = [ + ( + "bn254_snarks_fp", 254, + "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47" + ), + ( + "bn254_snarks_fr", 254, + "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001" + ), + + ( + "secp256k1_fp", 256, + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f" + ), + ( + "secp256k1_fr", 256, + "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141" + ), + ( + "bls12_381_fp", 381, + "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab" + ), + ( + "bls12_381_fr", 255, + "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001" + ), +] + + +proc t_field_add() = let asy = Assembler_LLVM.new(bkX86_64_Linux, cstring("x86_poc")) - let cm32 = CurveMetadata.init(asy, curve, size32) - asy.genFieldAddSat(cm32) - let cm64 = CurveMetadata.init(asy, curve, size64) - asy.genFieldAddSat(cm64) + for F in Fields: + let fd = asy.ctx.configureField( + F[0], F[1], F[2], + v = 1, w = 64) - asy.module.verify(AbortProcessAction) + discard asy.genFpAdd(fd) echo "=========================================" echo "LLVM IR\n" @@ -43,29 +54,11 @@ proc t_field_add(curve: static Algebra) = echo asy.module echo "=========================================" - var engine: ExecutionEngineRef - initializeFullNativeTarget() - createJITCompilerForModule(engine, asy.module, optLevel = 3) - - let fn32 = cm32.genSymbol(opFpAdd) - let fn64 = cm64.genSymbol(opFpAdd) - - let jitFpAdd64 = cast[proc(r: var array[4, uint64], a, b: array[4, uint64]){.noconv.}]( - engine.getFunctionAddress(cstring fn64) - ) - - var r: array[4, uint64] - r.jitFpAdd64([uint64 1, 2, 3, 4], [uint64 1, 1, 1, 1]) - echo "jitFpAdd64 = ", r - - # block: - # Cleanup - Assembler_LLVM is auto-managed - # engine.dispose() # also destroys the module attached to it, which double_frees Assembler_LLVM asy.module - echo "LLVM JIT - calling FpAdd64 SUCCESS" + asy.module.verify(AbortProcessAction) # -------------------------------------------- # See the assembly - note it might be different from what the JIT compiler did - + initializeFullNativeTarget() const triple = "x86_64-pc-linux-gnu" let machine = createTargetMachine( @@ -99,4 +92,24 @@ proc t_field_add(curve: static Algebra) = echo machine.emitTo[:string](asy.module, AssemblyFile) echo "=========================================" -t_field_add(Secp256k1) + # var engine: ExecutionEngineRef + # initializeFullNativeTarget() + # createJITCompilerForModule(engine, asy.module, optLevel = 3) + + # let fn32 = cm32.genSymbol(opFpAdd) + # let fn64 = cm64.genSymbol(opFpAdd) + + # let jitFpAdd64 = cast[proc(r: var array[4, uint64], a, b: array[4, uint64]){.noconv.}]( + # engine.getFunctionAddress(cstring fn64) + # ) + + # var r: array[4, uint64] + # r.jitFpAdd64([uint64 1, 2, 3, 4], [uint64 1, 1, 1, 1]) + # echo "jitFpAdd64 = ", r + + # # block: + # # Cleanup - Assembler_LLVM is auto-managed + # # engine.dispose() # also destroys the module attached to it, which double_frees Assembler_LLVM asy.module + # echo "LLVM JIT - calling FpAdd64 SUCCESS" + +t_field_add() diff --git a/tests/gpu/t_nvidia_fp.nim b/tests/gpu/t_nvidia_fp.nim index b3aa873f..af4de007 100644 --- a/tests/gpu/t_nvidia_fp.nim +++ b/tests/gpu/t_nvidia_fp.nim @@ -70,9 +70,9 @@ proc t_field_add(curve: static Algebra) = # Codegen # ------------------------- let asy = Assembler_LLVM.new(bkNvidiaPTX, cstring("t_nvidia_" & $curve)) - let cm32 = CurveMetadata.init(asy, curve, size32) + let cm32 = CurveMetadata.init(asy, curve, w32) asy.genFieldAddPTX(cm32) - let cm64 = CurveMetadata.init(asy, curve, size64) + let cm64 = CurveMetadata.init(asy, curve, w64) asy.genFieldAddPTX(cm64) let ptx = asy.codegenNvidiaPTX(sm) @@ -124,9 +124,9 @@ proc t_field_sub(curve: static Algebra) = # Codegen # ------------------------- let asy = Assembler_LLVM.new(bkNvidiaPTX, cstring("t_nvidia_" & $curve)) - let cm32 = CurveMetadata.init(asy, curve, size32) + let cm32 = CurveMetadata.init(asy, curve, w32) asy.genFieldSubPTX(cm32) - let cm64 = CurveMetadata.init(asy, curve, size64) + let cm64 = CurveMetadata.init(asy, curve, w64) asy.genFieldSubPTX(cm64) let ptx = asy.codegenNvidiaPTX(sm) @@ -178,14 +178,14 @@ proc t_field_mul(curve: static Algebra) = # Codegen # ------------------------- let asy = Assembler_LLVM.new(bkNvidiaPTX, cstring("t_nvidia_" & $curve)) - let cm32 = CurveMetadata.init(asy, curve, size32) + let cm32 = CurveMetadata.init(asy, curve, w32) asy.genFieldMulPTX(cm32) # 64-bit integer fused-multiply-add with carry is buggy: # https://gist.github.com/mratsim/a34df1e091925df15c13208df7eda569#file-mul-py # https://forums.developer.nvidia.com/t/incorrect-result-of-ptx-code/221067 - # let cm64 = CurveMetadata.init(asy, curve, size64) + # let cm64 = CurveMetadata.init(asy, curve, w64) # asy.genFieldMulPTX(cm64) let ptx = asy.codegenNvidiaPTX(sm)