From 112ab49eec41dad67651e89c6881610f87f3e46c Mon Sep 17 00:00:00 2001 From: Mamy Ratsimbazafy Date: Thu, 22 Aug 2024 10:52:26 +0200 Subject: [PATCH] failed attempt at solving towering / base field perf discrepancy, similar to https://github.com/mratsim/constantine/issues/446 --- .../assembly/limbs_asm_modular_x86.nim | 11 ++ .../extension_fields/assembly/fp2_asm_x86.nim | 85 ++++++++++ .../assembly/fp2_asm_x86_adx_bmi2.nim | 4 +- constantine/math/extension_fields/towers.nim | 158 +++++++++++------- constantine/platforms/ast_rebuilder.nim | 2 +- constantine/platforms/static_for.nim | 10 ++ .../x86/macro_assembler_x86_intel.nim | 4 +- 7 files changed, 206 insertions(+), 68 deletions(-) create mode 100644 constantine/math/extension_fields/assembly/fp2_asm_x86.nim diff --git a/constantine/math/arithmetic/assembly/limbs_asm_modular_x86.nim b/constantine/math/arithmetic/assembly/limbs_asm_modular_x86.nim index 0db85007..929b1dee 100644 --- a/constantine/math/arithmetic/assembly/limbs_asm_modular_x86.nim +++ b/constantine/math/arithmetic/assembly/limbs_asm_modular_x86.nim @@ -168,6 +168,9 @@ macro addmod_gen[N: static int](r_PIR: var Limbs[N], a_PIR, b_PIR, M_MEM: Limbs[ ctx.finalSubMayOverflowImpl(r, u, M, v, a_in_scratch = true, scratchReg = b.reuseRegister()) result.add ctx.generate() + return nnkBlockStmt.newTree( + newEmptyNode(), + result) func addmod_asm*(r: var Limbs, a, b, M: Limbs, spareBits: static int) = ## Constant-time modular addition @@ -228,6 +231,10 @@ macro submod_gen[N: static int](r_PIR: var Limbs[N], a_PIR, b_PIR, M_MEM: Limbs[ result.add ctx.generate() + return nnkBlockStmt.newTree( + newEmptyNode(), + result) + func submod_asm*(r: var Limbs, a, b, M: Limbs) = ## Constant-time modular substraction ## Warning, does not handle aliasing of a and b @@ -277,6 +284,10 @@ macro negmod_gen[N: static int](r_PIR: var Limbs[N], a_MEM, M_MEM: Limbs[N]): un result.add ctx.generate() + return nnkBlockStmt.newTree( + newEmptyNode(), + result) + func negmod_asm*(r: var Limbs, a, M: Limbs) = ## Constant-time modular negation negmod_gen(r, a, M) diff --git a/constantine/math/extension_fields/assembly/fp2_asm_x86.nim b/constantine/math/extension_fields/assembly/fp2_asm_x86.nim new file mode 100644 index 00000000..7ae107d6 --- /dev/null +++ b/constantine/math/extension_fields/assembly/fp2_asm_x86.nim @@ -0,0 +1,85 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + # Internal + constantine/platforms/abstractions, + constantine/named/algebras, + constantine/math/arithmetic + +import constantine/math/arithmetic/assembly/limbs_asm_modular_x86 {.all.} + +# ############################################################ +# # +# Assembly implementation of 𝔽p2 # +# # +# ############################################################ + +static: doAssert UseASM_X86_64 + +# No exceptions allowed +{.push raises: [].} + +# 𝔽p2 addition law +# ------------------------------------------------------------ + +template aliasPtr(coord, name: untyped): untyped = + # The *_gen macros get confused by bracket [] and dot `.` expressions + # when deriving names so create aliases + # Furthermore the C compiler requires asm inputs to be lvalues + # and arrays should be passed as pointers (aren't they aren't if we use a dot expression) + let name {.inject.} = coord.mres.limbs.unsafeAddr() + +func fp2_add_asm*( + r: var array[2, Fp], + a, b: array[2, Fp]) = + ## Addition on Fp2 + # This specialized proc inline calls and limits data movement (for example register pop/push) + const spareBits = Fp.getSpareBits() + + aliasPtr r[0], r0 + aliasPtr r[1], r1 + aliasPtr a[0], a0 + aliasPtr a[1], a1 + aliasPtr b[0], b0 + aliasPtr b[1], b1 + let p = Fp.getModulus().limbs.unsafeAddr() + + addmod_gen(r0[], a0[], b0[], p[], spareBits) + addmod_gen(r1[], a1[], b1[], p[], spareBits) + +func fp2_sub_asm*( + r: var array[2, Fp], + a, b: array[2, Fp]) = + ## Substraction on Fp2 + # This specialized proc inline calls and limits data movement (for example register pop/push) + aliasPtr r[0], r0 + aliasPtr r[1], r1 + aliasPtr a[0], a0 + aliasPtr a[1], a1 + aliasPtr b[0], b0 + aliasPtr b[1], b1 + let p = Fp.getModulus().limbs.unsafeAddr() + + submod_gen(r0[], a0[], b0[], p[]) + submod_gen(r1[], a1[], b1[], p[]) + +func fp2_neg_asm*( + r: var array[2, Fp], + a: array[2, Fp]) = + ## Negation on Fp2 + # This specialized proc inline calls and limits data movement (for example register pop/push) + + aliasPtr r[0], r0 + aliasPtr r[1], r1 + aliasPtr a[0], a0 + aliasPtr a[1], a1 + let p = Fp.getModulus().limbs.unsafeAddr() + + negmod_gen(r0[], a0[], p[]) + negmod_gen(r1[], a1[], p[]) diff --git a/constantine/math/extension_fields/assembly/fp2_asm_x86_adx_bmi2.nim b/constantine/math/extension_fields/assembly/fp2_asm_x86_adx_bmi2.nim index ac25e67e..25cc623c 100644 --- a/constantine/math/extension_fields/assembly/fp2_asm_x86_adx_bmi2.nim +++ b/constantine/math/extension_fields/assembly/fp2_asm_x86_adx_bmi2.nim @@ -29,9 +29,9 @@ static: doAssert UseASM_X86_64 # No exceptions allowed {.push raises: [].} -template c0*(a: array): auto = +template c0(a: array): auto = a[0] -template c1*(a: array): auto = +template c1(a: array): auto = a[1] func has1extraBit(F: type Fp): bool = diff --git a/constantine/math/extension_fields/towers.nim b/constantine/math/extension_fields/towers.nim index 20fc0d51..ebe139b4 100644 --- a/constantine/math/extension_fields/towers.nim +++ b/constantine/math/extension_fields/towers.nim @@ -16,7 +16,7 @@ export Fp when UseASM_X86_64: import - ./assembly/fp2_asm_x86_adx_bmi2 + ./assembly/[fp2_asm_x86, fp2_asm_x86_adx_bmi2] # Note: to avoid burdening the Nim compiler, we rely on generic extension # to complain if the base field procedures don't exist @@ -57,8 +57,8 @@ type CubicExt[Fp2[Name]] Fp12*[Name: static Algebra] = - CubicExt[Fp4[Name]] - # QuadraticExt[Fp6[Name]] + # CubicExt[Fp4[Name]] + QuadraticExt[Fp6[Name]] template c0*(a: ExtensionField): auto = a.coords[0] @@ -80,6 +80,56 @@ template Name*(E: type ExtensionField): Algebra = template getModulus*(E: type ExtensionField): auto = E.F.getModulus() +# ############################################################ +# # +# Cost functions # +# # +# ############################################################ + +func prefer_3sqr_over_2mul(F: type ExtensionField): bool {.compileTime.} = + ## Returns true + ## if time(3sqr) < time(2mul) in the extension fields + + let a = default(F) + # No shortcut in the VM + when a.c0 is Fp12: + # Benchmarked on BLS12-381 + when a.c0.c0 is Fp6: + return true + elif a.c0.c0 is Fp4: + return false + else: return false + else: return false + +func has_large_NR_norm(Name: static Algebra): bool = + ## Returns true if the non-residue of the extension fields + ## has a large norm + + const j = Name.getNonResidueFp() + const u = Name.getNonResidueFp2()[0] + const v = Name.getNonResidueFp2()[1] + + const norm2 = u*u + (j*v)*(j*v) + + # Compute integer square root + var norm = 0 + while (norm+1) * (norm+1) <= norm2: + norm += 1 + + return norm > 5 + +func has_large_field_elem*(Name: static Algebra): bool = + ## Returns true if field element are large + ## and necessitate custom routine for assembly in particular + let a = default(Fp[Name]) + return a.mres.limbs.len > 6 + +# ############################################################ +# # +# Implementation # +# # +# ############################################################ + # Initialization # ------------------------------------------------------------------- @@ -148,36 +198,56 @@ func ccopy*(a: var ExtensionField, b: ExtensionField, ctl: SecretBool) = # Abelian group # ------------------------------------------------------------------- +func hasFp2x86asm(T: type ExtensionField): bool = + T is Fp2 and UseASM_X86_64 and not T.Name.has_large_field_elem() func neg*(r: var ExtensionField, a: ExtensionField) = ## Field out-of-place negation - staticFor i, 0, a.coords.len: - r.coords[i].neg(a.coords[i]) + when a.typeof().hasFp2x86asm(): + r.coords.fp2_neg_asm(a.coords) + else: + staticFor i, 0, a.coords.len: + r.coords[i].neg(a.coords[i]) func neg*(a: var ExtensionField) = ## Field in-place negation - staticFor i, 0, a.coords.len: - a.coords[i].neg() + when a.typeof().hasFp2x86asm(): + a.coords.fp2_neg_asm(a.coords) + else: + staticFor i, 0, a.coords.len: + a.coords[i].neg() func `+=`*(a: var ExtensionField, b: ExtensionField) = ## Addition in the extension field - staticFor i, 0, a.coords.len: - a.coords[i] += b.coords[i] + when a.typeof().hasFp2x86asm(): + a.coords.fp2_add_asm(a.coords, b.coords) + else: + staticFor i, 0, a.coords.len: + a.coords[i] += b.coords[i] func `-=`*(a: var ExtensionField, b: ExtensionField) = ## Substraction in the extension field - staticFor i, 0, a.coords.len: - a.coords[i] -= b.coords[i] + when a.typeof().hasFp2x86asm(): + a.coords.fp2_sub_asm(a.coords, b.coords) + else: + staticFor i, 0, a.coords.len: + a.coords[i] -= b.coords[i] func double*(r: var ExtensionField, a: ExtensionField) = ## Field out-of-place doubling - staticFor i, 0, a.coords.len: - r.coords[i].double(a.coords[i]) + when a.typeof().hasFp2x86asm(): + r.coords.fp2_add_asm(a.coords, a.coords) + else: + staticFor i, 0, a.coords.len: + r.coords[i].double(a.coords[i]) func double*(a: var ExtensionField) = ## Field in-place doubling - staticFor i, 0, a.coords.len: - a.coords[i].double() + when a.typeof().hasFp2x86asm(): + a.coords.fp2_add_asm(a.coords, a.coords) + else: + staticFor i, 0, a.coords.len: + a.coords[i].double() func div2*(a: var ExtensionField) = ## Field in-place division by 2 @@ -186,13 +256,19 @@ func div2*(a: var ExtensionField) = func sum*(r: var ExtensionField, a, b: ExtensionField) = ## Sum ``a`` and ``b`` into ``r`` - staticFor i, 0, a.coords.len: - r.coords[i].sum(a.coords[i], b.coords[i]) + when a.typeof().hasFp2x86asm(): + r.coords.fp2_add_asm(a.coords, b.coords) + else: + staticFor i, 0, a.coords.len: + r.coords[i].sum(a.coords[i], b.coords[i]) func diff*(r: var ExtensionField, a, b: ExtensionField) = ## Diff ``a`` and ``b`` into ``r`` - staticFor i, 0, a.coords.len: - r.coords[i].diff(a.coords[i], b.coords[i]) + when a.typeof().hasFp2x86asm(): + r.coords.fp2_sub_asm(a.coords, b.coords) + else: + staticFor i, 0, a.coords.len: + r.coords[i].diff(a.coords[i], b.coords[i]) func conj*(a: var QuadraticExt) = ## Computes the conjugate in-place @@ -692,50 +768,6 @@ func prod2x*( {.pop.} # inline -# ############################################################ -# # -# Cost functions # -# # -# ############################################################ - -func prefer_3sqr_over_2mul(F: type ExtensionField): bool {.compileTime.} = - ## Returns true - ## if time(3sqr) < time(2mul) in the extension fields - - let a = default(F) - # No shortcut in the VM - when a.c0 is Fp12: - # Benchmarked on BLS12-381 - when a.c0.c0 is Fp6: - return true - elif a.c0.c0 is Fp4: - return false - else: return false - else: return false - -func has_large_NR_norm(Name: static Algebra): bool = - ## Returns true if the non-residue of the extension fields - ## has a large norm - - const j = Name.getNonResidueFp() - const u = Name.getNonResidueFp2()[0] - const v = Name.getNonResidueFp2()[1] - - const norm2 = u*u + (j*v)*(j*v) - - # Compute integer square root - var norm = 0 - while (norm+1) * (norm+1) <= norm2: - norm += 1 - - return norm > 5 - -func has_large_field_elem*(Name: static Algebra): bool = - ## Returns true if field element are large - ## and necessitate custom routine for assembly in particular - let a = default(Fp[Name]) - return a.mres.limbs.len > 6 - # ############################################################ # # # Quadratic extensions # diff --git a/constantine/platforms/ast_rebuilder.nim b/constantine/platforms/ast_rebuilder.nim index 2c5347a9..d5866260 100644 --- a/constantine/platforms/ast_rebuilder.nim +++ b/constantine/platforms/ast_rebuilder.nim @@ -74,4 +74,4 @@ proc rebuildUntypedAst*(ast: NimNode, dropRootStmtList = false): NimNode = if dropRootStmtList and ast.kind == nnkStmtList: return rebuild(ast[0]) else: - result = rebuild(ast) \ No newline at end of file + result = rebuild(ast) diff --git a/constantine/platforms/static_for.nim b/constantine/platforms/static_for.nim index c83ba363..8c492bed 100644 --- a/constantine/platforms/static_for.nim +++ b/constantine/platforms/static_for.nim @@ -20,6 +20,16 @@ proc replaceNodes(ast: NimNode, what: NimNode, by: NimNode): NimNode = return node of nnkLiterals: return node + + # Rebuild untyped AST + # -------------------- + of nnkHiddenStdConv: + if node[1].kind == nnkIntLit: + return node[1] + else: + expectKind(node[1], nnkSym) + return ident($node[1]) + # -------------------- else: var rTree = node.kind.newTree() for child in node: diff --git a/constantine/platforms/x86/macro_assembler_x86_intel.nim b/constantine/platforms/x86/macro_assembler_x86_intel.nim index 357c0953..cfa640eb 100644 --- a/constantine/platforms/x86/macro_assembler_x86_intel.nim +++ b/constantine/platforms/x86/macro_assembler_x86_intel.nim @@ -130,7 +130,7 @@ const OutputReg = {asmOutputEarlyClobber, asmInputOutput, asmInputOutputEarlyClo func toString*(nimSymbol: NimNode): string = # We need to dereference the hidden pointer of var param - let isPtr = nimSymbol.kind in {nnkHiddenDeref, nnkPtrTy} + let isPtr = nimSymbol.kind in {nnkHiddenDeref, nnkPtrTy, nnkDerefExpr} let isAddr = nimSymbol.kind in {nnkInfix, nnkCall} and (nimSymbol[0].eqIdent"addr" or nimSymbol[0].eqIdent"unsafeAddr") let nimSymbol = if isPtr: nimSymbol[0] @@ -432,7 +432,7 @@ func generate*(a: Assembler_x86): NimNode = params.add newLit(": ") & inOperands.foldl(a & newLit(", ") & b) & newLit("\n") else: params.add newLit(":\n") - + let clobbers = [(a.isStackClobbered, "sp"), (a.areFlagsClobbered, "cc"), (memClobbered, "memory")]