Skip to content

Commit

Permalink
AVX2: Document bounds in NTT and invNTT
Browse files Browse the repository at this point in the history
This commit documents the AVX2 assembly for the [inv]NTT in the x86_64
native backend. In particular, it tracks the bounds of data through
the various layers.

The invNTT implementation is highly aggressive in terms of minimizing
the number of modular reductions, which makes a pen-and-paper analysis
rather difficult.

At the last invNTT layer, the bounds reasoning applied is not enough to
show absence of overflow. To remedy, an additional reduction is added.

The analysis for the forward NTT is straightforward.

Ultimately, we confirm the absolute bound of 8*MLKEM_Q for the output
of forward and inverse NTT; this is the contractual bound that the
higher-level C-code is working with.

Signed-off-by: Hanno Becker <[email protected]>
  • Loading branch information
hanno-becker committed Dec 2, 2024
1 parent b2c6403 commit b0a95d1
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 75 deletions.
2 changes: 2 additions & 0 deletions mlkem/native/x86_64/fq.inc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ vpand %ymm0,%ymm\x,%ymm\x
vpaddw %ymm\x,%ymm\r,%ymm\r
.endm

// Montgomery multiplication between b and ah,
// with Montgomery twist of ah in al.
.macro fqmulprecomp al,ah,b,x=12
vpmullw %ymm\al,%ymm\b,%ymm\x
vpmulhw %ymm\ah,%ymm\b,%ymm\b
Expand Down
138 changes: 86 additions & 52 deletions mlkem/native/x86_64/intt.S
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,47 @@
.include "shuffle.inc"
.include "fq.inc"

// Compute four GS butterflies between rh{0,1,2,3} and rl{0,1,2,3}.
// Butterflies 0,1 use root zh0 and twisted root zl0, and butterflies
// 2,3 use root zh1 and twisted root zl1
// Results are again in rl{0-3} and rh{0-3}
.macro butterfly rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3,zl0=2,zl1=2,zh0=3,zh1=3
vpsubw %ymm\rl0,%ymm\rh0,%ymm12
vpaddw %ymm\rh0,%ymm\rl0,%ymm\rl0
vpsubw %ymm\rl1,%ymm\rh1,%ymm13
vpsubw %ymm\rl0,%ymm\rh0,%ymm12 // ymm12 = rh0 - rl0
vpaddw %ymm\rh0,%ymm\rl0,%ymm\rl0 // rl0 = rh0 + rl0
vpsubw %ymm\rl1,%ymm\rh1,%ymm13 // ymm13 = rh1 - rl1

vpmullw %ymm\zl0,%ymm12,%ymm\rh0
vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl1
vpsubw %ymm\rl2,%ymm\rh2,%ymm14
vpmullw %ymm\zl0,%ymm12,%ymm\rh0 // rh0 = (rh0 - rl0) * root0_twisted
vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl1 // rl1 = rh1 + rh1
vpsubw %ymm\rl2,%ymm\rh2,%ymm14 // ymm14 = rh2 - rl2

vpmullw %ymm\zl0,%ymm13,%ymm\rh1
vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl2
vpsubw %ymm\rl3,%ymm\rh3,%ymm15
vpmullw %ymm\zl0,%ymm13,%ymm\rh1 // rh1 = (rh1 - rl1) * root0_twisted
vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl2 // rl2 = rh2 + rl2
vpsubw %ymm\rl3,%ymm\rh3,%ymm15 // ymm15 = rh3 - rl3

vpmullw %ymm\zl1,%ymm14,%ymm\rh2
vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl3
vpmullw %ymm\zl1,%ymm15,%ymm\rh3
vpmullw %ymm\zl1,%ymm14,%ymm\rh2 // rh2 = (rh2 - rl2) * root1_twisted
vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl3 // rl3 = rh3 + rl3
vpmullw %ymm\zl1,%ymm15,%ymm\rh3 // rh3 = (rh3 - rl3) * root1_twisted

vpmulhw %ymm\zh0,%ymm12,%ymm12
vpmulhw %ymm\zh0,%ymm13,%ymm13
vpmulhw %ymm\zh0,%ymm12,%ymm12 // ymm12 = (rh0 - rl0) * root0
vpmulhw %ymm\zh0,%ymm13,%ymm13 // ymm13 = (rh1 - rl1) * root0

vpmulhw %ymm\zh1,%ymm14,%ymm14
vpmulhw %ymm\zh1,%ymm15,%ymm15
vpmulhw %ymm\zh1,%ymm14,%ymm14 // ymm14 = (rh2 - rl2) * root1
vpmulhw %ymm\zh1,%ymm15,%ymm15 // ymm15 = (rh3 - rl3) * root1

vpmulhw %ymm0,%ymm\rh0,%ymm\rh0
vpmulhw %ymm0,%ymm\rh0,%ymm\rh0 // rh0 = Q * [(rh0 - rl0) * root0_twisted]
vpmulhw %ymm0,%ymm\rh1,%ymm\rh1 // rh1 = Q * [(rh1 - rl1) * root0_twisted]
vpmulhw %ymm0,%ymm\rh2,%ymm\rh2 // rh2 = Q * [(rh2 - rl2) * root0_twisted]
vpmulhw %ymm0,%ymm\rh3,%ymm\rh3 // rh3 = Q * [(rh3 - rl3) * root0_twisted]

vpmulhw %ymm0,%ymm\rh1,%ymm\rh1

vpmulhw %ymm0,%ymm\rh2,%ymm\rh2
vpmulhw %ymm0,%ymm\rh3,%ymm\rh3

#

#

vpsubw %ymm\rh0,%ymm12,%ymm\rh0

vpsubw %ymm\rh1,%ymm13,%ymm\rh1

vpsubw %ymm\rh2,%ymm14,%ymm\rh2
vpsubw %ymm\rh3,%ymm15,%ymm\rh3
vpsubw %ymm\rh0,%ymm12,%ymm\rh0 // rh0 = montmul(rh0-rl0, root0)
vpsubw %ymm\rh1,%ymm13,%ymm\rh1 // rh1 = montmul(rh1-rl1, root0)
vpsubw %ymm\rh2,%ymm14,%ymm\rh2 // rh2 = montmul(rh2-rl2, root0)
vpsubw %ymm\rh3,%ymm15,%ymm\rh3 // rh3 = montmul(rh3-rl3, root0)
.endm

.macro intt_levels0t5 off
/* level 0 */
/* no bounds assumptions */
vmovdqa _16XFLO*2(%rsi),%ymm2
vmovdqa _16XFHI*2(%rsi),%ymm3

Expand All @@ -80,6 +77,8 @@ fqmulprecomp 2,3,10
fqmulprecomp 2,3,9
fqmulprecomp 2,3,11

/* bounds: coefficients < q */

vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+208)*2(%rsi),%ymm15
vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+176)*2(%rsi),%ymm1
vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+224)*2(%rsi),%ymm2
Expand All @@ -92,6 +91,14 @@ vpshufb %ymm12,%ymm3,%ymm3

butterfly 4,5,8,9,6,7,10,11,15,1,2,3

// Montgomery multiplication of a value <C*q with a signed canonical
// twiddle has absolute value < q*(0.0254 * C + 1/2) (see reduce.c).
// In the above butterfly, the values multiplied with twiddles have
// absolute value <2q, so we get an absolute bound < q*(1/2 + 2 * 0.0254),
// which is < INT16_MAX/16.
//
// 4,5,8,9 abs bound < 2q; 6,7,10,11 abs bound < INT16_MAX/16

/* level 1 */
vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+144)*2(%rsi),%ymm2
vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+160)*2(%rsi),%ymm3
Expand All @@ -101,55 +108,82 @@ vpshufb %ymm1,%ymm3,%ymm3

butterfly 4,5,6,7,8,9,10,11,2,2,3,3

shuffle1 4,5,3,5
shuffle1 6,7,4,7
shuffle1 8,9,6,9
shuffle1 10,11,8,11
// For 8,9,10,11, it is sufficient to use the bound <q (much weaker
// than what we used above) for the absolute value of the Montgomery
// multiplication with a twiddle.
// 4,5 abs bound < 4q; 6,7 abs bound < INT16_MAX/8; 8,9,10,11 abs bound <q.

shuffle1 4,5,3,5 // 3,5 abs bound < 4q
shuffle1 6,7,4,7 // 4,7 abs bound < INT16_MAX/8
shuffle1 8,9,6,9 // 6,9 abs bound < q
shuffle1 10,11,8,11 // 8,11 abs bound < q

/* level 2 */
vmovdqa _REVIDXD*2(%rsi),%ymm12
vpermd (_ZETAS_EXP+(1-\off)*224+112)*2(%rsi),%ymm12,%ymm2
vpermd (_ZETAS_EXP+(1-\off)*224+128)*2(%rsi),%ymm12,%ymm10

butterfly 3,4,6,8,5,7,9,11,2,2,10,10
// 3 abs bound < 8q, 4 abs bound < INT16_MAX/4, 6,8 abs bound < 2q, 5,7,9,11 abs bound < q

vmovdqa _16XV*2(%rsi),%ymm1
red16 3
// 4 abs bound < INT16_MAX/4, 6,8 abs bound < 2q, 3,5,7,9,11 abs bound < q

shuffle2 3,4,10,4
shuffle2 6,8,3,8
shuffle2 5,7,6,7
shuffle2 9,11,5,11
shuffle2 3,4,10,4 // see comment for shuffle2;
// 10,4: even 16-bit pairs from 3, so abs bound <q
// 10,4: odd 16-bit pairs from 4, so abs bound <INT16_MAX/4
shuffle2 6,8,3,8 // 3,8 abs bound < 2q
shuffle2 5,7,6,7 // 6,7 abs bound < q
shuffle2 9,11,5,11 // 5,11 abs bound < q

/* level 3 */
vpermq $0x1B,(_ZETAS_EXP+(1-\off)*224+80)*2(%rsi),%ymm2
vpermq $0x1B,(_ZETAS_EXP+(1-\off)*224+96)*2(%rsi),%ymm9

butterfly 10,3,6,5,4,8,7,11,2,2,9,9
// 10 abs bound < INT16_MAX/2
// 3 abs bound < 4q, 5,6 abs bound < 2q
// 4,8,7,11 abs bound < q

shuffle4 10,3,9,3
shuffle4 6,5,10,5
shuffle4 4,8,6,8
shuffle4 7,11,4,11
shuffle4 10,3,9,3 // 3,9 abs bound < INT16_MAX/2
shuffle4 6,5,10,5 // 5,10 abs bound < 2q
shuffle4 4,8,6,8 // 6,8 abs bound < q
shuffle4 7,11,4,11 // 4,11 abs bound < q

/* level 4 */
vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+48)*2(%rsi),%ymm2
vpermq $0x4E,(_ZETAS_EXP+(1-\off)*224+64)*2(%rsi),%ymm7

butterfly 9,10,6,4,3,5,8,11,2,2,7,7

// 9 abs bound < INT16_MAX
// 10 abs bound < 4q, 4,6 abs bound <2q
// 3,5,8,11 abs bound < q
red16 9
// 10 abs bound < 4q, 4,6 abs bound <2q
// 3,5,8,9,11 abs bound < q

shuffle8 9,10,7,10
shuffle8 6,4,9,4
shuffle8 3,5,6,5
shuffle8 8,11,3,11
shuffle8 9,10,7,10 // 7,10 abs bound < 4q
shuffle8 6,4,9,4 // 4,9 abs bound < 2q
shuffle8 3,5,6,5 // 5,6 abs bound < q
shuffle8 8,11,3,11 // 3,11 abs bound < q

/* level 5 */
vmovdqa (_ZETAS_EXP+(1-\off)*224+16)*2(%rsi),%ymm2
vmovdqa (_ZETAS_EXP+(1-\off)*224+32)*2(%rsi),%ymm8

butterfly 7,9,6,3,10,4,5,11,2,2,8,8
// 7 abs bound <8q
// 9 abs bound <4q
// 6,3 abs bound < 2q
// 4,5,10,11 abs bound < q

// REF-CHANGE: The official AVX2 implementation does not
// have this reduction, but it is not readily clear how
// to improve the 8q bound on ymm7 to guarantee that
// layer 6 won't overflow (16q > INT16_MAX).
red16 7
// global abs bound < 4q

vmovdqa %ymm7,(128*\off+ 0)*2(%rdi)
vmovdqa %ymm9,(128*\off+ 16)*2(%rdi)
Expand All @@ -176,10 +210,10 @@ vmovdqa (64*\off+176)*2(%rdi),%ymm11
vpbroadcastq (_ZETAS_EXP+4)*2(%rsi),%ymm3

butterfly 4,5,6,7,8,9,10,11
// global abs bound < 8q

.if \off == 0
red16 4
.endif
// REF-CHANGE: The official AVX2 implementation has a `red16 4` for `off=0`.
// We don't need this because of the earlier red16 which ensures an 8q bound

vmovdqa %ymm4,(64*\off+ 0)*2(%rdi)
vmovdqa %ymm5,(64*\off+ 16)*2(%rdi)
Expand Down
59 changes: 38 additions & 21 deletions mlkem/native/x86_64/ntt.S
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

.include "shuffle.inc"

// Compute steps 1,2 / 3 of Montgomery multiplication
.macro mul rh0,rh1,rh2,rh3,zl0=15,zl1=15,zh0=2,zh1=2
vpmullw %ymm\zl0,%ymm\rh0,%ymm12
vpmullw %ymm\zl0,%ymm\rh1,%ymm13
Expand All @@ -27,6 +28,8 @@ vpmulhw %ymm\zh1,%ymm\rh2,%ymm\rh2
vpmulhw %ymm\zh1,%ymm\rh3,%ymm\rh3
.endm

// Compute step 3 / 3 of Montgomery multiplication
// Multiply-high is signed; outputs are bound by 2^15 * q in abs value
.macro reduce
vpmulhw %ymm0,%ymm12,%ymm12
vpmulhw %ymm0,%ymm13,%ymm13
Expand All @@ -35,28 +38,42 @@ vpmulhw %ymm0,%ymm14,%ymm14
vpmulhw %ymm0,%ymm15,%ymm15
.endm

// Finish Montgomery multiplication and compute add/sub steps in NTT butterfly
//
// At this point, the two high-products of 4 ongoing Montgomery multiplications
// are in %ymm{12,13,14,15} and %ymm{rh{0,1,2,3}}, respectively.
// The NTT coefficients that the results of the Montgomery multiplications should
// be add/sub-ed with, are in %ymm{rl{0,1,2,3}}.
//
// What's interesting, here, is that rather than completing the Montgomery
// multiplications by computing `%ymm{12+i} + %ymm{rh{i}}`, and then add/sub'ing
// the result into %ymm{rl{0,1,2,3}}, we add/sub both `%ymm{12+i}` and
// %ymm{rh{i}} to %ymm{rl{0,1,2,3}}, and then add the results.
//
// Functionally, though, this is still a signed Montgomery multiplication
// followed by an add/sub.
//
// Since the result of the Montgomery multiplication is bounded
// by q in absolute value, the coefficients overall grow by not
// more than q in absolute value per layer.
.macro update rln,rl0,rl1,rl2,rl3,rh0,rh1,rh2,rh3
vpaddw %ymm\rh0,%ymm\rl0,%ymm\rln
vpsubw %ymm\rh0,%ymm\rl0,%ymm\rh0
vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl0

vpsubw %ymm\rh1,%ymm\rl1,%ymm\rh1
vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl1
vpsubw %ymm\rh2,%ymm\rl2,%ymm\rh2

vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl2
vpsubw %ymm\rh3,%ymm\rl3,%ymm\rh3

vpsubw %ymm12,%ymm\rln,%ymm\rln
vpaddw %ymm12,%ymm\rh0,%ymm\rh0
vpsubw %ymm13,%ymm\rl0,%ymm\rl0

vpaddw %ymm13,%ymm\rh1,%ymm\rh1
vpsubw %ymm14,%ymm\rl1,%ymm\rl1
vpaddw %ymm14,%ymm\rh2,%ymm\rh2

vpsubw %ymm15,%ymm\rl2,%ymm\rl2
vpaddw %ymm15,%ymm\rh3,%ymm\rh3
vpaddw %ymm\rh0,%ymm\rl0,%ymm\rln // rln = rl0 + rh0
vpsubw %ymm\rh0,%ymm\rl0,%ymm\rh0 // rh0 = rl0 - rh0
vpaddw %ymm\rh1,%ymm\rl1,%ymm\rl0 // rl0 = rl1 + rh1
vpsubw %ymm\rh1,%ymm\rl1,%ymm\rh1 // rh1 = rl1 - rh1
vpaddw %ymm\rh2,%ymm\rl2,%ymm\rl1 // rl1 = rl2 + rh2
vpsubw %ymm\rh2,%ymm\rl2,%ymm\rh2 // rh2 = rl2 - rh2
vpaddw %ymm\rh3,%ymm\rl3,%ymm\rl2 // rl2 = rl3 + rh3
vpsubw %ymm\rh3,%ymm\rl3,%ymm\rh3 // rh3 = rl3 - rh3

vpsubw %ymm12,%ymm\rln,%ymm\rln // rln = rh0 + rl0 - ymm12 = rl0 + (rh0 - ymm12)
vpaddw %ymm12,%ymm\rh0,%ymm\rh0 // rh0 = rl0 - rh0 + ymm12 = rl0 - (rh0 - ymm12)
vpsubw %ymm13,%ymm\rl0,%ymm\rl0 // rl0 = rl1 + rh1 - ymm13 = rl1 + (rh1 - ymm13)
vpaddw %ymm13,%ymm\rh1,%ymm\rh1 // rh1 = rl1 - rh1 + ymm13 = rl1 - (rh1 - ymm13)
vpsubw %ymm14,%ymm\rl1,%ymm\rl1 // rl1 = rh2 + rl2 - ymm14 = rl2 + (rh2 - ymm14)
vpaddw %ymm14,%ymm\rh2,%ymm\rh2 // rh2 = rl2 - rh2 + ymm14 = rl2 - (rh2 - ymm14)
vpsubw %ymm15,%ymm\rl2,%ymm\rl2 // rl2 = rh3 + rl3 - ymm15 = rl3 + (rh3 - ymm15)
vpaddw %ymm15,%ymm\rh3,%ymm\rh3 // rh3 = rl3 - rh3 + ymm15 = rl3 - (rh3 - ymm15)
.endm

.macro level0 off
Expand Down
11 changes: 9 additions & 2 deletions mlkem/native/x86_64/shuffle.inc
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@ vpunpcklqdq %ymm\r1,%ymm\r0,%ymm\r2
vpunpckhqdq %ymm\r1,%ymm\r0,%ymm\r3
.endm

// Shuffle r0=(a0,b0,c0,d0,...), r1=(a1,b1,c1,d1,...) into
// r2 = (a0,b0,a1,b1,e0,f0,e1,f1,...)
// r3 = (c0,d0,c1,d1,g0,h0,g1,h1,...)
.macro shuffle2 r0,r1,r2,r3
#vpsllq $32,%ymm\r1,%ymm\r2
// r2=(a1,b1,a1,b1,e1,f1,e1,f1,...)
vmovsldup %ymm\r1,%ymm\r2
// Conditional move
// 0xAA = 0b10101010
// r2=(a0,b0,a1,b1,e0,f0,e1,f1,...)
vpblendd $0xAA,%ymm\r2,%ymm\r0,%ymm\r2
// r0=(c0,d0,0,0,g0,h0,0,0,...)
vpsrlq $32,%ymm\r0,%ymm\r0
#vmovshdup %ymm\r0,%ymm\r0
// r3=(c0,d0,c1,d1,g0,h0,g1,h1,...)
vpblendd $0xAA,%ymm\r1,%ymm\r0,%ymm\r3
.endm

Expand Down

0 comments on commit b0a95d1

Please sign in to comment.