From 569e02933166915ec76d17640a21a8a591e36f8f Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Wed, 14 Aug 2024 11:50:55 +0200 Subject: [PATCH] LLVM: field addition with saturated fields (#456) * feat(LLVM): add codegenerator for saturated field add/sub * LLVM: WIP refactor - boilerplate, linkage, assembly sections, ... * feat(llvm): try (and fail) to workaround bad modular addition codegen with inline function. * llvm: partial workaround failure around https://github.com/llvm/llvm-project/issues/102868\#issuecomment-2284935755 module inlining breaks machine instruction fusion * llvm: define our own addcarry/subborrow which properly optimize on x86 (but not ARM see https://github.com/llvm/llvm-project/issues/102062) * llvm: use builtin llvm.uadd.with.overflow.iXXX to try to generate optimal code (and fail for i320 and i384 https://github.com/llvm/llvm-project/issues/103717) --- PLANNING.md | 8 + .../assembly/limbs_asm_modular_x86.nim | 2 +- constantine/math_compiler/README.md | 83 +++ constantine/math_compiler/codegen_nvidia.nim | 20 +- .../math_compiler/impl_fields_globals.nim | 216 ++++++ .../math_compiler/impl_fields_nvidia.nim | 23 +- constantine/math_compiler/impl_fields_sat.nim | 153 ++++ constantine/math_compiler/ir.nim | 674 +++++++++++------- constantine/math_compiler/pub_fields.nim | 30 + constantine/platforms/abis/llvm_abi.nim | 160 ++++- .../extended_precision_64bit_uint128.nim | 4 +- .../extended_precision_x86_64_msvc.nim | 6 +- .../{nvidia_inlineasm.nim => asm_nvidia.nim} | 0 .../platforms/llvm/asm_x86.nim | 0 constantine/platforms/llvm/llvm.nim | 46 +- .../platforms/llvm/super_instructions.nim | 392 ++++++++++ research/codegen/x86_instr.nim | 96 --- research/codegen/x86_poc.nim | 386 +++------- tests/gpu/t_nvidia_fp.nim | 12 +- 19 files changed, 1626 insertions(+), 685 deletions(-) create mode 100644 constantine/math_compiler/README.md create mode 100644 constantine/math_compiler/impl_fields_globals.nim create mode 100644 constantine/math_compiler/impl_fields_sat.nim create mode 100644 constantine/math_compiler/pub_fields.nim rename constantine/platforms/llvm/{nvidia_inlineasm.nim => asm_nvidia.nim} (100%) rename research/codegen/x86_inlineasm.nim => constantine/platforms/llvm/asm_x86.nim (100%) create mode 100644 constantine/platforms/llvm/super_instructions.nim delete mode 100644 research/codegen/x86_instr.nim diff --git a/PLANNING.md b/PLANNING.md index 7fd6f956..b4258d26 100644 --- a/PLANNING.md +++ b/PLANNING.md @@ -101,6 +101,14 @@ Other tracks are stretch goals, contributions towards them are accepted. - introduce batchAffine_vartime - Optimized square_repeated in assembly for Montgomery and Crandall/Pseudo-Mersenne primes - Optimized elliptic curve directly calling assembly without ADX checks and limited input/output movement in registers or using function multi-versioning. +- LLVM IR: + - use internal or private linkage type + - look into calling conventions like "fast" or "Tail fast" + - check if returning a value from function is propely optimized + compared to in-place result + - use readnone (pure) and readmem attribute for functions + - look into passing parameter as arrays instead of pointers? + - use hot function attribute ### User Experience track diff --git a/constantine/math/arithmetic/assembly/limbs_asm_modular_x86.nim b/constantine/math/arithmetic/assembly/limbs_asm_modular_x86.nim index 12da5fcd..0db85007 100644 --- a/constantine/math/arithmetic/assembly/limbs_asm_modular_x86.nim +++ b/constantine/math/arithmetic/assembly/limbs_asm_modular_x86.nim @@ -80,7 +80,7 @@ proc finalSubMayOverflowImpl*( ctx.mov scratch[i], a[i] ctx.sbb scratch[i], M[i] - # If it overflows here, it means that it was + # If it underflows here, it means that it was # smaller than the modulus and we don't need `scratch` ctx.sbb scratchReg, 0 diff --git a/constantine/math_compiler/README.md b/constantine/math_compiler/README.md new file mode 100644 index 00000000..2ff81b44 --- /dev/null +++ b/constantine/math_compiler/README.md @@ -0,0 +1,83 @@ +# Cryptography primitive compiler + +This implements a cryptography compiler that can be used to produce +- high-performance JIT code for GPUs +- or assembly files, for CPUs when we want to ensure + there are no side-channel regressions for secret data +- or vectorized assembly file, as LLVM IR is significantly + more convenient to model vector operation + +There are also LLVM IR => FPGA translators that might be useful +in the future. + +## Platforms limitations + +- X86 cannot use dual carry-chain ADCX/ADOX easily. + - no native support for clearing a flag with `xor` + and keeping it clear. + - inline assembly cannot use the raw ASM printer. + so workflow will need to compile -> decompile. +- Nvidia GPUs cannot lower types larger than 64-bit, hence we cannot use i256 for example. +- AMD GPUs have a 1/4 throughput for i32 MUL compared to f32 MUL or i24 MUL +- non-x86 targets may not be as optimized for matching + pattern for addcarry and subborrow, even with @llvm.usub.with.overflow + +## ABI + +Internal functions are: +- prefixed with `_` +- Linkage: internal +- calling convention: "fast" +- mark `hot` for field arithmetic functions + +Internal global constants are: +- prefixed with `_` +- Linkage: linkonce_odr (so they are merged with globals of the same name) + +External functions use default convention. + +We ensure parameters / return value fit in registers: +- https://llvm.org/docs/Frontend/PerformanceTips.html + +TODO: +- function alignment: look into + - https://www.bazhenov.me/posts/2024-02-performance-roulette/ + - https://lkml.org/lkml/2015/5/21/443 +- function multiversioning +- aggregate alignment (via datalayout) + +Naming convention for internal procedures: +- _big_add_u64x4 +- _finalsub_mayo_u64x4 -> final substraction may overflow +- _finalsub_noo_u64x4 -> final sub no overflow +- _mod_add_u64x4 +- _mod_add2x_u64x8 -> FpDbl backend +- _mty_mulur_u64x4b2 -> unreduced Montgomery multiplication (unreduced result valid iff 2 spare bits) +- _mty_mul_u64x4b1 -> reduced Montgomery multiplication (result valid iff at least 1 spare bit) +- _mty_mul_u64x4 -> reduced Montgomery multiplication +- _mty_nsqrur_u64x4b2 -> unreduced square n times +- _mty_nsqr_u64x4b1 -> reduced square n times +- _mty_sqr_u64x4 -> square +- _mty_red_u64x4 -> reduction u64x4 <- u64x8 +- _pmp_red_mayo_u64x4 -> Pseudo-Mersenne Prime partial reduction may overflow (secp256k1) +- _pmp_red_noo_u64x4 -> Pseudo-Mersenne Prime partial reduction no overflow +- _secp256k1_red -> special reduction +- _fp2x_sqr2x_u64x4 -> Fp2 complex, Fp -> FpDbl lazy reduced squaring +- _fp2g_sqr2x_u64x4 -> Fp2 generic/non-complex (do we pass the mul-non-residue as parameter?) +- _fp2_sqr_u64x4 -> Fp2 (pass the mul-by-non-residue function as parameter) +- _fp4o2_mulnr1pi_u64x4 -> Fp4 over Fp2 mul with (1+i) non-residue optimization +- _fp4o2_mulbynr_u64x4 +- _fp12_add_u64x4 +- _fp12o4o2_mul_u64x4 -> Fp12 over Fp4 over Fp2 +- _ecg1swjac_adda0_u64x4 -> Shortweierstrass G1 jacobian addition a=0 +- _ecg1swjac_add_u64x4_var -> Shortweierstrass G1 jacobian vartime addition +- _ectwprj_add_u64x4 -> Twisted Edwards Projective addition + +Vectorized: +- _big_add_u64x4v4 +- _big_add_u32x8v8 + +Naming for external procedures: +- bls12_381_fp_add +- bls12_381_fr_add +- bls12_381_fp12_add diff --git a/constantine/math_compiler/codegen_nvidia.nim b/constantine/math_compiler/codegen_nvidia.nim index 19e92019..fdc4c393 100644 --- a/constantine/math_compiler/codegen_nvidia.nim +++ b/constantine/math_compiler/codegen_nvidia.nim @@ -9,12 +9,12 @@ import constantine/platforms/abis/nvidia_abi {.all.}, constantine/platforms/abis/c_abi, - constantine/platforms/llvm/[llvm, nvidia_inlineasm], + constantine/platforms/llvm/llvm, constantine/platforms/primitives, ./ir export - nvidia_abi, nvidia_inlineasm, + nvidia_abi, Flag, flag, wrapOpenArrayLenType # ############################################################ @@ -115,22 +115,6 @@ proc cudaDeviceInit*(deviceID = 0'i32): CUdevice = # # ############################################################ -proc tagCudaKernel(module: ModuleRef, fn: FnDef) = - ## Tag a function as a Cuda Kernel, i.e. callable from host - - doAssert fn.fnTy.getReturnType().isVoid(), block: - "Kernels must not return values but function returns " & $fn.fnTy.getReturnType().getTypeKind() - - let ctx = module.getContext() - module.addNamedMetadataOperand( - "nvvm.annotations", - ctx.asValueRef(ctx.metadataNode([ - fn.fnImpl.asMetadataRef(), - ctx.metadataNode("kernel"), - constInt(ctx.int32_t(), 1, LlvmBool(false)).asMetadataRef() - ])) - ) - proc wrapInCallableCudaKernel*(module: ModuleRef, fn: FnDef) = ## Create a public wrapper of a cuda device function ## diff --git a/constantine/math_compiler/impl_fields_globals.nim b/constantine/math_compiler/impl_fields_globals.nim new file mode 100644 index 00000000..faac591c --- /dev/null +++ b/constantine/math_compiler/impl_fields_globals.nim @@ -0,0 +1,216 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + constantine/platforms/bithacks, + constantine/platforms/llvm/llvm, + constantine/serialization/[io_limbs, codecs], + constantine/named/deriv/precompute + +import ./ir + +# ############################################################ +# +# Metadata precomputation +# +# ############################################################ + +# Constantine on CPU is configured at compile-time for several properties that need to be runtime configuration GPUs: +# - word size (32-bit or 64-bit) +# - curve properties access like modulus bitsize or -1/M[0] a.k.a. m0ninv +# - constants are stored in freestanding `const` +# +# This is because it's not possible to store a BigInt[254] and a BigInt[384] +# in a generic way in the same structure, especially without using heap allocation. +# And with Nim's dead code elimination, unused curves are not compiled in. +# +# As there would be no easy way to dynamically retrieve (via an array or a table) +# const BLS12_381_modulus = ... +# const BN254_Snarks_modulus = ... +# +# - We would need a macro to properly access each constant. +# - We would need to create a 32-bit and a 64-bit version. +# - Unused curves would be compiled in the program. +# +# Note: on GPU we don't manipulate secrets hence branches and dynamic memory allocations are allowed. +# +# As GPU is a niche usage, instead we recreate the relevant `precompute` and IO procedures +# with dynamic wordsize support. + +type + DynWord = uint32 or uint64 + BigNum[T: DynWord] = object + bits: uint32 + limbs: seq[T] + +# Serialization +# ------------------------------------------------ + +func byteLen(bits: SomeInteger): SomeInteger {.inline.} = + ## Length in bytes to serialize BigNum + (bits + 7) shr 3 # (bits + 8 - 1) div 8 + +func fromHex[T](a: var BigNum[T], s: string) = + var bytes = newSeq[byte](a.bits.byteLen()) + bytes.paddedFromHex(s, bigEndian) + + # 2. Convert canonical uint to BigNum + const wordBitwidth = sizeof(T) * 8 + a.limbs.unmarshal(bytes, wordBitwidth, bigEndian) + +func fromHex[T](BN: type BigNum[T], bits: uint32, s: string): BN = + const wordBitwidth = sizeof(T) * 8 + let numWords = wordsRequired(bits, wordBitwidth) + + result.bits = bits + result.limbs.setLen(numWords) + result.fromHex(s) + +func toHexLlvm*[T](a: BigNum[T]): string = + ## Conversion to big-endian hex suitable for LLVM literals + ## It MUST NOT have a prefix + ## This is variable-time + # 1. Convert BigInt to canonical uint + const wordBitwidth = sizeof(T) * 8 + var bytes = newSeq[byte](byteLen(a.bits)) + bytes.marshal(a.limbs, wordBitwidth, bigEndian) + + # 2. Convert canonical uint to hex + const hexChars = "0123456789abcdef" + result = newString(2 * bytes.len) + for i in 0 ..< bytes.len: + let bi = bytes[i] + result[2*i] = hexChars[bi shr 4 and 0xF] + result[2*i+1] = hexChars[bi and 0xF] + +# Checks +# ------------------------------------------------ + +func checkValidModulus(M: BigNum) = + const wordBitwidth = uint32(BigNum.T.sizeof() * 8) + let expectedMsb = M.bits-1 - wordBitwidth * (M.limbs.len.uint32 - 1) + let msb = log2_vartime(M.limbs[M.limbs.len-1]) + + doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those:\n" & + " Modulus '0x" & M.toHexLlvm() & "' is declared with " & $M.bits & + " bits but uses " & $(msb + wordBitwidth * uint32(M.limbs.len - 1)) & " bits." + +# Fields metadata +# ------------------------------------------------ + +func negInvModWord[T](M: BigNum[T]): T = + ## Returns the Montgomery domain magic constant for the input modulus: + ## + ## µ ≡ -1/M[0] (mod SecretWord) + ## + ## M[0] is the least significant limb of M + ## M must be odd and greater than 2. + ## + ## Assuming 64-bit words: + ## + ## µ ≡ -1/M[0] (mod 2^64) + checkValidModulus(M) + return M.limbs[0].negInvModWord() + +# ############################################################ +# +# Globals in IR +# +# ############################################################ + +proc getModulusPtr*(asy: Assembler_LLVM, fd: FieldDescriptor): ValueRef = + let modname = fd.name & "_mod" + var M = asy.module.getGlobal(cstring modname) + if M.isNil(): + M = asy.defineGlobalConstant( + name = modname, + section = fd.name, + constIntOfStringAndSize(fd.intBufTy, fd.modulus, 16), + fd.intBufTy, + alignment = 64 + ) + return M + +proc getM0ninv*(asy: Assembler_LLVM, fd: FieldDescriptor): ValueRef = + let m0ninvname = fd.name & "_m0ninv" + var m0ninv = asy.module.getGlobal(cstring m0ninvname) + if m0ninv.isNil(): + if fd.w == 32: + let M = BigNum[uint32].fromHex(fd.bits, fd.modulus) + m0ninv = asy.defineGlobalConstant( + name = m0ninvname, + section = fd.name, + constInt(fd.wordTy, M.negInvModWord()), + fd.wordTy + ) + else: + let M = BigNum[uint64].fromHex(fd.bits, fd.modulus) + m0ninv = asy.defineGlobalConstant( + name = m0ninvname, + section = fd.name, + constInt(fd.wordTy, M.negInvModWord()), + fd.wordTy + ) + + + return m0ninv + +when isMainModule: + let asy = Assembler_LLVM.new("test_module", bkX86_64_Linux) + let fd = asy.ctx.configureField( + "bls12_381_fp", + 381, + "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab", + v = 1, w = 64) + + discard asy.getModulusPtr(fd) + discard asy.getM0ninv(fd) + + echo "=========================================" + echo "LLVM IR\n" + + echo asy.module + echo "=========================================" + + asy.module.verify(AbortProcessAction) + + # -------------------------------------------- + # See the assembly - note it might be different from what the JIT compiler did + initializeFullNativeTarget() + + const triple = "x86_64-pc-linux-gnu" + + let machine = createTargetMachine( + target = toTarget(triple), + triple = triple, + cpu = "", + features = "adx,bmi2", # TODO check the proper way to pass options + level = CodeGenLevelAggressive, + reloc = RelocDefault, + codeModel = CodeModelDefault + ) + + let pbo = createPassBuilderOptions() + let err = asy.module.runPasses( + "default,function-attrs,memcpyopt,sroa,mem2reg,gvn,dse,instcombine,inline,adce", + machine, + pbo + ) + if not err.pointer().isNil(): + writeStackTrace() + let errMsg = err.getErrorMessage() + stderr.write("\"codegenX86_64\" for module '" & astToStr(module) & "' " & $instantiationInfo() & + " exited with error: " & $cstring(errMsg) & '\n') + errMsg.dispose() + quit 1 + + echo "=========================================" + echo "Assembly\n" + + echo machine.emitTo[:string](asy.module, AssemblyFile) + echo "=========================================" diff --git a/constantine/math_compiler/impl_fields_nvidia.nim b/constantine/math_compiler/impl_fields_nvidia.nim index 0ffbb5b1..5843d02d 100644 --- a/constantine/math_compiler/impl_fields_nvidia.nim +++ b/constantine/math_compiler/impl_fields_nvidia.nim @@ -7,8 +7,8 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import - ../platforms/llvm/llvm, - ./ir, ./codegen_nvidia + constantine/platforms/llvm/[llvm, asm_nvidia], + ./ir # ############################################################ # @@ -40,8 +40,13 @@ import # but the carry codegen of madc.hi.cc.u64 has off-by-one # - https://forums.developer.nvidia.com/t/incorrect-result-of-ptx-code/221067 # - old 32-bit bug: https://forums.developer.nvidia.com/t/wrong-result-returned-by-madc-hi-u64-ptx-instruction-for-specific-operands/196094 +# +# See instruction throughput +# - https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions +# +# We cannot use i256 on Nvidia target: https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L244-L276 -proc finalSubMayOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) = +proc finalSubMayOverflow(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -74,7 +79,7 @@ proc finalSubMayOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, for i in 0 ..< N: r[i] = bld.slct(scratch[i], a[i], underflowedModulus) -proc finalSubNoOverflow*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) = +proc finalSubNoOverflow(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, r, a: Array) = ## If a >= Modulus: r <- a-M ## else: r <- a ## @@ -165,8 +170,8 @@ proc field_sub_gen*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field): FnDef let t = bld.makeArray(fieldTy) let N = cm.getNumWords(field) let zero = case cm.wordSize - of size32: constInt(asy.i32_t, 0) - of size64: constInt(asy.i64_t, 0) + of w32: constInt(asy.i32_t, 0) + of w64: constInt(asy.i64_t, 0) t[0] = bld.sub_bo(a[0], b[0]) for i in 1 ..< N: @@ -258,8 +263,8 @@ proc field_mul_CIOS_sparebit_gen(asy: Assembler_LLVM, cm: CurveMetadata, field: let m0ninv = ValueRef cm.getMontgomeryNegInverse0(field) let M = (seq[ValueRef])(cm.getModulus(field)) let zero = case cm.wordSize - of size32: constInt(asy.i32_t, 0) - of size64: constInt(asy.i64_t, 0) + of w32: constInt(asy.i32_t, 0) + of w64: constInt(asy.i64_t, 0) for i in 0 ..< N: # Multiplication @@ -354,4 +359,4 @@ proc field_mul_CIOS_sparebit_gen(asy: Assembler_LLVM, cm: CurveMetadata, field: proc field_mul_gen*(asy: Assembler_LLVM, cm: CurveMetadata, field: Field, skipFinalSub = false): FnDef = ## Generate an optimized modular addition kernel ## with parameters `a, b, modulus: Limbs -> Limbs` - return asy.field_mul_CIOS_sparebit_gen(cm, field, skipFinalSub) \ No newline at end of file + return asy.field_mul_CIOS_sparebit_gen(cm, field, skipFinalSub) diff --git a/constantine/math_compiler/impl_fields_sat.nim b/constantine/math_compiler/impl_fields_sat.nim new file mode 100644 index 00000000..cd52f96f --- /dev/null +++ b/constantine/math_compiler/impl_fields_sat.nim @@ -0,0 +1,153 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + constantine/platforms/llvm/[llvm, super_instructions], + ./ir + +# ############################################################ +# +# Field arithmetic with saturated limbs +# +# ############################################################ +# +# This implements field operations in pure LLVM +# using saturated limbs, i.e. 64-bit words on 64-bit platforms. +# +# This relies on hardware addition-with-carry and substraction-with-borrow +# for efficiency. +# +# As such it is not suitable for platforms with no carry flags such as: +# - WASM +# - MIPS +# - RISC-V +# - Metal +# +# It may be suitable for Intel GPUs as the virtual ISA does support add-carry +# +# It is (theoretically) suitable for: +# - ARM +# - AMD GPUs +# +# The following backends have better optimizations through assembly: +# - x86: access to ADOX and ADCX interleaved double-carry chain +# - Nvidia: access to multiply accumulate instruction +# and non-interleaved double-carry chain +# +# Hardware limitations +# -------------------- +# +# AMD GPUs may benefits from using 24-bit limbs +# - https://www.amd.com/content/dam/amd/en/documents/radeon-tech-docs/programmer-references/AMD_OpenCL_Programming_Optimization_Guide2.pdf +# p2-23: +# Generally, the throughput and latency for 32-bit integer operations is the same +# as for single-precision floating point operations. +# 24-bit integer MULs and MADs have four times the throughput of 32-bit integer +# multiplies. 24-bit signed and unsigned integers are natively supported on the +# GCN family of devices. The use of OpenCL built-in functions for mul24 and mad24 +# is encouraged. Note that mul24 can be useful for array indexing operations +# Doc from 2015, it might not apply to RDNA family +# - https://free.eol.cn/edu_net/edudown/AMDppt/OpenCL%20Programming%20and%20Optimization%20-%20Part%20I.pdf +# slide 24 +# +# - https://chipsandcheese.com/2023/01/07/microbenchmarking-amds-rdna-3-graphics-architecture/ +# "Since Turing, Nvidia also achieves very good integer multiplication performance. +# Integer multiplication appears to be extremely rare in shader code, +# and AMD doesn’t seem to have optimized for it. +# 32-bit integer multiplication executes at around a quarter of FP32 rate, +# and latency is pretty high too." +# +# Software limitations +# -------------------- +# +# Unfortunately implementing unrolled using word size is fraught with perils +# for add-carry / sub-borrow +# AMDGPU crash: https://github.com/llvm/llvm-project/issues/102058 +# ARM64 missed optim: https://github.com/llvm/llvm-project/issues/102062 +# +# and while using @llvm.usub.with.overflow.i64 allows ARM64 to solve the missing optimization +# it is also missed on AMDGPU (or nvidia) +# +# And implementing them with i256 / i384 is similarly tricky +# https://github.com/llvm/llvm-project/issues/102868 + +const SectionName = "ctt.fields" + +proc finalSubMayOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, M, carry: ValueRef) = + ## If a >= Modulus: r <- a-M + ## else: r <- a + ## + ## This is constant-time straightline code. + ## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU. + ## + ## To be used when the final substraction can + ## also overflow the limbs (a 2^256 order of magnitude modulus stored in n words of total max size 2^256) + + # Mask: contains 0xFFFF or 0x0000 + let (_, mask) = asy.br.subborrow(fd.zero, fd.zero, carry) + + # Now substract the modulus, and test a < M + # (underflow) with the last borrow + let (borrow, a_minus_M) = asy.br.llvm_sub_overflow(a, M) + + # If it underflows here, it means that it was + # smaller than the modulus and we don't need `a-M` + let (ctl, _) = asy.br.subborrow(mask, fd.zero, borrow) + + let t = asy.br.select(ctl, a, a_minus_M) + asy.store(rr, t) + +proc finalSubNoOverflow*(asy: Assembler_LLVM, fd: FieldDescriptor, rr, a, M: ValueRef) = + ## If a >= Modulus: r <- a-M + ## else: r <- a + ## + ## This is constant-time straightline code. + ## Due to warp divergence, the overhead of doing comparison with shortcutting might not be worth it on GPU. + ## + ## To be used when the modulus does not use the full bitwidth of the storing words + ## (say using 255 bits for the modulus out of 256 available in words) + + # Now substract the modulus, and test a < M + # (underflow) with the last borrow + let (borrow, a_minus_M) = asy.br.llvm_sub_overflow(a, M) + + # If it underflows here, it means that it was + # smaller than the modulus and we don't need `a-M` + let t = asy.br.select(borrow, a, a_minus_M) + asy.store(rr, t) + +proc modadd*(asy: Assembler_LLVM, fd: FieldDescriptor, r, a, b, M: ValueRef) = + ## Generate an optimized modular addition kernel + ## with parameters `a, b, modulus: Limbs -> Limbs` + + let red = if fd.spareBits >= 1: "noo" + else: "mayo" + let name = "_modadd_" & red & ".u" & $fd.w & "x" & $fd.numWords + asy.llvmInternalFnDef( + name, SectionName, + asy.void_t, toTypes([r, a, b, M]), + {kHot}): + + tagParameter(1, "sret") + + let (rr, aa, bb, MM) = llvmParams + + # Pointers are opaque in LLVM now + let a = asy.load2(fd.intBufTy, aa, "a") + let b = asy.load2(fd.intBufTy, bb, "b") + let M = asy.load2(fd.intBufTy, MM, "M") + + let (carry, apb) = asy.br.llvm_add_overflow(a, b) + if fd.spareBits >= 1: + asy.finalSubNoOverflow(fd, rr, apb, M) + else: + asy.finalSubMayOverflow(fd, rr, apb, M, carry) + + asy.br.retVoid() + + asy.callFn(name, [r, a, b, M]) diff --git a/constantine/math_compiler/ir.nim b/constantine/math_compiler/ir.nim index 1523fdab..ab80e036 100644 --- a/constantine/math_compiler/ir.nim +++ b/constantine/math_compiler/ir.nim @@ -7,12 +7,9 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import - constantine/named/algebras, - constantine/named/deriv/precompute, - constantine/math/io/io_bigints, - constantine/platforms/[primitives, bithacks], - constantine/platforms/llvm/llvm, - constantine/serialization/[endians, codecs, io_limbs] + constantine/platforms/bithacks, + constantine/platforms/llvm/[llvm, super_instructions], + std/tables # ############################################################ # @@ -21,280 +18,198 @@ import # ############################################################ type + AttrKind* = enum + # Other important properties like + # - norecurse + # - memory side-effects memory(argmem: readwrtite) + # can be deduced. + kHot, + kInline, + kAlwaysInline, + kNoInline + Assembler_LLVM* = ref object - # LLVM ctx*: ContextRef module*: ModuleRef - builder*: BuilderRef - i1_t*, i32_t*, i64_t*, i128_t*, void_t*: TypeRef + br*: BuilderRef + datalayout: TargetDataRef + psize*: int32 + publicCC: CallingConvention backend*: Backend + byteOrder: ByteOrder + + # It doesn't seem possible to retrieve a function type + # from its value, so we store them here. + # If we store the type we might as well store the impl + # and we store whether it's internal to apply the fastcc calling convention + fns: Table[string, tuple[ty: TypeRef, impl: ValueRef, internal: bool]] + attrs: array[AttrKind, AttributeRef] + + # Convenience + void_t*: TypeRef Backend* = enum + bkAmdGpu bkNvidiaPTX bkX86_64_Linux - FnDef* = tuple[fnTy: TypeRef, fnImpl: ValueRef] - # calling getTypeOf on a ValueRef function - # loses type information like return value type or arity - proc finalizeAssemblerLLVM(asy: Assembler_LLVM) = if not asy.isNil: - asy.builder.dispose() + asy.br.dispose() asy.module.dispose() asy.ctx.dispose() + # asy.datalayout.dispose() # unnecessary when module is cleared + +proc configure(asy: var Assembler_LLVM, backend: Backend) = + case backend + of bkAmdGpu: + asy.module.setTarget("amdgcn-amd-amdhsa") + + const datalayout1 {.used.} = + "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-" & + "i64:64-" & + "v16:16-v24:32-" & + "v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-" & + "n32:64-S32-A5-G1-ni:7" + + const datalayout2 = + "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-" & + "i64:64-" & + "v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-" & + "n32:64-S32-A5-G1-ni:7:8" + + asy.module.setDataLayout(datalayout2) + + of bkNvidiaPTX: + asy.module.setTarget("nvptx64-nvidia-cuda") + # Datalayout for NVVM IR 1.8 (CUDA 11.6) + asy.module.setDataLayout( + "e-" & "p:64:64:64-" & + "i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-" & + "f32:32:32-f64:64:64-" & + "v16:16:16-v32:32:32-v64:64:64-v128:128:128-" & + "n16:32:64") + of bkX86_64_Linux: + asy.module.setTarget("x86_64-pc-linux-gnu") + + asy.datalayout = asy.module.getDataLayout() + asy.psize = int32 asy.datalayout.getPointerSize() + asy.backend = backend + asy.byteOrder = asy.dataLayout.getEndianness() proc new*(T: type Assembler_LLVM, backend: Backend, moduleName: cstring): Assembler_LLVM = new result, finalizeAssemblerLLVM result.ctx = createContext() result.module = result.ctx.createModule(moduleName) + result.br = result.ctx.createBuilder() + result.datalayout = result.module.getDataLayout() - case backend - of bkNvidiaPTX: - result.module.setTarget("nvptx64-nvidia-cuda") - # Datalayout for NVVM IR 1.8 (CUDA 11.6) - result.module.setDataLayout("e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64") - of bkX86_64_Linux: - {.warning : "The x86 LLVM backend is incomplete and for research purposes only".} - result.module.setTarget("x86_64-pc-linux-gnu") - - result.builder = result.ctx.createBuilder() - result.i1_t = result.ctx.int1_t() - result.i32_t = result.ctx.int32_t() - result.i64_t = result.ctx.int64_t() - result.i128_t = result.ctx.int128_t() result.void_t = result.ctx.void_t() - result.backend = backend + + result.configure(backend) + + result.attrs[kHot] = result.ctx.createAttr("hot") + result.attrs[kInline] = result.ctx.createAttr("inlinehint") + result.attrs[kAlwaysInline] = result.ctx.createAttr("alwaysinline") + result.attrs[kNoInline] = result.ctx.createAttr("noinline") + result.attrs[kNoInline] = result.ctx.createAttr("sret") # ############################################################ # -# Metadata precomputation +# Syntax Sugar # # ############################################################ -# Constantine on CPU is configured at compile-time for several properties that need to be runtime configuration GPUs: -# - word size (32-bit or 64-bit) -# - curve properties access like modulus bitsize or -1/M[0] a.k.a. m0ninv -# - constants are stored in freestanding `const` -# -# This is because it's not possible to store a BigInt[254] and a BigInt[384] -# in a generic way in the same structure, especially without using heap allocation. -# And with Nim's dead code elimination, unused curves are not compiled in. -# -# As there would be no easy way to dynamically retrieve (via an array or a table) -# const BLS12_381_modulus = ... -# const BN254_Snarks_modulus = ... -# -# - We would need a macro to properly access each constant. -# - We would need to create a 32-bit and a 64-bit version. -# - Unused curves would be compiled in the program. -# -# Note: on GPU we don't manipulate secrets hence branches and dynamic memory allocations are allowed. -# -# As GPU is a niche usage, instead we recreate the relevant `precompute` and IO procedures -# with dynamic wordsize support. - -type - DynWord = uint32 or uint64 - BigNum[T: DynWord] = object - bits: uint32 - limbs: seq[T] +func i1*(asy: Assembler_LLVM, v: SomeInteger): ValueRef = + constInt(asy.ctx.int1_t(), v) -# Serialization -# ------------------------------------------------ +func i32*(asy: Assembler_LLVM, v: SomeInteger): ValueRef = + constInt(asy.ctx.int32_t(), v) -func byteLen(bits: SomeInteger): SomeInteger {.inline.} = - ## Length in bytes to serialize BigNum - (bits + 7) shr 3 # (bits + 8 - 1) div 8 +# ############################################################ +# +# Intermediate Representation +# +# ############################################################ -func wordsRequired(bits, wordBitwidth: SomeInteger): SomeInteger {.inline.} = +func wordsRequired*(bits, wordBitwidth: SomeInteger): SomeInteger {.inline.} = ## Compute the number of limbs required ## from the announced bit length - debug: doAssert wordBitwidth == 32 or wordBitwidth == 64 # Power of 2 + doAssert wordBitwidth == 32 or wordBitwidth == 64 # Power of 2 (bits + wordBitwidth - 1) shr log2_vartime(uint32 wordBitwidth) # 5x to 55x faster than dividing by wordBitwidth -func fromHex[T](a: var BigNum[T], s: string) = - var bytes = newSeq[byte](a.bits.byteLen()) - bytes.paddedFromHex(s, bigEndian) - - # 2. Convert canonical uint to BigNum - const wordBitwidth = sizeof(T) * 8 - a.limbs.unmarshal(bytes, wordBitwidth, bigEndian) - -func fromHex[T](BN: type BigNum[T], bits: uint32, s: string): BN = - const wordBitwidth = sizeof(T) * 8 - let numWords = wordsRequired(bits, wordBitwidth) - - result.bits = bits - result.limbs.setLen(numWords) - result.fromHex(s) +type + FieldDescriptor* = object + name*: string + modulus*: string # Modulus as Big-Endian uppercase hex, NOT prefixed with 0x + # primeKind*: PrimeKind + + # Word: i32, i64 but can also be v4i32, v16i32 ... + wordTy*: TypeRef + word2xTy*: TypeRef # Double the word size + v*, w*: uint32 + numWords*: uint32 + zero*, zero_i1*: ValueRef + intBufTy*: TypeRef # int type, multiple of the word size, that can store the field elements + # For example a 381 bit field is stored in 384-bit ints (whether on 32 or 64-bit platforms) + + # Field metadata + fieldTy*: TypeRef + bits*: uint32 + spareBits*: uint8 -func toHex[T](a: BigNum[T]): string = - ## Conversion to big-endian hex - ## This is variable-time - # 1. Convert BigInt to canonical uint - const wordBitwidth = sizeof(T) * 8 - var bytes = newSeq[byte](byteLen(a.bits)) - bytes.marshal(a.limbs, wordBitwidth, bigEndian) +proc configureField*(ctx: ContextRef, + name: string, + modBits: int, modulus: string, + v, w: int): FieldDescriptor = + ## Configure a field descriptor with: + ## - v: vector length + ## - w: base word size in bits + ## - a `modulus` of bitsize `modBits` + ## + ## - Name is a prefix for example + ## `mycurve_fp_` - # 2 Convert canonical uint to hex - return bytes.toHex() + let v = uint32 v + let w = uint32 w + let modBits = uint32 modBits -# Checks -# ------------------------------------------------ + result.name = name + result.modulus = modulus -func checkValidModulus(M: BigNum) = - const wordBitwidth = uint32(BigNum.T.sizeof() * 8) - let expectedMsb = M.bits-1 - wordBitwidth * (M.limbs.len.uint32 - 1) - let msb = log2_vartime(M.limbs[M.limbs.len-1]) + doAssert v == 1, "At the moment SIMD vectorization is not supported." + result.v = v + result.w = w - doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those:\n" & - " Modulus '" & M.toHex() & "' is declared with " & $M.bits & - " bits but uses " & $(msb + wordBitwidth * uint32(M.limbs.len - 1)) & " bits." + result.numWords = wordsRequired(modBits, w) + result.wordTy = ctx.int_t(w) + result.word2xTy = ctx.int_t(w+w) + result.zero = constInt(result.wordTy, 0) + result.zero_i1 = constInt(ctx.int1_t(), 0) -# Fields metadata -# ------------------------------------------------ + let next_multiple_wordsize = result.numWords * w + result.intBufTy = ctx.int_t(next_multiple_wordsize) -func negInvModWord[T](M: BigNum[T]): T = - ## Returns the Montgomery domain magic constant for the input modulus: - ## - ## µ ≡ -1/M[0] (mod SecretWord) - ## - ## M[0] is the least significant limb of M - ## M must be odd and greater than 2. - ## - ## Assuming 64-bit words: - ## - ## µ ≡ -1/M[0] (mod 2^64) - checkValidModulus(M) - return M.limbs[0].negInvModWord() + result.fieldTy = array_t(result.wordTy, result.numWords) + result.bits = modBits + result.spareBits = uint8(next_multiple_wordsize - modBits) -# ############################################################ -# -# Intermediate Representation -# -# ############################################################ +proc definePrimitives*(asy: Assembler_LLVM, fd: FieldDescriptor) = + asy.ctx.def_llvm_add_overflow(asy.module, fd.wordTy) + asy.ctx.def_llvm_add_overflow(asy.module, fd.intBufTy) + asy.ctx.def_llvm_sub_overflow(asy.module, fd.wordTy) + asy.ctx.def_llvm_sub_overflow(asy.module, fd.intBufTy) -type - WordSize* = enum - size32 - size64 - - Field* = enum - fp - fr - - FieldConst* = object - wordTy: TypeRef - fieldTy: TypeRef - modulus*: seq[ConstValueRef] - m0ninv*: ConstValueRef - bits*: uint32 - spareBits*: uint8 + asy.ctx.def_addcarry(asy.module, asy.ctx.int1_t(), fd.wordTy) + asy.ctx.def_subborrow(asy.module, asy.ctx.int1_t(), fd.wordTy) - CurveMetadata* = object - curve*: Algebra - prefix*: string - wordSize*: WordSize - fp*: FieldConst - fr*: FieldConst - - Opcode* = enum - opFpAdd = "fp_add" - opFrAdd = "fr_add" - opFpSub = "fp_sub" - opFrSub = "fr_sub" - opFpMul = "fp_mul" - opFrMul = "fr_mul" - opFpMulSkipFinalSub = "fp_mul_skip_final_sub" - opFrMulSkipFinalSub = "fr_mul_skip_final_sub" - -proc setFieldConst(fc: var FieldConst, ctx: ContextRef, wordSize: WordSize, modBits: uint32, modulus: string) = - let wordTy = case wordSize - of size32: ctx.int32_t() - of size64: ctx.int64_t() - - let wordBitwidth = case wordSize - of size32: 32'u32 - of size64: 64'u32 - - let numWords = wordsRequired(modBits, wordBitwidth) - - fc.wordTy = wordTy - fc.fieldTy = array_t(wordTy, numWords) - - case wordSize - of size32: - let m = BigNum[uint32].fromHex(modBits, modulus) - fc.modulus.setlen(m.limbs.len) - for i in 0 ..< m.limbs.len: - fc.modulus[i] = ctx.int32_t().constInt(m.limbs[i]) - - fc.m0ninv = ctx.int32_t().constInt(m.negInvModWord()) - - of size64: - let m = BigNum[uint64].fromHex(modBits, modulus) - fc.modulus.setlen(m.limbs.len) - for i in 0 ..< m.limbs.len: - fc.modulus[i] = ctx.int64_t().constInt(m.limbs[i]) - - fc.m0ninv = ctx.int64_t().constInt(m.negInvModWord()) - - debug: doAssert numWords == fc.modulus.len.uint32 - fc.bits = modBits - fc.spareBits = uint8(numWords*wordBitwidth - modBits) - -proc init*( - C: type CurveMetadata, ctx: ContextRef, - prefix: string, wordSize: WordSize, - fpBits: uint32, fpMod: string, - frBits: uint32, frMod: string): CurveMetadata = - - result = C(prefix: prefix, wordSize: wordSize) - result.fp.setFieldConst(ctx, wordSize, fpBits, fpMod) - result.fr.setFieldConst(ctx, wordSize, frBits, frMod) - -proc genSymbol*(cm: CurveMetadata, opcode: Opcode): string {.inline.} = - cm.prefix & - (if cm.wordSize == size32: "32b_" else: "64b_") & - $opcode - -func getFieldType*(cm: CurveMetadata, field: Field): TypeRef {.inline.} = - if field == fp: - return cm.fp.fieldTy - else: - return cm.fr.fieldTy - -func getNumWords*(cm: CurveMetadata, field: Field): int {.inline.} = - case field - of fp: - return cm.fp.modulus.len - of fr: - return cm.fr.modulus.len - -func getModulus*(cm: CurveMetadata, field: Field): lent seq[ConstValueRef] {.inline.} = - case field - of fp: - return cm.fp.modulus - of fr: - return cm.fr.modulus - -func getMontgomeryNegInverse0*(cm: CurveMetadata, field: Field): lent ConstValueRef {.inline.} = - case field - of fp: - return cm.fp.m0ninv - of fr: - return cm.fr.m0ninv - -func getSpareBits*(cm: CurveMetadata, field: Field): uint8 {.inline.} = - if field == fp: - return cm.fp.sparebits - else: - return cm.fr.sparebits +proc wordTy*(fd: FieldDescriptor, value: SomeInteger) = + constInt(fd.wordTy, value) # ############################################################ # -# Syntax Sugar +# Aggregate Types # # ############################################################ @@ -309,34 +224,34 @@ func getSpareBits*(cm: CurveMetadata, field: Field): uint8 {.inline.} = type Array* = object builder: BuilderRef - p: ValueRef + buf*: ValueRef arrayTy: TypeRef elemTy: TypeRef int32_t: TypeRef -proc asArray*(builder: BuilderRef, arrayPtr: ValueRef, arrayTy: TypeRef): Array = +proc asArray*(asy: Assembler_LLVM, arrayPtr: ValueRef, arrayTy: TypeRef): Array = Array( - builder: builder, - p: arrayPtr, + builder: asy.br, + buf: arrayPtr, arrayTy: arrayTy, elemTy: arrayTy.getElementType(), int32_t: arrayTy.getContext().int32_t() ) -proc makeArray*(builder: BuilderRef, arrayTy: TypeRef): Array = +proc makeArray*(asy: Assembler_LLVM, arrayTy: TypeRef): Array = Array( - builder: builder, - p: builder.alloca(arrayTy), + builder: asy.br, + buf: asy.br.alloca(arrayTy), arrayTy: arrayTy, elemTy: arrayTy.getElementType(), int32_t: arrayTy.getContext().int32_t() ) -proc makeArray*(builder: BuilderRef, elemTy: TypeRef, len: uint32): Array = +proc makeArray*(asy: Assembler_LLVM, elemTy: TypeRef, len: uint32): Array = let arrayTy = array_t(elemTy, len) Array( - builder: builder, - p: builder.alloca(arrayTy), + builder: asy.br, + buf: asy.br.alloca(arrayTy), arrayTy: arrayTy, elemTy: elemTy, int32_t: arrayTy.getContext().int32_t() @@ -344,13 +259,276 @@ proc makeArray*(builder: BuilderRef, elemTy: TypeRef, len: uint32): Array = proc `[]`*(a: Array, index: SomeInteger): ValueRef {.inline.}= # First dereference the array pointer with 0, then access the `index` - let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.p, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)]) + let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)]) a.builder.load2(a.elemTy, pelem) proc `[]=`*(a: Array, index: SomeInteger, val: ValueRef) {.inline.}= - let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.p, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)]) + let pelem = a.builder.getElementPtr2_InBounds(a.arrayTy, a.buf, [ValueRef constInt(a.int32_t, 0), ValueRef constInt(a.int32_t, uint64 index)]) a.builder.store(val, pelem) -proc store*(builder: BuilderRef, dst: Array, src: Array) {.inline.}= - let v = builder.load2(src.arrayTy, src.p) - builder.store(v, dst.p) +proc store*(asy: Assembler_LLVM, dst: Array, src: Array) {.inline.}= + let v = asy.br.load2(src.arrayTy, src.buf) + asy.br.store(v, dst.buf) + +proc store*(asy: Assembler_LLVM, dst: Array, src: ValueRef) {.inline.}= + ## Heterogeneous store of i256 into 4xuint64 + doAssert asy.byteOrder == kLittleEndian + asy.br.store(src, dst.buf) + +# Conversion to native LLVM int +# ------------------------------- + +proc asLlvmIntPtr*(asy: Assembler_LLVM, a: Array): ValueRef = + doAssert asy.byteOrder == kLittleEndian, "Only little-endian machines are supported at the moment." + let bits = asy.datalayout.getSizeInBits(a.arrayTy) + let pInt = pointer_t(asy.ctx.int_t(uint32 bits)) + asy.br.bitcast(a.buf, pInt) + +proc asLlvmIntPtr*(asy: Assembler_LLVM, v: ValueRef, ty: TypeRef): ValueRef = + doAssert asy.byteOrder == kLittleEndian, "Only little-endian machines are supported at the moment." + let pInt = pointer_t(ty) + asy.br.bitcast(v, pInt) + +# ############################################################ +# +# Globals +# +# ############################################################ + +proc loadGlobal*(asy: Assembler_LLVM, name: string): ValueRef = + let g = asy.module.getGlobal(cstring name) + doAssert not result.isNil(), "The global '" & name & "' has not been declared in the module" + let ty = result.getTypeOf() + return asy.br.load2(ty, g, name = "g") + +proc defineGlobalConstant*( + asy: Assembler_LLVM, + name, section: string, + value: ValueRef, + ty: TypeRef, alignment = -1): ValueRef = + ## Declare a global constant + ## name: The name of the constant + ## section: globals are kept near each other in memory to improve locality + ## and avoid page-faults + ## an alignment of -1 leaves it at default for the ISA. + ## Otherwise configure the alignment in bytes. + ## + ## Return a pointer to the global + let g = asy.module.addGlobal(ty, cstring name) + g.setGlobal(value) + if alignment > 0: + g.setAlignment(cuint alignment) + # We intentionally keep globals internal: + # - for performance, this may avoids a translation table, + # they also can be inlined. + # - for forward-compatibility, for example to expose the modulus + # a function can handle non-matching in internal representation + # for example if we want to have different endianness of words on bigEndian machine. + # g.setLinkage(linkInternal) + g.setImmutable() + + # Group related globals in the same section + # This doesn't prevent globals from being optimized away + # if they are fully inlined or unused. + # This has the following benefits: + # - They might all be loaded in memory if they share a cacheline + # - If a section is unused, it can be garbage collected by the linker + g.setSection(cstring("ctt." & section & ".constants")) + return g + +# ############################################################ +# +# ISA configuration +# +# ############################################################ + +proc tagCudaKernel(asy: Assembler_LLVM, fn: ValueRef) = + ## Tag a function as a Cuda Kernel, i.e. callable from host + + let returnTy = fn.getTypeOf().getReturnType() + doAssert returnTy.isVoid(), block: + "Kernels must not return values but function returns " & $returnTy.getTypeKind() + + asy.module.addNamedMetadataOperand( + "nvvm.annotations", + asy.ctx.asValueRef(asy.ctx.metadataNode([ + fn.asMetadataRef(), + asy.ctx.metadataNode("kernel"), + asy.i32(1).asMetadataRef() + ])) + ) + +proc setPublic(asy: Assembler_LLVM, fn: ValueRef) = + case asy.backend + of bkAmdGpu: fn.setFnCallConv(AMDGPU_KERNEL) + of bkNvidiaPtx: asy.tagCudaKernel(fn) + else: discard + +# ############################################################ +# +# Function Definition and calling convention +# +# ############################################################ + +# Most architectures can pass up to 4 or 6 arguments directly into registers +# And we allow LLVM to use the best calling convention possible with "Fast". +# +# Recommendation: +# https://llvm.org/docs/Frontend/PerformanceTips.html +# +# Avoid creating values of aggregate types (i.e. structs and arrays). +# In particular, avoid loading and storing them, +# or manipulating them with insertvalue and extractvalue instructions. +# Instead, only load and store individual fields of the aggregate. +# +# There are some exceptions to this rule: +# - It is fine to use values of aggregate type in global variable initializers. +# - It is fine to return structs, if this is done to represent the return of multiple values in registers. +# - It is fine to work with structs returned by LLVM intrinsics, such as the with.overflow family of intrinsics. +# - It is fine to use aggregate types without creating values. For example, they are commonly used in getelementptr instructions or attributes like sret. +# +# Furthermore for aggregate types like struct we need to check the number of elements +# - https://groups.google.com/g/llvm-dev/c/CafdpEzOEp0 +# - https://stackoverflow.com/questions/27386912/prevent-clang-from-expanding-arguments-that-are-aggregate-types +# - https://people.freebsd.org/~obrien/amd64-elf-abi.pdf +# Though that might be overkill for functions tagged 'internal' linkage and 'Fast' CC +# +# Hopefully the compiler will remove the unnecessary lod/store/register movement, especially when inlining. + +proc wrapTypesForFnCall[N: static int]( + asy: AssemblerLLVM, + paramTypes: array[N, TypeRef] + ): tuple[wrapped, src: array[N, TypeRef]] = + ## Wrap parameters that would need more than 3x registers + ## into a pointer. + ## There are 2 such cases: + ## - An array/struct of more than 3 elements, for example 4x uint32 + ## - A type larger than 3x the pointer size, for example 4x uint64 + ## Vectors are passed by special SIMD registers + ## + ## Due to LLVM opaque pointers, we return the wrapped and src types + + for i in 0 ..< paramTypes.len: + let ty = paramTypes[i] + let tk = ty.getTypeKind() + if tk in {tkVector, tkScalableVector}: + ## There are special SIMD registers for vectors + result.wrapped[i] = paramTypes[i] + result.src[i] = paramTypes[i] + elif asy.datalayout.getAbiSize(ty).int32 > 3*asy.psize: + result.wrapped[i] = pointer_t(paramTypes[i]) + result.src[i] = paramTypes[i] + else: + case tk + of tkArray: + if ty.getArrayLength() >= 3: + result.wrapped[i] = pointer_t(paramTypes[i]) + result.src[i] = paramTypes[i] + else: + result.wrapped[i] = paramTypes[i] + result.src[i] = paramTypes[i] + of tkStruct: + if ty.getNumElements() >= 3: + result.wrapped[i] = pointer_t(paramTypes[i]) + result.src[i] = paramTypes[i] + else: + result.wrapped[i] = paramTypes[i] + result.src[i] = paramTypes[i] + else: + result.wrapped[i] = paramTypes[i] + result.src[i] = paramTypes[i] + +proc addAttributes(asy: Assembler_LLVM, fn: ValueRef, attrs: set[AttrKind]) = + for attr in attrs: + fn.addAttribute(kAttrFnIndex, asy.attrs[attr]) + + fn.addAttribute(kAttrFnIndex, asy.attrs[kHot]) + +template llvmFnDef[N: static int]( + asy: Assembler_LLVM, + name, sectionName: string, + returnType: TypeRef, + paramTypes: array[N, TypeRef], + internal: bool, + attrs: set[AttrKind], + body: untyped) = + ## This setups common prologue to implement a function in LLVM + ## Function parameters are available with the `llvmParams` magic variable + let paramsTys = asy.wrapTypesForFnCall(paramTypes) + + var fn = asy.module.getFunction(cstring name) + if fn.pointer.isNil(): + var savedLoc = asy.br.getInsertBlock() + + let fnTy = function_t(returnType, paramsTys.wrapped) + fn = asy.module.addFunction(cstring name, fnTy) + + asy.fns[name] = (fnTy, fn, internal) + + let blck = asy.ctx.appendBasicBlock(fn) + asy.br.positionAtEnd(blck) + + if savedLoc.pointer.isNil(): + # We're installing the first function + # of the call tree, return to its basic block + savedLoc = blck + + let llvmParams {.inject.} = unpackParams(asy.br, paramsTys) + template tagParameter(idx: int, attr: string) {.inject.} = + let a = asy.ctx.createAttr(attr) + fn.addAttribute(cint idx, a) + body + + if internal: + fn.setFnCallConv(Fast) + fn.setLinkage(linkInternal) + else: + asy.setPublic(fn) + fn.setSection(sectionName) + asy.addAttributes(fn, attrs) + + asy.br.positionAtEnd(savedLoc) + +template llvmInternalFnDef*[N: static int]( + asy: Assembler_LLVM, + name, sectionName: string, + returnType: TypeRef, + paramTypes: array[N, TypeRef], + attrs: set[AttrKind] = {}, + body: untyped) = + llvmFnDef(asy, name, sectionName, returnType, paramTypes, internal = true, attrs, body) + +template llvmPublicFnDef*[N: static int]( + asy: Assembler_LLVM, + name, sectionName: string, + returnType: TypeRef, + paramTypes: array[N, TypeRef], + body: untyped) = + llvmFnDef(asy, name, sectionName, returnType, paramTypes, internal = false, {}, body) + +proc callFn*( + asy: Assembler_LLVM, + name: string, + params: openArray[ValueRef]): ValueRef {.discardable.} = + + if asy.fns[name].ty.getReturnType().getTypeKind() == tkVoid: + result = asy.br.call2(asy.fns[name].ty, asy.fns[name].impl, params) + else: + result = asy.br.call2(asy.fns[name].ty, asy.fns[name].impl, params, cstring(name)) + + if asy.fns[name].internal: + result.setInstrCallConv(Fast) + +# ############################################################ +# +# Forward to Builder +# +# ############################################################ + +# {.experimental: "dotOperators".} dos not seem to work within templates?macros + +template load2*(asy: Assembler_LLVM, ty: TypeRef, `ptr`: ValueRef, name: cstring = ""): ValueRef = + asy.br.load2(ty, `ptr`, name) + +template store*(asy: Assembler_LLVM, dst, src: ValueRef, name: cstring = "") = + asy.br.store(src, dst) diff --git a/constantine/math_compiler/pub_fields.nim b/constantine/math_compiler/pub_fields.nim new file mode 100644 index 00000000..51f92da3 --- /dev/null +++ b/constantine/math_compiler/pub_fields.nim @@ -0,0 +1,30 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + constantine/platforms/llvm/llvm, + ./ir, + ./impl_fields_globals, + ./impl_fields_sat + +proc genFpAdd*(asy: Assembler_LLVM, fd: FieldDescriptor): string = + ## Generate a public field addition proc + ## with signature + ## void name(FieldType r, FieldType a, FieldType b) + ## with r the result and a, b the operants + ## and return the corresponding name to call it + + let name = fd.name & "_add" + asy.llvmPublicFnDef(name, "ctt." & fd.name, asy.void_t, [fd.fieldTy, fd.fieldTy, fd.fieldTy]): + let M = asy.getModulusPtr(fd) + + let (r, a, b) = llvmParams + asy.modadd(fd, r, a, b, M) + asy.br.retVoid() + + return name diff --git a/constantine/platforms/abis/llvm_abi.nim b/constantine/platforms/abis/llvm_abi.nim index 7fcf342f..81b19b20 100644 --- a/constantine/platforms/abis/llvm_abi.nim +++ b/constantine/platforms/abis/llvm_abi.nim @@ -38,12 +38,14 @@ type ContextRef* = distinct pointer ModuleRef* = distinct pointer TargetRef* = distinct pointer + TargetDataRef* = distinct pointer ExecutionEngineRef* = distinct pointer TargetMachineRef* = distinct pointer PassBuilderOptionsRef* = distinct pointer TypeRef* = distinct pointer ValueRef* = distinct pointer MetadataRef = distinct pointer + AttributeRef* = distinct pointer LLVMstring = distinct cstring ErrorMessageString = distinct cstring ## A string with a buffer owned by LLVM @@ -122,16 +124,19 @@ proc verify(module: ModuleRef, failureAction: VerifierFailureAction, msg: var LL {.push used.} proc initializeX86AsmPrinter() {.importc: "LLVMInitializeX86AsmPrinter".} +proc initializeX86AsmParser() {.importc: "LLVMInitializeX86AsmParser".} proc initializeX86Target() {.importc: "LLVMInitializeX86Target".} proc initializeX86TargetInfo() {.importc: "LLVMInitializeX86TargetInfo".} proc initializeX86TargetMC() {.importc: "LLVMInitializeX86TargetMC".} proc initializeNVPTXAsmPrinter() {.importc: "LLVMInitializeNVPTXAsmPrinter".} +proc initializeNVPTXAsmParser() {.importc: "LLVMInitializeNVPTXAsmParser".} proc initializeNVPTXTarget() {.importc: "LLVMInitializeNVPTXTarget".} proc initializeNVPTXTargetInfo() {.importc: "LLVMInitializeNVPTXTargetInfo".} proc initializeNVPTXTargetMC() {.importc: "LLVMInitializeNVPTXTargetMC".} proc initializeAMDGPUAsmPrinter() {.importc: "LLVMInitializeAMDGPUAsmPrinter".} +proc initializeAMDGPUAsmParser() {.importc: "LLVMInitializeAMDGPUAsmParser".} proc initializeAMDGPUTarget() {.importc: "LLVMInitializeAMDGPUTarget".} proc initializeAMDGPUTargetInfo() {.importc: "LLVMInitializeAMDGPUTargetInfo".} proc initializeAMDGPUTargetMC() {.importc: "LLVMInitializeAMDGPUTargetMC".} @@ -186,19 +191,33 @@ type CodeGenFileType* {.size: sizeof(cint).} = enum AssemblyFile, ObjectFile - TargetDataRef* = distinct pointer - TargetLibraryInfoRef* = distinct pointer - # "" proc createTargetMachine*( target: TargetRef, triple, cpu, features: cstring, level: CodeGenOptLevel, reloc: RelocMode, codeModel: CodeModel): TargetMachineRef {.importc: "LLVMCreateTargetMachine".} proc dispose*(m: TargetMachineRef) {.importc: "LLVMDisposeTargetMachine".} -proc createTargetDataLayout*(t: TargetMachineRef): TargetDataRef {.importc: "LLVMCreateTargetDataLayout".} +proc getDataLayout*(t: TargetMachineRef): TargetDataRef {.importc: "LLVMCreateTargetDataLayout".} +proc getDataLayout*(module: ModuleRef): TargetDataRef {.importc: "LLVMGetModuleDataLayout".} proc dispose*(m: TargetDataRef) {.importc: "LLVMDisposeTargetData".} proc setDataLayout*(module: ModuleRef, dataLayout: TargetDataRef) {.importc: "LLVMSetModuleDataLayout".} +proc getPointerSize*(datalayout: TargetDataRef): cuint {.importc: "LLVMPointerSize".} +proc getSizeInBits*(datalayout: TargetDataRef, ty: TypeRef): culonglong {.importc: "LLVMSizeOfTypeInBits".} + ## Computes the size of a type in bits for a target. +proc getStoreSize*(datalayout: TargetDataRef, ty: TypeRef): culonglong {.importc: "LLVMStoreSizeOfType".} + ## Computes the storage size of a type in bytes for a target. +proc getAbiSize*(datalayout: TargetDataRef, ty: TypeRef): culonglong {.importc: "LLVMABISizeOfType".} + ## Computes the ABI size of a type in bytes for a target. + +type + ByteOrder {.size: sizeof(cint).} = enum + kBigEndian + kLittleEndian + +proc getEndianness*(datalayout: TargetDataref): ByteOrder {.importc: "LLVMByteOrder".} + + proc targetMachineEmitToFile*(t: TargetMachineRef, m: ModuleRef, fileName: cstring, codegen: CodeGenFileType, errorMessage: var LLVMstring): LLVMBool {.importc: "LLVMTargetMachineEmitToFile".} proc targetMachineEmitToMemoryBuffer*(t: TargetMachineRef, m: ModuleRef, @@ -280,12 +299,17 @@ proc getIntTypeWidth*(ty: TypeRef): uint32 {.importc: "LLVMGetIntTypeWidth".} proc struct_t*( ctx: ContextRef, elemTypes: openArray[TypeRef], - packed: LlvmBool): TypeRef {.wrapOpenArrayLenType: cuint, importc: "LLVMStructTypeInContext".} + packed = LlvmBool(false)): TypeRef {.wrapOpenArrayLenType: cuint, importc: "LLVMStructTypeInContext".} proc array_t*(elemType: TypeRef, elemCount: uint32): TypeRef {.importc: "LLVMArrayType".} +proc vector_t*(elemType: TypeRef, elemCount: uint32): TypeRef {.importc: "LLVMVectorType".} + ## Create a SIMD vector type (for SSE, AVX or Neon for example) proc pointerType(elementType: TypeRef; addressSpace: cuint): TypeRef {.used, importc: "LLVMPointerType".} proc getElementType*(arrayOrVectorTy: TypeRef): TypeRef {.importc: "LLVMGetElementType".} +proc getArrayLength*(arrayTy: TypeRef): uint64 {.importc: "LLVMGetArrayLength2".} +proc getNumElements*(structTy: TypeRef): cuint {.importc: "LLVMCountStructElementTypes".} +proc getVectorSize*(vecTy: TypeRef): cuint {.importc: "LLVMGetVectorSize".} # Functions # ------------------------------------------------------------ @@ -537,6 +561,44 @@ type # The highest possible ID. Must be some 2^k - 1. MaxID = 1023 +type + Linkage {.size: sizeof(cint).} = enum + # https://web.archive.org/web/20240224034505/https://bluesadi.me/2024/01/05/Linkage-types-in-LLVM/ + # Weak linkage means unreferenced globals may not be discarded when linking. + # + # Also relevant: https://stackoverflow.com/a/55599037 + # The necessity of making code relocatable in order allow shared objects to be loaded a different addresses + # in different process means that statically allocated variables, + # whether they have global or local scope, + # can't be accessed with directly with a single instruction on most architectures. + # The only exception I know of is the 64-bit x86 architecture, as you see above. + # It supports memory operands that are both PC-relative and have large 32-bit displacements + # that can reach any variable defined in the same component. + linkExternal, ## Externally visible function + linkAvailableExternally, ## no description + linkOnceAny, ## Keep one copy of function when linking (inline) + linkOnceODR, ## Same, but only replaced by something equivalent. (ODR: one definition rule) + linkOnceODRAutoHide, ## Obsolete + linkWeakAny, ## Keep one copy of function when linking (weak) + linkWeakODR, ## Same, but only replaced by something equivalent. + linkAppending, ## Special purpose, only applies to global arrays + linkInternal, ## Rename collisions when linking (static functions) + linkPrivate, ## Like Internal, but omit from symbol table + linkDLLImport, ## Obsolete + linkDLLExport, ## Obsolete + linkExternalWeak, ## ExternalWeak linkage description + linkGhost, ## Obsolete + linkCommon, ## Tentative definitions + linkLinkerPrivate, ## Like Private, but linker removes. + linkLinkerPrivateWeak ## Like LinkerPrivate, but is weak. + + Visibility {.size: sizeof(cint).} = enum + # Note: Function with internal or private linkage must have default visibility + visDefault + visHidden + visProtected + + proc function_t*( returnType: TypeRef, paramTypes: openArray[TypeRef], @@ -546,11 +608,33 @@ proc addFunction*(m: ModuleRef, name: cstring, ty: TypeRef): ValueRef {.importc: ## Declare a function `name` in a module. ## Returns a handle to specify its instructions +proc getFunction*(m: ModuleRef, name: cstring): ValueRef {.importc: "LLVMGetNamedFunction".} + ## Get a function by name from the curent module. + ## Return nil if not found. + proc getReturnType*(functionTy: TypeRef): TypeRef {.importc: "LLVMGetReturnType".} proc countParamTypes*(functionTy: TypeRef): uint32 {.importc: "LLVMCountParamTypes".} -proc getCallingConvention*(function: ValueRef): CallingConvention {.importc: "LLVMGetFunctionCallConv".} -proc setCallingConvention*(function: ValueRef, cc: CallingConvention) {.importc: "LLVMSetFunctionCallConv".} +proc getCalledFunctionType*(fn: ValueRef): TypeRef {.importc: "LLVMGetCalledFunctionType".} + +proc getFnCallConv*(function: ValueRef): CallingConvention {.importc: "LLVMGetFunctionCallConv".} +proc setFnCallConv*(function: ValueRef, cc: CallingConvention) {.importc: "LLVMSetFunctionCallConv".} + +proc getInstrCallConv*(instr: ValueRef): CallingConvention {.importc: "LLVMGetInstructionCallConv".} +proc setInstrCallConv*(instr: ValueRef, cc: CallingConvention) {.importc: "LLVMSetInstructionCallConv".} + +type + AttributeIndex* {.size: sizeof(cint).} = enum + ## Attribute index is either -1 for the function + ## 0 for the return value + ## or 1..n for each function parameter + kAttrFnIndex = -1 + kAttrRetIndex = 0 + +proc toAttrId*(name: openArray[char]): cuint {.importc: "LLVMGetEnumAttributeKindForName".} +proc toAttr*(ctx: ContextRef, attr_id: uint64, val = 0'u64): AttributeRef {.importc: "LLVMCreateEnumAttribute".} +proc addAttribute*(fn: ValueRef, index: cint, attr: AttributeRef) {.importc: "LLVMAddAttributeAtIndex".} +proc addAttribute*(fn: ValueRef, index: AttributeIndex, attr: AttributeRef) {.importc: "LLVMAddAttributeAtIndex".} # ############################################################ # @@ -560,6 +644,18 @@ proc setCallingConvention*(function: ValueRef, cc: CallingConvention) {.importc: # {.push header: "".} +proc getGlobal*(module: ModuleRef, name: cstring): ValueRef {.importc: "LLVMGetNamedGlobal".} +proc addGlobal*(module: ModuleRef, ty: TypeRef, name: cstring): ValueRef {.importc: "LLVMAddGlobal".} +proc setGlobal*(globalVar: ValueRef, constantVal: ValueRef) {.importc: "LLVMSetInitializer".} +proc setImmutable*(globalVar: ValueRef, immutable = LlvmBool(true)) {.importc: "LLVMSetGlobalConstant".} + +proc getGlobalParent*(global: ValueRef): ModuleRef {.importc: "LLVMGetGlobalParent".} + +proc setLinkage*(global: ValueRef, linkage: Linkage) {.importc: "LLVMSetLinkage".} +proc setVisibility*(global: ValueRef, vis: Visibility) {.importc: "LLVMSetVisibility".} +proc setAlignment*(v: ValueRef, bytes: cuint) {.importc: "LLVMSetAlignment".} +proc setSection*(global: ValueRef, section: cstring) {.importc: "LLVMSetSection".} + proc getTypeOf*(v: ValueRef): TypeRef {.importc: "LLVMTypeOf".} proc getValueName2(v: ValueRef, rLen: var csize_t): cstring {.used, importc: "LLVMGetValueName2".} ## Returns the name of a valeu if it exists. @@ -578,6 +674,9 @@ proc toLLVMstring(v: ValueRef): LLVMstring {.used, importc: "LLVMPrintValueToStr # https://llvm.org/doxygen/group__LLVMCCoreValueConstant.html proc constInt(ty: TypeRef, n: culonglong, signExtend: LlvmBool): ValueRef {.used, importc: "LLVMConstInt".} +proc constIntOfArbitraryPrecision(ty: TypeRef, numWords: cuint, words: ptr uint64): ValueRef {.used, importc: "LLVMConstIntOfArbitraryPrecision".} +proc constIntOfStringAndSize(ty: TypeRef, text: openArray[char], radix: uint8): ValueRef {.used, importc: "LLVMConstIntOfStringAndSize".} + proc constReal*(ty: TypeRef, n: cdouble): ValueRef {.importc: "LLVMConstReal".} proc constNull*(ty: TypeRef): ValueRef {.importc: "LLVMConstNull".} @@ -587,6 +686,15 @@ proc constArray*( constantVals: openArray[ValueRef], ): ValueRef {.wrapOpenArrayLenType: cuint, importc: "LLVMConstArray".} +# Undef & Poison +# ------------------------------------------------------------ +# https://llvm.org/devmtg/2020-09/slides/Lee-UndefPoison.pdf + +proc poison*(ty: TypeRef): ValueRef {.importc: "LLVMGetPoison".} +proc undef*(ty: TypeRef): ValueRef {.importc: "LLVMGetUndef".} + + + # ############################################################ # # IR builder @@ -601,17 +709,17 @@ type ## An instruction builder represents a point within a basic block and is ## the exclusive means of building instructions using the C interface. - IntPredicate* {.size: sizeof(cint).} = enum - IntEQ = 32 ## equal - IntNE ## not equal - IntUGT ## unsigned greater than - IntUGE ## unsigned greater or equal - IntULT ## unsigned less than - IntULE ## unsigned less or equal - IntSGT ## signed greater than - IntSGE ## signed greater or equal - IntSLT ## signed less than - IntSLE ## signed less or equal + Predicate* {.size: sizeof(cint).} = enum + kEQ = 32 ## equal + kNE ## not equal + kUGT ## unsigned greater than + kUGE ## unsigned greater or equal + kULT ## unsigned less than + kULE ## unsigned less or equal + kSGT ## signed greater than + kSGE ## signed greater or equal + kSLT ## signed less than + kSLE ## signed less or equal InlineAsmDialect* {.size: sizeof(cint).} = enum InlineAsmDialectATT @@ -622,7 +730,7 @@ type # Instantiation # ------------------------------------------------------------ -proc appendBasicBlock*(ctx: ContextRef, fn: ValueRef, name: cstring): BasicBlockRef {.importc: "LLVMAppendBasicBlockInContext".} +proc appendBasicBlock*(ctx: ContextRef, fn: ValueRef, name: cstring = ""): BasicBlockRef {.importc: "LLVMAppendBasicBlockInContext".} ## Append a basic block to the end of a function proc createBuilder*(ctx: ContextRef): BuilderRef {.importc: "LLVMCreateBuilderInContext".} @@ -675,19 +783,27 @@ proc call2*( proc add*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildAdd".} proc addNSW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNSWAdd".} + ## Addition No Signed Wrap, i.e. guaranteed to not overflow proc addNUW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNUWAdd".} + ## Addition No Unsigned Wrap, i.e. guaranteed to not overflow proc sub*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildSub".} proc subNSW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNSWSub".} + ## Substraction No Signed Wrap, i.e. guaranteed to not overflow proc subNUW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNUWSub".} + ## Substraction No Unsigned Wrap, i.e. guaranteed to not overflow proc neg*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNeg".} proc negNSW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNSWNeg".} + ## Negation No Signed Wrap, i.e. guaranteed to not overflow proc negNUW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNUWNeg".} + ## Negation No Unsigned Wrap, i.e. guaranteed to not overflow proc mul*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildMul".} proc mulNSW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNSWMul".} + ## Multiplication No Signed Wrap, i.e. guaranteed to not overflow proc mulNUW*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNUWMul".} + ## Multiplication No Unsigned Wrap, i.e. guaranteed to not overflow proc divU*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildUDiv".} proc divU_exact*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildExactUDiv".} @@ -706,9 +822,9 @@ proc `xor`*(builder: BuilderRef, lhs, rhs: ValueRef, name: cstring = ""): ValueR proc `not`*(builder: BuilderRef, val: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildNot".} proc select*(builder: BuilderRef, condition, then, otherwise: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildSelect".} -proc icmp*(builder: BuilderRef, op: IntPredicate, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildICmp".} +proc icmp*(builder: BuilderRef, op: Predicate, lhs, rhs: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildICmp".} -proc bitcast*(builder: BuilderRef, val: ValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildBitcast".} +proc bitcast*(builder: BuilderRef, val: ValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildBitCast".} proc trunc*(builder: BuilderRef, val: ValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildTrunc".} proc zext*(builder: BuilderRef, val: ValueRef, destTy: TypeRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildZExt".} ## Zero-extend @@ -722,7 +838,7 @@ proc alloca*(builder: BuilderRef, ty: TypeRef, name: cstring = ""): ValueRef {.i proc allocaArray*(builder: BuilderRef, ty: TypeRef, length: ValueRef, name: cstring = ""): ValueRef {.importc: "LLVMBuildArrayAlloca".} proc extractValue*(builder: BuilderRef, aggVal: ValueRef, index: uint32, name: cstring = ""): ValueRef {.importc: "LLVMBuildExtractValue".} -proc insertValue*(builder: BuilderRef, aggVal: ValueRef, eltVal: ValueRef, index: uint32, name: cstring = ""): ValueRef {.discardable, importc: "LLVMBuildInsertValue".} +proc insertValue*(builder: BuilderRef, aggVal: ValueRef, eltVal: ValueRef, index: uint32, name: cstring = ""): ValueRef {.importc: "LLVMBuildInsertValue".} proc getElementPtr2*( builder: BuilderRef, diff --git a/constantine/platforms/intrinsics/extended_precision_64bit_uint128.nim b/constantine/platforms/intrinsics/extended_precision_64bit_uint128.nim index 345acd8f..f0e438a2 100644 --- a/constantine/platforms/intrinsics/extended_precision_64bit_uint128.nim +++ b/constantine/platforms/intrinsics/extended_precision_64bit_uint128.nim @@ -85,7 +85,7 @@ func muladd2*(hi, lo: var Ct[uint64], a, b, c1, c2: Ct[uint64]) {.inline.}= {.emit:["*",lo, " = (NU64)", dblPrec,";"].} func smul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} = - ## Extended precision multiplication + ## Signed extended precision multiplication ## (hi, lo) <- a*b ## ## Inputs are intentionally unsigned @@ -103,4 +103,4 @@ func smul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} = {.emit:[lo, " = (NU64)", dblPrec,";"].} else: {.emit:["*",hi, " = (NU64)(", dblPrec," >> ", 64'u64, ");"].} - {.emit:["*",lo, " = (NU64)", dblPrec,";"].} \ No newline at end of file + {.emit:["*",lo, " = (NU64)", dblPrec,";"].} diff --git a/constantine/platforms/intrinsics/extended_precision_x86_64_msvc.nim b/constantine/platforms/intrinsics/extended_precision_x86_64_msvc.nim index 3216b859..06bde04d 100644 --- a/constantine/platforms/intrinsics/extended_precision_x86_64_msvc.nim +++ b/constantine/platforms/intrinsics/extended_precision_x86_64_msvc.nim @@ -77,12 +77,12 @@ func smul128(a, b: Ct[uint64], hi: var Ct[uint64]): Ct[uint64] {.importc:"_mul12 ## as we use their unchecked raw representation for cryptography func smul*(hi, lo: var Ct[uint64], a, b: Ct[uint64]) {.inline.} = - ## Extended precision multiplication + ## Signed extended precision multiplication ## (hi, lo) <- a*b ## ## Inputs are intentionally unsigned ## as we use their unchecked raw representation for cryptography - ## + ## ## This is constant-time on most hardware ## See: https://www.bearssl.org/ctmul.html - lo = smul128(a, b, hi) \ No newline at end of file + lo = smul128(a, b, hi) diff --git a/constantine/platforms/llvm/nvidia_inlineasm.nim b/constantine/platforms/llvm/asm_nvidia.nim similarity index 100% rename from constantine/platforms/llvm/nvidia_inlineasm.nim rename to constantine/platforms/llvm/asm_nvidia.nim diff --git a/research/codegen/x86_inlineasm.nim b/constantine/platforms/llvm/asm_x86.nim similarity index 100% rename from research/codegen/x86_inlineasm.nim rename to constantine/platforms/llvm/asm_x86.nim diff --git a/constantine/platforms/llvm/llvm.nim b/constantine/platforms/llvm/llvm.nim index d222a306..e87e7989 100644 --- a/constantine/platforms/llvm/llvm.nim +++ b/constantine/platforms/llvm/llvm.nim @@ -7,6 +7,7 @@ # at your option. This file may not be copied, modified, or distributed except according to those terms. import constantine/platforms/abis/llvm_abi {.all.} +import std/macros export llvm_abi # ############################################################ @@ -146,11 +147,17 @@ proc emitTo*[T: string or seq[byte]](t: TargetMachineRef, m: ModuleRef, codegen: # Builder # ------------------------------------------------------------ +proc getCurrentFunction*(builder: BuilderRef): ValueRef = + builder.getInsertBlock().getBasicBlockParent() + proc getContext*(builder: BuilderRef): ContextRef = # LLVM C API does not expose IRBuilder.getContext() # making this unnecessary painful # https://github.com/llvm/llvm-project/issues/59875 - builder.getInsertBlock().getBasicBlockParent().getTypeOf().getContext() + builder.getCurrentFunction().getTypeOf().getContext() + +proc getCurrentModule*(builder: BuilderRef): ModuleRef = + builder.getCurrentFunction().getGlobalParent() # Types # ------------------------------------------------------------ @@ -172,12 +179,37 @@ proc array_t*(elemType: TypeRef, elemCount: SomeInteger): TypeRef {.inline.}= proc function_t*(returnType: TypeRef, paramTypes: openArray[TypeRef]): TypeRef {.inline.} = function_t(returnType, paramTypes, isVarArg = LlvmBool(false)) +# Functions +# ------------------------------------------------------------ + +proc createAttr*(ctx: ContextRef, name: openArray[char]): AttributeRef = + ctx.toAttr(name.toAttrId()) + +proc toTypes*[N: static int](v: array[N, ValueRef]): array[N, TypeRef] = + for i in 0 ..< v.len: + result[i] = v[i].getTypeOf() + +macro unpackParams*[N: static int]( + br: BuilderRef, + paramsTys: tuple[wrapped, src: array[N, TypeRef]]): untyped = + ## Unpack function parameters. + ## + ## The new function basic block MUST be setup before calling unpackParams. + ## + ## In the future we may automatically unwrap types. + + result = nnkPar.newTree() + for i in 0 ..< N: + result.add quote do: + # let tySrc = `paramsTys`.src[`i`] + # let tyCC = `paramsTys`.wrapped[`i`] + let fn = `br`.getCurrentFunction() + fn.getParam(uint32 `i`) + # Values # ------------------------------------------------------------ -type - ConstValueRef* = distinct ValueRef - AnyValueRef* = ValueRef or ConstValueRef +proc isNil*(v: ValueRef): bool {.borrow.} proc getName*(v: ValueRef): string = var rLen: csize_t @@ -186,7 +218,5 @@ proc getName*(v: ValueRef): string = result = newString(rLen.int) copyMem(result[0].addr, rStr, rLen.int) -proc constInt*(ty: TypeRef, n: uint64, signExtend = false): ConstValueRef {.inline.} = - ConstValueRef constInt(ty, culonglong(n), LlvmBool(signExtend)) - -proc getTypeOf*(v: ConstValueRef): TypeRef {.borrow.} +proc constInt*(ty: TypeRef, n: SomeInteger, signExtend = false): ValueRef {.inline.} = + constInt(ty, culonglong(n), LlvmBool(signExtend)) diff --git a/constantine/platforms/llvm/super_instructions.nim b/constantine/platforms/llvm/super_instructions.nim new file mode 100644 index 00000000..08646288 --- /dev/null +++ b/constantine/platforms/llvm/super_instructions.nim @@ -0,0 +1,392 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import ./llvm + +# ############################################################ +# +# LLVM IR super-instructions +# +# ############################################################ + +# This defines a collection of LLVM IR super-instructions +# Ideally those super-instructions compile-down +# to ISA optimized single instructions +# +# To ensure this, tests can be consulted at: +# https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/ + +# Add-carry: +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/add-of-carry.ll +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/addcarry.ll +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/addcarry2.ll +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/adx-intrinsics.ll +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/adx-intrinsics-upgrade.ll +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/apx/adc.ll +# +# Sub-borrow +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/sub-with-overflow.ll +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/AArch64/cgp-usubo.ll +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/cgp-usubo.ll +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/apx/sbb.ll +# +# Multiplication +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/mulx32.ll +# - https://github.com/llvm/llvm-project/blob/llvmorg-18.1.8/llvm/test/CodeGen/X86/mulx64.ll + +# Warning 1: +# +# There is no guarantee of constant-time with LLVM IR +# It MAY introduce branches. +# For workload that involves private keys or secrets +# assembly MUST be used +# +# Alternatively an assembly source file must be generated +# and checked in the repo to avoid regressions should +# the compiler "progress" +# +# - https://github.com/mratsim/constantine/wiki/Constant-time-arithmetics#fighting-the-compiler +# - https://blog.cr.yp.to/20240803-clang.html +# - https://www.cl.cam.ac.uk/~rja14/Papers/whatyouc.pdf +# +# Warning 2: +# +# Unfortunately implementing unrolled bigint arithmetic using word size +# is fraught with perils for add-carry / sub-borrow +# AMDGPU crash: https://github.com/llvm/llvm-project/issues/102058 +# ARM64 missed optim: https://github.com/llvm/llvm-project/issues/102062 +# +# and while using @llvm.usub.with.overflow.i64 allows ARM64 to solve the missing optimization +# it is also missed on AMDGPU (or nvidia) + +proc hi(bld: BuilderRef, val: ValueRef, baseTy: TypeRef, oversize: uint32, prefix: string): ValueRef = + let ctx = bld.getContext() + let bits = baseTy.getIntTypeWidth() + let overTy = ctx.int_t(bits + oversize) + + # %hi_shift_1 = zext i8 64 to i128 + let s = constInt(ctx.int8_t(), oversize) + let shift = bld.zext(s, overTy, name = cstring(prefix & "S_")) + # %hiLarge_1 = lshr i128 %input, %hi_shift_1 + let hiLarge = bld.lshr(val, shift, name = cstring(prefix & "L_")) + # %hi_1 = trunc i128 %hiLarge_1 to i64 + let hi = bld.trunc(hiLarge, baseTy, name = cstring(prefix & "_")) + + return hi + +const SectionName = "ctt.superinstructions" + +proc getInstrName(baseName: string, ty: TypeRef, builtin = false): string = + var w, v: int # Wordsize and vector size + if ty.getTypeKind() == tkInteger: + w = int ty.getIntTypeWidth() + v = 1 + elif ty.getTypeKind() == tkVector: + v = int ty.getVectorSize() + w = int ty.getElementType().getIntTypeWidth() + else: + doAssert false, "Invalid input type: " & $ty + + return baseName & + (if v != 1: ".v" & $v else: ".") & + (if builtin: "i" else: "u") & $w + + +proc def_llvm_add_overflow*(ctx: ContextRef, m: ModuleRef, wordTy: TypeRef) = + let name = "llvm.uadd.with.overflow".getInstrName(wordTy, builtin = true) + + let br {.inject.} = ctx.createBuilder() + defer: br.dispose() + + var fn = m.getFunction(cstring name) + if fn.pointer.isNil(): + let retTy = ctx.struct_t([wordTy, ctx.int1_t()]) + let fnTy = function_t(retTy, [wordTy, wordTy]) + discard m.addFunction(cstring name, fnTy) + +proc llvm_add_overflow*(br: BuilderRef, a, b: ValueRef, name = ""): tuple[carryOut, r: ValueRef] = + ## (cOut, result) <- a+b+cIn + let ty = a.getTypeOf() + let intrin_name = "llvm.uadd.with.overflow".getInstrName(ty, builtin = true) + + let fn = br.getCurrentModule().getFunction(cstring intrin_name) + doAssert not fn.pointer.isNil, "Function '" & intrin_name & "' does not exist in the module\n" + + let ctx = br.getContext() + + let retTy = ctx.struct_t([ty, ctx.int1_t()]) + let fnTy = function_t(retTy, [ty, ty]) + let addo = br.call2(fnTy, fn, [a, b], cstring name) + let lo = br.extractValue(addo, 0, cstring(name & ".lo")) + let cOut = br.extractValue(addo, 1, cstring(name & ".carry")) + return (cOut, lo) + +proc def_llvm_sub_overflow*(ctx: ContextRef, m: ModuleRef, wordTy: TypeRef) = + let name = "llvm.usub.with.overflow".getInstrName(wordTy, builtin = true) + + let br {.inject.} = ctx.createBuilder() + defer: br.dispose() + + var fn = m.getFunction(cstring name) + if fn.pointer.isNil(): + let retTy = ctx.struct_t([wordTy, ctx.int1_t()]) + let fnTy = function_t(retTy, [wordTy, wordTy]) + discard m.addFunction(cstring name, fnTy) + +proc llvm_sub_overflow*(br: BuilderRef, a, b: ValueRef, name = ""): tuple[borrowOut, r: ValueRef] = + ## (cOut, result) <- a+b+cIn + let ty = a.getTypeOf() + let intrin_name = "llvm.usub.with.overflow".getInstrName(ty, builtin = true) + + let fn = br.getCurrentModule().getFunction(cstring intrin_name) + doAssert not fn.pointer.isNil, "Function '" & intrin_name & "' does not exist in the module\n" + + let ctx = br.getContext() + + let retTy = ctx.struct_t([ty, ctx.int1_t()]) + let fnTy = function_t(retTy, [ty, ty]) + let subo = br.call2(fnTy, fn, [a, b], cstring name) + let lo = br.extractValue(subo, 0, cstring(name & ".lo")) + let bOut = br.extractValue(subo, 1, cstring(name & ".borrow")) + return (bOut, lo) + +template defSuperInstruction[N: static int]( + module: ModuleRef, baseName: string, + returnType: TypeRef, + paramTypes: array[N, TypeRef], + body: untyped) = + ## Boilerplate for super instruction definition + ## Creates a magic `llvmParams` variable to tuple-destructure + ## to access the inputs + ## and `br` for building the instructions + let ty = paramTypes[0] + let name = baseName.getInstrName(ty) + + let ctx = module.getContext() + let br {.inject.} = ctx.createBuilder() + defer: br.dispose() + + var fn = module.getFunction(cstring name) + if fn.pointer.isNil(): + let fnTy = function_t(returnType, paramTypes) + fn = module.addFunction(cstring name, fnTy) + let blck = ctx.appendBasicBlock(fn) + br.positionAtEnd(blck) + + let llvmParams {.inject.} = unpackParams(br, (paramTypes, paramTypes)) + template tagParameter(idx: int, attr: string) {.inject, used.} = + let a = asy.ctx.createAttr(cstring attr) + fn.addAttribute(cint idx, a) + body + + fn.setFnCallConv(Fast) + fn.setLinkage(linkInternal) + fn.setSection(SectionName) + fn.addAttribute(kAttrFnIndex, ctx.createAttr("alwaysinline")) + +proc def_addcarry*(ctx: ContextRef, m: ModuleRef, carryTy, wordTy: TypeRef) = + ## Define (carryOut, result) <- a+b+carryIn + + let retType = ctx.struct_t([carryTy, wordTy]) + let inType = [wordTy, wordTy, carryTy] + + m.defSuperInstruction("addcarry", retType, inType): + let (a, b, carryIn) = llvmParams + + let (carry0, add) = br.llvm_add_overflow(a, b, "a_plus_b") + let cIn = br.zext(carryIn, wordTy, name = "carryIn") + let (carry1, adc) = br.llvm_add_overflow(cIn, add, "a_plus_b_plus_cIn") + let carryOut = br.`or`(carry0, carry1, name = "carryOut") + + var ret = br.insertValue(poison(retType), adc, 1, "lo") + ret = br.insertValue(ret, carryOut, 0, "ret") + br.ret(ret) + +proc addcarry*(br: BuilderRef, a, b, carryIn: ValueRef): tuple[carryOut, r: ValueRef] = + ## (cOut, result) <- a+b+cIn + let ty = a.getTypeOf() + let tyC = carryIn.getTypeOf() + let name = "addcarry".getInstrName(ty) + + let fn = br.getCurrentModule().getFunction(cstring name) + doAssert not fn.pointer.isNil, "Function '" & name & "' does not exist in the module\n" + + let retTy = br.getContext().struct_t([tyC, ty]) + let fnTy = function_t(retTy, [ty, ty, tyC]) + let adc = br.call2(fnTy, fn, [a, b, carryIn], name = "adc") + adc.setInstrCallConv(Fast) + let lo = br.extractValue(adc, 1, name = "adc.lo") + let cOut = br.extractValue(adc, 0, name = "adc.carry") + return (cOut, lo) + +proc def_subborrow*(ctx: ContextRef, m: ModuleRef, borrowTy, wordTy: TypeRef) = + ## Define (borrowOut, result) <- a-b-borrowIn + + let retType = ctx.struct_t([borrowTy, wordTy]) + let inType = [wordTy, wordTy, borrowTy] + + m.defSuperInstruction("subborrow", retType, inType): + let (a, b, borrowIn) = llvmParams + + let (borrow0, sub) = br.llvm_sub_overflow(a, b, "a_minus_b") + let bIn = br.zext(borrowIn, wordTy, name = "borrowIn") + let (borrow1, sbb) = br.llvm_sub_overflow(sub, bIn, "sbb") + let borrowOut = br.`or`(borrow0, borrow1, name = "borrowOut") + + var ret = br.insertValue(poison(retType), sbb, 1, "lo") + ret = br.insertValue(ret, borrowOut, 0, "ret") + br.ret(ret) + +proc subborrow*(br: BuilderRef, a, b, borrowIn: ValueRef): tuple[borrowOut, r: ValueRef] = + ## (cOut, result) <- a+b+cIn + let ty = a.getTypeOf() + let tyC = borrowIn.getTypeOf() + let name = "subborrow".getInstrName(ty) + + let fn = br.getCurrentModule().getFunction(cstring name) + doAssert not fn.pointer.isNil, "Function '" & name & "' does not exist in the module\n" + + let retTy = br.getContext().struct_t([tyC, ty]) + let fnTy = function_t(retTy, [ty, ty, tyC]) + let sbb = br.call2(fnTy, fn, [a, b, borrowIn], name = "sbb") + sbb.setInstrCallConv(Fast) + let lo = br.extractValue(sbb, 1, name = "sbb.lo") + let bOut = br.extractValue(sbb, 0, name = "sbb.borrow") + return (bOut, lo) + +proc mulExt*(bld: BuilderRef, a, b: ValueRef): tuple[hi, lo: ValueRef] = + ## Extended precision multiplication + ## (hi, lo) <- a*b + let ctx = bld.getContext() + let ty = a.getTypeOf() + let bits = ty.getIntTypeWidth() + let dbl = bits shl 1 + let dblTy = ctx.int_t(dbl) + + let a = bld.zext(a, dblTy, name = "mulx0_") + let b = bld.zext(b, dblTy, name = "mulx1_") + let r = bld.mulNUW(a, b, name = "mulx_") + + let lo = bld.trunc(r, ty, name = "mullo_") + let hi = bld.hi(r, ty, oversize = bits, prefix = "mulhi_") + return (hi, lo) + +proc smulExt*(bld: BuilderRef, a, b: ValueRef): tuple[hi, lo: ValueRef] = + ## Signed extended precision multiplication + ## (hi, lo) <- a*b + let ctx = bld.getContext() + let ty = a.getTypeOf() + let bits = ty.getIntTypeWidth() + let dbl = bits shl 1 + let dblTy = ctx.int_t(dbl) + + let a = bld.sext(a, dblTy, name = "smulx0_") + let b = bld.sext(b, dblTy, name = "smulx1_") + let r = bld.mulNSW(a, b, name = "smulx0_") + + let lo = bld.trunc(r, ty, name = "smullo_") + let hi = bld.hi(r, ty, oversize = bits, prefix = "smulhi_") + return (hi, lo) + +proc muladd1*(bld: BuilderRef, a, b, c: ValueRef): tuple[hi, lo: ValueRef] = + ## Extended precision multiplication + addition + ## (hi, lo) <- a*b + c + ## + ## Note: 0xFFFFFFFF² -> (hi: 0xFFFFFFFE, lo: 0x00000001) + ## so adding any c cannot overflow + let ctx = bld.getContext() + let ty = a.getTypeOf() + let bits = ty.getIntTypeWidth() + let dbl = bits shl 1 + let dblTy = ctx.int_t(dbl) + + let a = bld.zext(a, dblTy, name = "fmax0_") + let b = bld.zext(b, dblTy, name = "fmax1_") + let ab = bld.mulNUW(a, b, name = "fmax01_") + + let c = bld.zext(c, dblTy, name = "fmax2_") + let r = bld.addNUW(ab, c, name = "fmax_") + + let lo = bld.trunc(r, ty, name = "fmalo_") + let hi = bld.hi(r, ty, oversize = bits, prefix = "fmahi_") + return (hi, lo) + +proc muladd2*(bld: BuilderRef, a, b, c1, c2: ValueRef): tuple[hi, lo: ValueRef] = + ## Extended precision multiplication + addition + addition + ## (hi, lo) <- a*b + c1 + c2 + ## + ## Note: 0xFFFFFFFF² -> (hi: 0xFFFFFFFE, lo: 0x00000001) + ## so adding 0xFFFFFFFF leads to (hi: 0xFFFFFFFF, lo: 0x00000000) + ## and we have enough space to add again 0xFFFFFFFF without overflowing + let ctx = bld.getContext() + let ty = a.getTypeOf() + let bits = ty.getIntTypeWidth() + let dbl = bits shl 1 + let dblTy = ctx.int_t(dbl) + + let a = bld.zext(a, dblTy, name = "fmaa0_") + let b = bld.zext(b, dblTy, name = "fmaa1_") + let ab = bld.mulNUW(a, b, name = "fmaa01_") + + let c1 = bld.zext(c1, dblTy, name = "fmaa2_") + let abc1 = bld.addNUW(ab, c1, name = "fmaa012_") + let c2 = bld.zext(c2, dblTy, name = "fmaa3_") + let r = bld.addNUW(abc1, c2, name = "fmaa_") + + let lo = bld.trunc(r, ty, name = "fmaalo_") + let hi = bld.hi(r, ty, oversize = bits, prefix = "fmaahi_") + return (hi, lo) + +proc mulAcc*(bld: BuilderRef, tuv: var ValueRef, a, b: ValueRef) = + ## (t, u, v) <- (t, u, v) + a * b + let ctx = bld.getContext() + + let ty = a.getTypeOf() + let bits = ty.getIntTypeWidth() + + let x3ty = tuv.getTypeOf() + let x3bits = x3ty.getIntTypeWidth() + + doAssert bits * 3 == x3bits + + let dbl = bits shl 1 + let dblTy = ctx.int_t(dbl) + + let a = bld.zext(a, dblTy, name = "mac0_") + let b = bld.zext(b, dblTy, name = "mac1_") + let ab = bld.mulNUW(a, b, name = "mac01_") + + let wide_ab = bld.zext(ab, x3ty, name = "mac01x_") + let r = bld.addNUW(tuv, wide_ab, "mac_") + + tuv = r + +proc mulDoubleAcc*(bld: BuilderRef, tuv: var ValueRef, a, b: ValueRef) = + ## (t, u, v) <- (t, u, v) + 2 * a * b + let ctx = bld.getContext() + + let ty = a.getTypeOf() + let bits = ty.getIntTypeWidth() + + let x3ty = tuv.getTypeOf() + let x3bits = x3ty.getIntTypeWidth() + + doAssert bits * 3 == x3bits + + let dbl = bits shl 1 + let dblTy = ctx.int_t(dbl) + + let a = bld.zext(a, dblTy, name = "macd0_") + let b = bld.zext(b, dblTy, name = "macd1_") + let ab = bld.mulNUW(a, b, name = "macd01_") + + let wide_ab = bld.zext(ab, x3ty, name = "macd01x_") + let r1 = bld.addNUW(tuv, wide_ab, "macdpart_") + let r2 = bld.addNUW(r1, wide_ab, "macd_") + + tuv = r2 diff --git a/research/codegen/x86_instr.nim b/research/codegen/x86_instr.nim deleted file mode 100644 index a4a19219..00000000 --- a/research/codegen/x86_instr.nim +++ /dev/null @@ -1,96 +0,0 @@ -# Constantine -# Copyright (c) 2018-2019 Status Research & Development GmbH -# Copyright (c) 2020-Present Mamy André-Ratsimbazafy -# Licensed and distributed under either of -# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). -# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). -# at your option. This file may not be copied, modified, or distributed except according to those terms. - -import - constantine/platforms/abis/c_abi, - constantine/platforms/llvm/llvm, - constantine/platforms/primitives, - constantine/math_compiler/ir, - ./x86_inlineasm - -export x86_inlineasm - -# ############################################################ -# -# x86 API -# -# ############################################################ - -proc defMulExt*(asy: Assembler_LLVM, wordSize: int): FnDef = - - let procName = if wordSize == 64: cstring"hw_mulExt64" - else: cstring"hw_mulExt32" - - let doublePrec_t = if wordSize == 64: asy.i128_t - else: asy.i64_t - - let mulExtTy = if wordSize == 64: function_t(doublePrec_t, [asy.i64_t, asy.i64_t]) - else: function_t(doublePrec_t, [asy.i32_t, asy.i32_t]) - let mulExtKernel = asy.module.addFunction(procName, mulExtTy) - let blck = asy.ctx.appendBasicBlock(mulExtKernel, "mulExtBody") - asy.builder.positionAtEnd(blck) - - let bld = asy.builder - - let a = bld.zext(mulExtKernel.getParam(0), doublePrec_t) - let b = bld.zext(mulExtKernel.getParam(1), doublePrec_t) - let r = bld.mul(a, b) - - bld.ret r - - return (mulExtTy, mulExtKernel) - -proc defHi*(asy: Assembler_LLVM, wordSize: int): FnDef = - - let procName = if wordSize == 64: cstring"hw_hi64" - else: cstring"hw_hi32" - let doublePrec_t = if wordSize == 64: asy.i128_t - else: asy.i64_t - let singlePrec_t = if wordSize == 64: asy.i64_t - else: asy.i32_t - - let hiTy = function_t(singlePrec_t, [doublePrec_t]) - - let hiKernel = asy.module.addFunction(procName, hiTy) - let blck = asy.ctx.appendBasicBlock(hiKernel, "hiBody") - asy.builder.positionAtEnd(blck) - - let bld = asy.builder - - # %1 = zext i32 64 to i128 - let shift = bld.zext(constInt(asy.i32_t, culonglong wordSize, signExtend = LlvmBool(0)), doublePrec_t) - # %hiLarge = lshr i128 %input, %1 - let hiLarge = bld.lshr(hiKernel.getParam(0), shift) - # %hi = trunc i128 %hiLarge to i64 - let hi = bld.trunc(hiLarge, singlePrec_t) - - bld.ret hi - - return (hiTy, hiKernel) - -proc defLo*(asy: Assembler_LLVM, wordSize: int): FnDef = - - let procName = if wordSize == 64: cstring"hw_lo64" - else: cstring"hw_lo32" - let doublePrec_t = if wordSize == 64: asy.i128_t - else: asy.i64_t - let singlePrec_t = if wordSize == 64: asy.i64_t - else: asy.i32_t - - let loTy = function_t(singlePrec_t, [doublePrec_t]) - - let loKernel = asy.module.addFunction(procName, loTy) - let blck = asy.ctx.appendBasicBlock(loKernel, "loBody") - asy.builder.positionAtEnd(blck) - - let bld = asy.builder - - # %lo = trunc i128 %input to i64 - let lo = bld.trunc(loKernel.getParam(0), singlePrec_t) - bld.ret lo - return (loTy, loKernel) diff --git a/research/codegen/x86_poc.nim b/research/codegen/x86_poc.nim index c5c376fe..4ec5b67f 100644 --- a/research/codegen/x86_poc.nim +++ b/research/codegen/x86_poc.nim @@ -8,140 +8,121 @@ import constantine/platforms/llvm/llvm, - constantine/platforms/primitives, - constantine/math_compiler/ir, - ./x86_instr - -echo "LLVM JIT compiler: Multiplication with MULX/ADOX/ADCX" - -proc big_mul_gen(asy: Assembler_LLVM): FnDef = - - - let procName = "big_mul_64x4" - let N = 4 - let ty = array_t(asy.i64_t, N) - let pty = pointer_t(ty) - - let bigMulTy = function_t(asy.void_t, [pty, pty, pty]) - let bigMulKernel = asy.module.addFunction(cstring procName, bigMulTy) - let blck = asy.ctx.appendBasicBlock(bigMulKernel, "bigMulBody") - asy.builder.positionAtEnd(blck) - - let bld = asy.builder - - let (hiTy, hiKernel) = asy.defHi(64) - proc hi(builder: BuilderRef, a: ValueRef): ValueRef = - return builder.call2( - hiTy, hiKernel, - [a], "hi64_" - ) - - let (loTy, loKernel) = asy.defLo(64) - proc lo(builder: BuilderRef, a: ValueRef): ValueRef = - return builder.call2( - loTy, loKernel, - [a], "lo64_" - ) - - let (mulExtTy, mulExtKernel) = asy.defMulExt(64) - bld.positionAtEnd(blck) - - proc mulx(builder: BuilderRef, a, b: ValueRef): tuple[hi, lo: ValueRef] = - # LLVM does not support multipel return value at the moment - # https://nondot.org/sabre/LLVMNotes/MultipleReturnValues.txt - # So we don't create an LLVM function - let t = builder.call2( - mulExtTy, mulExtKernel, - [a, b], "mulx64_" - ) - - builder.positionAtEnd(blck) - let lo = builder.lo(t) - let hi = builder.hi(t) - return (hi, lo) - - let r = bld.asArray(bigMulKernel.getParam(0), ty) - let a = bld.asArray(bigMulKernel.getParam(1), ty) - let b = bld.asArray(bigMulKernel.getParam(2), ty) - - let t = bld.makeArray(ty) - - block: # i = 0 - # TODO: properly implement add/adc in pure LLVM - - # TODO: ensure flags are cleared properly, compiler might optimize this away - t[0] = bld.`xor`(t[0], t[0]) - let (hi, lo) = bld.mulx(a[0], b[0]) - r[0] = lo - t[0] = hi - - for j in 1 ..< N: - let (hi , lo) = bld.mulx(a[j], b[0]) - t[j] = hi - # SHOWSTOPPER: LLVM ERROR: Inline asm not supported by this streamer because we don't have an asm parser for this target - discard bld.adcx_rr(t[j-1], lo) # Replace by LLVM IR uadd_with_overflow - - # SHOWSTOPPER: LLVM ERROR: Inline asm not supported by this streamer because we don't have an asm parser for this target - discard bld.adcx_rr(t[N-1], 0) - - # TODO: rotate t array - - # TODO: impl i in 1 ..< N - - bld.store(r, t) - bld.retVoid() - return (bigMulTy, bigMulKernel) - -when isMainModule: - # It's not the Nvidia PTX backend but it's fine + constantine/math_compiler/[ir, pub_fields] + +const Fields = [ + ( + "bn254_snarks_fp", 254, + "30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47" + ), + ( + "bn254_snarks_fr", 254, + "30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001" + ), + + ( + "secp256k1_fp", 256, + "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f" + ), + ( + "secp256k1_fr", 256, + "fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141" + ), + ( + "bls12_381_fp", 381, + "1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab" + ), + ( + "bls12_381_fr", 255, + "73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001" + ), + ( + "bls12_377_fp", 377, + "01ae3a4617c510eac63b05c06ca1493b1a22d9f300f5138f1ef3622fba094800170b5d44300000008508c00000000001" + ), + ( + "bls12_377_fr", 253, + "12ab655e9a2ca55660b44d1e5c37b00159aa76fed00000010a11800000000001" + ), + ( + "bls24_315_fp", 315, + "4c23a02b586d650d3f7498be97c5eafdec1d01aa27a1ae0421ee5da52bde5026fe802ff40300001" + ), + ( + "bls12_315_fr", 253, + "196deac24a9da12b25fc7ec9cf927a98c8c480ece644e36419d0c5fd00c00001" + ), + ( + "bls24_317_fp", 317, + "1058CA226F60892CF28FC5A0B7F9D039169A61E684C73446D6F339E43424BF7E8D512E565DAB2AAB" + ), + ( + "bls12_317_fr", 255, + "443F917EA68DAFC2D0B097F28D83CD491CD1E79196BF0E7AF000000000000001" + ), +] + +proc t_field_add() = let asy = Assembler_LLVM.new(bkX86_64_Linux, cstring("x86_poc")) - let bigMul = asy.big_mul_gen() + for F in Fields: + let fd = asy.ctx.configureField( + F[0], F[1], F[2], + v = 1, w = 64) - asy.module.verify(AbortProcessAction) + asy.definePrimitives(fd) + + discard asy.genFpAdd(fd) echo "=========================================" - echo "LLVM IR\n" + echo "LLVM IR unoptimized\n" echo asy.module echo "=========================================" - - var engine: ExecutionEngineRef - initializeFullNativeTarget() - createJITCompilerForModule(engine, asy.module, optLevel = 0) - - let jitMul = cast[proc(r: var array[4, uint64], a, b: array[4, uint64]){.noconv.}]( - engine.getFunctionAddress("big_mul_64x4") - ) - - var r: array[4, uint64] - r.jitMul([uint64 1, 2, 3, 4], [uint64 1, 1, 1, 1]) - echo "jitMul = ", r - - # block: - # Cleanup - Assembler_LLVM is auto-managed - # engine.dispose() # also destroys the module attached to it, which double_frees Assembler_LLVM asy.module - echo "LLVM JIT - calling big_mul_64x4 SUCCESS" + asy.module.verify(AbortProcessAction) # -------------------------------------------- - # See the assembly- note it might be different from what the JIT compiler did - + # See the assembly - note it might be different from what the JIT compiler did + initializeFullNativeTarget() const triple = "x86_64-pc-linux-gnu" let machine = createTargetMachine( target = toTarget(triple), triple = triple, cpu = "", - features = "adx,bmi2", # TODO check the proper way to pass options - level = CodeGenLevelAggressive, + features = "", # "adx,bmi2", # TODO check the proper way to pass options + level = CodeGenLevelDefault, reloc = RelocDefault, codeModel = CodeModelDefault ) + # Due to https://github.com/llvm/llvm-project/issues/102868 + # We want to reproduce the codegen from llc.cpp + # However we can't reproduce the code from either + # - LLVM16 https://github.com/llvm/llvm-project/blob/llvmorg-16.0.6/llvm/tools/llc/llc.cpp + # need legacy PassManagerRef and the PassManagerBuilder that interfaces between the + # legacy PssManagerRef and new PassBuilder has been deleted in LLVM17 + # + # - and contrary to what is claimed in https://llvm.org/docs/NewPassManager.html#id2 + # the C API of PassBuilderRef is ghost town. + # + # So we somewhat reproduce the optimization passes from + # https://reviews.llvm.org/D145835 + let pbo = createPassBuilderOptions() pbo.setMergeFunctions() let err = asy.module.runPasses( - "default,function-attrs,memcpyopt,sroa,mem2reg,gvn,dse,instcombine,inline,adce", + # "default,memcpyopt,sroa,mem2reg,function-attrs,inline,gvn,dse,aggressive-instcombine,adce", + "function(require,require,require,require,require)" & + ",function(aa-eval)" & + ",always-inline,hotcoldsplit,inferattrs,instrprof,recompute-globalsaa" & + ",cgscc(argpromotion,function-attrs)" & + ",require,partial-inliner,called-value-propagation" & + ",scc-oz-module-inliner,module-inline" & # Buggy optimization + ",function(verify,loop-mssa(loop-reduce),mergeicmps,expand-memcmp,instsimplify)" & + ",function(lower-constant-intrinsics,consthoist,partially-inline-libcalls,ee-instrument,scalarize-masked-mem-intrin,verify)" & + ",memcpyopt,sroa,dse,aggressive-instcombine,gvn,ipsccp,deadargelim,adce" & + "", machine, pbo ) @@ -154,174 +135,35 @@ when isMainModule: quit 1 echo "=========================================" - echo "Assembly\n" + echo "LLVM IR optimized\n" - echo machine.emitTo[:string](asy.module, AssemblyFile) + echo asy.module echo "=========================================" - # Output - # ------------------------------------------------------------------ - - #[ - LLVM JIT compiler: Multiplication with MULX/ADOX/ADCX - ========================================= - LLVM IR - - ; ModuleID = 'x86_poc' - source_filename = "x86_poc" - target triple = "x86_64-pc-linux-gnu" - - define void @big_mul_64x4(ptr %0, ptr %1, ptr %2) { - bigMulBody: - %3 = alloca [4 x i64], align 8 - %4 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 0 - %5 = load i64, ptr %4, align 4 - %6 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 0 - %7 = load i64, ptr %6, align 4 - %8 = xor i64 %5, %7 - %9 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 0 - store i64 %8, ptr %9, align 4 - %10 = getelementptr inbounds [4 x i64], ptr %1, i32 0, i32 0 - %11 = load i64, ptr %10, align 4 - %12 = getelementptr inbounds [4 x i64], ptr %2, i32 0, i32 0 - %13 = load i64, ptr %12, align 4 - %mulx64_ = call i128 @hw_mulExt64(i64 %11, i64 %13) - %lo64_ = call i64 @hw_lo64(i128 %mulx64_) - %hi64_ = call i64 @hw_hi64(i128 %mulx64_) - %14 = getelementptr inbounds [4 x i64], ptr %0, i32 0, i32 0 - store i64 %lo64_, ptr %14, align 4 - %15 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 0 - store i64 %hi64_, ptr %15, align 4 - %16 = getelementptr inbounds [4 x i64], ptr %1, i32 0, i32 1 - %17 = load i64, ptr %16, align 4 - %18 = getelementptr inbounds [4 x i64], ptr %2, i32 0, i32 0 - %19 = load i64, ptr %18, align 4 - %mulx64_1 = call i128 @hw_mulExt64(i64 %17, i64 %19) - %lo64_2 = call i64 @hw_lo64(i128 %mulx64_1) - %hi64_3 = call i64 @hw_hi64(i128 %mulx64_1) - %20 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 1 - store i64 %hi64_3, ptr %20, align 4 - %21 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 0 - %22 = load i64, ptr %21, align 4 - %23 = call i64 asm "adcxq %2, %0;", "=r,%0,r"(i64 %22, i64 %lo64_2) - %24 = getelementptr inbounds [4 x i64], ptr %1, i32 0, i32 2 - %25 = load i64, ptr %24, align 4 - %26 = getelementptr inbounds [4 x i64], ptr %2, i32 0, i32 0 - %27 = load i64, ptr %26, align 4 - %mulx64_4 = call i128 @hw_mulExt64(i64 %25, i64 %27) - %lo64_5 = call i64 @hw_lo64(i128 %mulx64_4) - %hi64_6 = call i64 @hw_hi64(i128 %mulx64_4) - %28 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 2 - store i64 %hi64_6, ptr %28, align 4 - %29 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 1 - %30 = load i64, ptr %29, align 4 - %31 = call i64 asm "adcxq %2, %0;", "=r,%0,r"(i64 %30, i64 %lo64_5) - %32 = getelementptr inbounds [4 x i64], ptr %1, i32 0, i32 3 - %33 = load i64, ptr %32, align 4 - %34 = getelementptr inbounds [4 x i64], ptr %2, i32 0, i32 0 - %35 = load i64, ptr %34, align 4 - %mulx64_7 = call i128 @hw_mulExt64(i64 %33, i64 %35) - %lo64_8 = call i64 @hw_lo64(i128 %mulx64_7) - %hi64_9 = call i64 @hw_hi64(i128 %mulx64_7) - %36 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 3 - store i64 %hi64_9, ptr %36, align 4 - %37 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 2 - %38 = load i64, ptr %37, align 4 - %39 = call i64 asm "adcxq %2, %0;", "=r,%0,r"(i64 %38, i64 %lo64_8) - %40 = getelementptr inbounds [4 x i64], ptr %3, i32 0, i32 3 - %41 = load i64, ptr %40, align 4 - %42 = call i64 asm "adcxq %2, %0;", "=r,%0,r"(i64 %41, i64 0) - %43 = load [4 x i64], ptr %3, align 4 - store [4 x i64] %43, ptr %0, align 4 - ret void - } - - define i64 @hw_hi64(i128 %0) { - hiBody: - %1 = lshr i128 %0, 64 - %2 = trunc i128 %1 to i64 - ret i64 %2 - } - - define i64 @hw_lo64(i128 %0) { - loBody: - %1 = trunc i128 %0 to i64 - ret i64 %1 - } - - define i128 @hw_mulExt64(i64 %0, i64 %1) { - mulExtBody: - %2 = zext i64 %0 to i128 - %3 = zext i64 %1 to i128 - %4 = mul i128 %2, %3 - ret i128 %4 - } + echo "=========================================" + echo "Assembly\n" - ========================================= - jitMul = [0, 0, 0, 0] - LLVM JIT - calling big_mul_64x4 SUCCESS - ========================================= - Assembly + echo machine.emitTo[:string](asy.module, AssemblyFile) + echo "=========================================" - .text - .file "x86_poc" - .globl big_mul_64x4 - .p2align 4, 0x90 - .type big_mul_64x4,@function - big_mul_64x4: - .cfi_startproc - movq %rdx, %rcx - movq (%rdx), %rax - mulq (%rsi) - movq %rdx, %r8 - movq %rax, (%rdi) - movq (%rcx), %rcx - movq %rcx, %rax - mulq 8(%rsi) - movq %rdx, %r9 - movq %rcx, %rax - mulq 16(%rsi) - movq %rdx, %r10 - movq %rcx, %rax - mulq 24(%rsi) - movq %r8, (%rdi) - movq %r9, 8(%rdi) - movq %r10, 16(%rdi) - movq %rdx, 24(%rdi) - retq - .Lfunc_end0: - .size big_mul_64x4, .Lfunc_end0-big_mul_64x4 - .cfi_endproc + # var engine: ExecutionEngineRef + # initializeFullNativeTarget() + # createJITCompilerForModule(engine, asy.module, optLevel = 3) - .globl hw_hi64 - .p2align 4, 0x90 - .type hw_hi64,@function - hw_hi64: - movq %rsi, %rax - retq - .Lfunc_end1: - .size hw_hi64, .Lfunc_end1-hw_hi64 + # let fn32 = cm32.genSymbol(opFpAdd) + # let fn64 = cm64.genSymbol(opFpAdd) - .globl hw_lo64 - .p2align 4, 0x90 - .type hw_lo64,@function - hw_lo64: - movq %rdi, %rax - retq - .Lfunc_end2: - .size hw_lo64, .Lfunc_end2-hw_lo64 + # let jitFpAdd64 = cast[proc(r: var array[4, uint64], a, b: array[4, uint64]){.noconv.}]( + # engine.getFunctionAddress(cstring fn64) + # ) - .globl hw_mulExt64 - .p2align 4, 0x90 - .type hw_mulExt64,@function - hw_mulExt64: - movq %rsi, %rax - mulq %rdi - retq - .Lfunc_end3: - .size hw_mulExt64, .Lfunc_end3-hw_mulExt64 + # var r: array[4, uint64] + # r.jitFpAdd64([uint64 1, 2, 3, 4], [uint64 1, 1, 1, 1]) + # echo "jitFpAdd64 = ", r - .section ".note.GNU-stack","",@progbits + # # block: + # # Cleanup - Assembler_LLVM is auto-managed + # # engine.dispose() # also destroys the module attached to it, which double_frees Assembler_LLVM asy.module + # echo "LLVM JIT - calling FpAdd64 SUCCESS" - ========================================= - ]# +t_field_add() diff --git a/tests/gpu/t_nvidia_fp.nim b/tests/gpu/t_nvidia_fp.nim index b3aa873f..af4de007 100644 --- a/tests/gpu/t_nvidia_fp.nim +++ b/tests/gpu/t_nvidia_fp.nim @@ -70,9 +70,9 @@ proc t_field_add(curve: static Algebra) = # Codegen # ------------------------- let asy = Assembler_LLVM.new(bkNvidiaPTX, cstring("t_nvidia_" & $curve)) - let cm32 = CurveMetadata.init(asy, curve, size32) + let cm32 = CurveMetadata.init(asy, curve, w32) asy.genFieldAddPTX(cm32) - let cm64 = CurveMetadata.init(asy, curve, size64) + let cm64 = CurveMetadata.init(asy, curve, w64) asy.genFieldAddPTX(cm64) let ptx = asy.codegenNvidiaPTX(sm) @@ -124,9 +124,9 @@ proc t_field_sub(curve: static Algebra) = # Codegen # ------------------------- let asy = Assembler_LLVM.new(bkNvidiaPTX, cstring("t_nvidia_" & $curve)) - let cm32 = CurveMetadata.init(asy, curve, size32) + let cm32 = CurveMetadata.init(asy, curve, w32) asy.genFieldSubPTX(cm32) - let cm64 = CurveMetadata.init(asy, curve, size64) + let cm64 = CurveMetadata.init(asy, curve, w64) asy.genFieldSubPTX(cm64) let ptx = asy.codegenNvidiaPTX(sm) @@ -178,14 +178,14 @@ proc t_field_mul(curve: static Algebra) = # Codegen # ------------------------- let asy = Assembler_LLVM.new(bkNvidiaPTX, cstring("t_nvidia_" & $curve)) - let cm32 = CurveMetadata.init(asy, curve, size32) + let cm32 = CurveMetadata.init(asy, curve, w32) asy.genFieldMulPTX(cm32) # 64-bit integer fused-multiply-add with carry is buggy: # https://gist.github.com/mratsim/a34df1e091925df15c13208df7eda569#file-mul-py # https://forums.developer.nvidia.com/t/incorrect-result-of-ptx-code/221067 - # let cm64 = CurveMetadata.init(asy, curve, size64) + # let cm64 = CurveMetadata.init(asy, curve, w64) # asy.genFieldMulPTX(cm64) let ptx = asy.codegenNvidiaPTX(sm)