From b0a95d1dbb52ed902e854d21d79797abb3d90f40 Mon Sep 17 00:00:00 2001 From: Hanno Becker Date: Mon, 2 Dec 2024 05:12:46 +0000 Subject: [PATCH] AVX2: Document bounds in NTT and invNTT This commit documents the AVX2 assembly for the [inv]NTT in the x86_64 native backend. In particular, it tracks the bounds of data through the various layers. The invNTT implementation is highly aggressive in terms of minimizing the number of modular reductions, which makes a pen-and-paper analysis rather difficult. At the last invNTT layer, the bounds reasoning applied is not enough to show absence of overflow. To remedy, an additional reduction is added. The analysis for the forward NTT is straightforward. Ultimately, we confirm the absolute bound of 8*MLKEM_Q for the output of forward and inverse NTT; this is the contractual bound that the higher-level C-code is working with. Signed-off-by: Hanno Becker --- mlkem/native/x86_64/fq.inc | 2 + mlkem/native/x86_64/intt.S | 138 ++++++++++++++++++++------------ mlkem/native/x86_64/ntt.S | 59 +++++++++----- mlkem/native/x86_64/shuffle.inc | 11 ++- 4 files changed, 135 insertions(+), 75 deletions(-) diff --git a/mlkem/native/x86_64/fq.inc b/mlkem/native/x86_64/fq.inc index 3030fa0bc..826fb47cf 100644 --- a/mlkem/native/x86_64/fq.inc +++ b/mlkem/native/x86_64/fq.inc @@ -28,6 +28,8 @@ vpand %ymm0,%ymm\x,%ymm\x vpaddw %ymm\x,%ymm\r,%ymm\r .endm +// Montgomery multiplication between b and ah, +// with Montgomery twist of ah in al. .macro fqmulprecomp al,ah,b,x=12 vpmullw %ymm\al,%ymm\b,%ymm\x vpmulhw %ymm\ah,%ymm\b,%ymm\b diff --git a/mlkem/native/x86_64/intt.S b/mlkem/native/x86_64/intt.S index b25a1e3dc..afad8977f 100644 --- a/mlkem/native/x86_64/intt.S +++ b/mlkem/native/x86_64/intt.S @@ -13,50 +13,47 @@ .include "shuffle.inc" .include "fq.inc" +// Compute four GS butterflies between rh{0,1,2,3} and rl{0,1,2,3}. +// Butterflies 0,1 use root zh0 and twisted root zl0, and butterflies +// 2,3 use root zh1 and twisted root zl1 +// Results are again in rl{0-3} and rh{0-3} .macro butterfly rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3,zl0=2,zl1=2,zh0=3,zh1=3 -vpsubw %ymm\rl0,%ymm\rh0,%ymm12 -vpaddw %ymm\rh0,%ymm\rl0,%ymm\rl0 -vpsubw %ymm\rl1,%ymm\rh1,%ymm13 +vpsubw %ymm\rl0,%ymm\rh0,%ymm12 // ymm12 = rh0 - rl0 +vpaddw %ymm\rh0,%ymm\rl0,%ymm\rl0 // rl0 = rh0 + rl0 +vpsubw %ymm\rl1,%ymm\rh1,%ymm13 // ymm13 = rh1 - rl1 -vpmullw %ymm\zl0,%ymm12,%ymm\rh0 -vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl1 -vpsubw %ymm\rl2,%ymm\rh2,%ymm14 +vpmullw %ymm\zl0,%ymm12,%ymm\rh0 // rh0 = (rh0 - rl0) * root0_twisted +vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl1 // rl1 = rh1 + rh1 +vpsubw %ymm\rl2,%ymm\rh2,%ymm14 // ymm14 = rh2 - rl2 -vpmullw %ymm\zl0,%ymm13,%ymm\rh1 -vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl2 -vpsubw %ymm\rl3,%ymm\rh3,%ymm15 +vpmullw %ymm\zl0,%ymm13,%ymm\rh1 // rh1 = (rh1 - rl1) * root0_twisted +vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl2 // rl2 = rh2 + rl2 +vpsubw %ymm\rl3,%ymm\rh3,%ymm15 // ymm15 = rh3 - rl3 -vpmullw %ymm\zl1,%ymm14,%ymm\rh2 -vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl3 -vpmullw %ymm\zl1,%ymm15,%ymm\rh3 +vpmullw %ymm\zl1,%ymm14,%ymm\rh2 // rh2 = (rh2 - rl2) * root1_twisted +vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl3 // rl3 = rh3 + rl3 +vpmullw %ymm\zl1,%ymm15,%ymm\rh3 // rh3 = (rh3 - rl3) * root1_twisted -vpmulhw %ymm\zh0,%ymm12,%ymm12 -vpmulhw %ymm\zh0,%ymm13,%ymm13 +vpmulhw %ymm\zh0,%ymm12,%ymm12 // ymm12 = (rh0 - rl0) * root0 +vpmulhw %ymm\zh0,%ymm13,%ymm13 // ymm13 = (rh1 - rl1) * root0 -vpmulhw %ymm\zh1,%ymm14,%ymm14 -vpmulhw %ymm\zh1,%ymm15,%ymm15 +vpmulhw %ymm\zh1,%ymm14,%ymm14 // ymm14 = (rh2 - rl2) * root1 +vpmulhw %ymm\zh1,%ymm15,%ymm15 // ymm15 = (rh3 - rl3) * root1 -vpmulhw %ymm0,%ymm\rh0,%ymm\rh0 +vpmulhw %ymm0,%ymm\rh0,%ymm\rh0 // rh0 = Q * [(rh0 - rl0) * root0_twisted] +vpmulhw %ymm0,%ymm\rh1,%ymm\rh1 // rh1 = Q * [(rh1 - rl1) * root0_twisted] +vpmulhw %ymm0,%ymm\rh2,%ymm\rh2 // rh2 = Q * [(rh2 - rl2) * root0_twisted] +vpmulhw %ymm0,%ymm\rh3,%ymm\rh3 // rh3 = Q * [(rh3 - rl3) * root0_twisted] -vpmulhw %ymm0,%ymm\rh1,%ymm\rh1 - -vpmulhw %ymm0,%ymm\rh2,%ymm\rh2 -vpmulhw %ymm0,%ymm\rh3,%ymm\rh3 - -# - -# - -vpsubw %ymm\rh0,%ymm12,%ymm\rh0 - -vpsubw %ymm\rh1,%ymm13,%ymm\rh1 - -vpsubw %ymm\rh2,%ymm14,%ymm\rh2 -vpsubw %ymm\rh3,%ymm15,%ymm\rh3 +vpsubw %ymm\rh0,%ymm12,%ymm\rh0 // rh0 = montmul(rh0-rl0, root0) +vpsubw %ymm\rh1,%ymm13,%ymm\rh1 // rh1 = montmul(rh1-rl1, root0) +vpsubw %ymm\rh2,%ymm14,%ymm\rh2 // rh2 = montmul(rh2-rl2, root0) +vpsubw %ymm\rh3,%ymm15,%ymm\rh3 // rh3 = montmul(rh3-rl3, root0) .endm .macro intt_levels0t5 off /* level 0 */ +/* no bounds assumptions */ vmovdqa _16XFLO*2(%rsi),%ymm2 vmovdqa _16XFHI*2(%rsi),%ymm3 @@ -80,6 +77,8 @@ fqmulprecomp 2,3,10 fqmulprecomp 2,3,9 fqmulprecomp 2,3,11 +/* bounds: coefficients < q */ + vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+208)*2(%rsi),%ymm15 vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+176)*2(%rsi),%ymm1 vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+224)*2(%rsi),%ymm2 @@ -92,6 +91,14 @@ vpshufb %ymm12,%ymm3,%ymm3 butterfly 4,5,8,9,6,7,10,11,15,1,2,3 +// Montgomery multiplication of a value INT16_MAX). +red16 7 +// global abs bound < 4q vmovdqa %ymm7,(128*\off+ 0)*2(%rdi) vmovdqa %ymm9,(128*\off+ 16)*2(%rdi) @@ -176,10 +210,10 @@ vmovdqa (64*\off+176)*2(%rdi),%ymm11 vpbroadcastq (_ZETAS_EXP+4)*2(%rsi),%ymm3 butterfly 4,5,6,7,8,9,10,11 +// global abs bound < 8q -.if \off == 0 -red16 4 -.endif +// REF-CHANGE: The official AVX2 implementation has a `red16 4` for `off=0`. +// We don't need this because of the earlier red16 which ensures an 8q bound vmovdqa %ymm4,(64*\off+ 0)*2(%rdi) vmovdqa %ymm5,(64*\off+ 16)*2(%rdi) diff --git a/mlkem/native/x86_64/ntt.S b/mlkem/native/x86_64/ntt.S index 0c94c94d0..54c46a3e3 100644 --- a/mlkem/native/x86_64/ntt.S +++ b/mlkem/native/x86_64/ntt.S @@ -13,6 +13,7 @@ .include "shuffle.inc" +// Compute steps 1,2 / 3 of Montgomery multiplication .macro mul rh0,rh1,rh2,rh3,zl0=15,zl1=15,zh0=2,zh1=2 vpmullw %ymm\zl0,%ymm\rh0,%ymm12 vpmullw %ymm\zl0,%ymm\rh1,%ymm13 @@ -27,6 +28,8 @@ vpmulhw %ymm\zh1,%ymm\rh2,%ymm\rh2 vpmulhw %ymm\zh1,%ymm\rh3,%ymm\rh3 .endm +// Compute step 3 / 3 of Montgomery multiplication +// Multiply-high is signed; outputs are bound by 2^15 * q in abs value .macro reduce vpmulhw %ymm0,%ymm12,%ymm12 vpmulhw %ymm0,%ymm13,%ymm13 @@ -35,28 +38,42 @@ vpmulhw %ymm0,%ymm14,%ymm14 vpmulhw %ymm0,%ymm15,%ymm15 .endm +// Finish Montgomery multiplication and compute add/sub steps in NTT butterfly +// +// At this point, the two high-products of 4 ongoing Montgomery multiplications +// are in %ymm{12,13,14,15} and %ymm{rh{0,1,2,3}}, respectively. +// The NTT coefficients that the results of the Montgomery multiplications should +// be add/sub-ed with, are in %ymm{rl{0,1,2,3}}. +// +// What's interesting, here, is that rather than completing the Montgomery +// multiplications by computing `%ymm{12+i} + %ymm{rh{i}}`, and then add/sub'ing +// the result into %ymm{rl{0,1,2,3}}, we add/sub both `%ymm{12+i}` and +// %ymm{rh{i}} to %ymm{rl{0,1,2,3}}, and then add the results. +// +// Functionally, though, this is still a signed Montgomery multiplication +// followed by an add/sub. +// +// Since the result of the Montgomery multiplication is bounded +// by q in absolute value, the coefficients overall grow by not +// more than q in absolute value per layer. .macro update rln,rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3 -vpaddw %ymm\rh0,%ymm\rl0,%ymm\rln -vpsubw %ymm\rh0,%ymm\rl0,%ymm\rh0 -vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl0 - -vpsubw %ymm\rh1,%ymm\rl1,%ymm\rh1 -vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl1 -vpsubw %ymm\rh2,%ymm\rl2,%ymm\rh2 - -vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl2 -vpsubw %ymm\rh3,%ymm\rl3,%ymm\rh3 - -vpsubw %ymm12,%ymm\rln,%ymm\rln -vpaddw %ymm12,%ymm\rh0,%ymm\rh0 -vpsubw %ymm13,%ymm\rl0,%ymm\rl0 - -vpaddw %ymm13,%ymm\rh1,%ymm\rh1 -vpsubw %ymm14,%ymm\rl1,%ymm\rl1 -vpaddw %ymm14,%ymm\rh2,%ymm\rh2 - -vpsubw %ymm15,%ymm\rl2,%ymm\rl2 -vpaddw %ymm15,%ymm\rh3,%ymm\rh3 +vpaddw %ymm\rh0,%ymm\rl0,%ymm\rln // rln = rl0 + rh0 +vpsubw %ymm\rh0,%ymm\rl0,%ymm\rh0 // rh0 = rl0 - rh0 +vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl0 // rl0 = rl1 + rh1 +vpsubw %ymm\rh1,%ymm\rl1,%ymm\rh1 // rh1 = rl1 - rh1 +vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl1 // rl1 = rl2 + rh2 +vpsubw %ymm\rh2,%ymm\rl2,%ymm\rh2 // rh2 = rl2 - rh2 +vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl2 // rl2 = rl3 + rh3 +vpsubw %ymm\rh3,%ymm\rl3,%ymm\rh3 // rh3 = rl3 - rh3 + +vpsubw %ymm12,%ymm\rln,%ymm\rln // rln = rh0 + rl0 - ymm12 = rl0 + (rh0 - ymm12) +vpaddw %ymm12,%ymm\rh0,%ymm\rh0 // rh0 = rl0 - rh0 + ymm12 = rl0 - (rh0 - ymm12) +vpsubw %ymm13,%ymm\rl0,%ymm\rl0 // rl0 = rl1 + rh1 - ymm13 = rl1 + (rh1 - ymm13) +vpaddw %ymm13,%ymm\rh1,%ymm\rh1 // rh1 = rl1 - rh1 + ymm13 = rl1 - (rh1 - ymm13) +vpsubw %ymm14,%ymm\rl1,%ymm\rl1 // rl1 = rh2 + rl2 - ymm14 = rl2 + (rh2 - ymm14) +vpaddw %ymm14,%ymm\rh2,%ymm\rh2 // rh2 = rl2 - rh2 + ymm14 = rl2 - (rh2 - ymm14) +vpsubw %ymm15,%ymm\rl2,%ymm\rl2 // rl2 = rh3 + rl3 - ymm15 = rl3 + (rh3 - ymm15) +vpaddw %ymm15,%ymm\rh3,%ymm\rh3 // rh3 = rl3 - rh3 + ymm15 = rl3 - (rh3 - ymm15) .endm .macro level0 off diff --git a/mlkem/native/x86_64/shuffle.inc b/mlkem/native/x86_64/shuffle.inc index 9709fee2f..418a21e6f 100644 --- a/mlkem/native/x86_64/shuffle.inc +++ b/mlkem/native/x86_64/shuffle.inc @@ -11,12 +11,19 @@ vpunpcklqdq %ymm\r1,%ymm\r0,%ymm\r2 vpunpckhqdq %ymm\r1,%ymm\r0,%ymm\r3 .endm +// Shuffle r0=(a0,b0,c0,d0,...), r1=(a1,b1,c1,d1,...) into +// r2 = (a0,b0,a1,b1,e0,f0,e1,f1,...) +// r3 = (c0,d0,c1,d1,g0,h0,g1,h1,...) .macro shuffle2 r0,r1,r2,r3 -#vpsllq $32,%ymm\r1,%ymm\r2 +// r2=(a1,b1,a1,b1,e1,f1,e1,f1,...) vmovsldup %ymm\r1,%ymm\r2 +// Conditional move +// 0xAA = 0b10101010 +// r2=(a0,b0,a1,b1,e0,f0,e1,f1,...) vpblendd $0xAA,%ymm\r2,%ymm\r0,%ymm\r2 +// r0=(c0,d0,0,0,g0,h0,0,0,...) vpsrlq $32,%ymm\r0,%ymm\r0 -#vmovshdup %ymm\r0,%ymm\r0 +// r3=(c0,d0,c1,d1,g0,h0,g1,h1,...) vpblendd $0xAA,%ymm\r1,%ymm\r0,%ymm\r3 .endm