Skip to content

Commit

Permalink
AVX2: Work through basemul and document bounds
Browse files Browse the repository at this point in the history
This commit adds a few comments to the AVX2 assembly for the
base multiplication in NTT domain. In particular, it explains
that each coefficient in the output of the basemul is bound by
2q in absolute value, and that, thus, the native AVX2 implementation
of polyvec_basemul_acc_montgomery_cached_avx2() does not overflow.

Signed-off-by: Hanno Becker <[email protected]>
  • Loading branch information
hanno-becker committed Dec 2, 2024
1 parent a2bf116 commit 2c4cdd2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
21 changes: 21 additions & 0 deletions mlkem/native/x86_64/basemul.S
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
#include "consts.h"
#include "params.h"

// Polynomials to be multiplied are denoted a+bX (rsi arg) and c+dX (rdx arg)
.macro schoolbook off
vmovdqa _16XQINV*2(%rcx),%ymm0
vmovdqa (64*\off+ 0)*2(%rsi),%ymm1 # a0
vmovdqa (64*\off+16)*2(%rsi),%ymm2 # b0
vmovdqa (64*\off+32)*2(%rsi),%ymm3 # a1
vmovdqa (64*\off+48)*2(%rsi),%ymm4 # b1

// Prepare Montgomery twists
vpmullw %ymm0,%ymm1,%ymm9 # a0.lo
vpmullw %ymm0,%ymm2,%ymm10 # b0.lo
vpmullw %ymm0,%ymm3,%ymm11 # a1.lo
Expand All @@ -26,6 +28,7 @@ vpmullw %ymm0,%ymm4,%ymm12 # b1.lo
vmovdqa (64*\off+ 0)*2(%rdx),%ymm5 # c0
vmovdqa (64*\off+16)*2(%rdx),%ymm6 # d0

// Compute high-parts of monomials in (a0+b0*X)*(c0+d0*X)
vpmulhw %ymm5,%ymm1,%ymm13 # a0c0.hi
vpmulhw %ymm6,%ymm1,%ymm1 # a0d0.hi
vpmulhw %ymm5,%ymm2,%ymm14 # b0c0.hi
Expand All @@ -34,23 +37,30 @@ vpmulhw %ymm6,%ymm2,%ymm2 # b0d0.hi
vmovdqa (64*\off+32)*2(%rdx),%ymm7 # c1
vmovdqa (64*\off+48)*2(%rdx),%ymm8 # d1

// Compute high-parts of monomials in (a1+b1*X)*(c1+d1*X)
// Don't yet accumulate nor reduce X^2
vpmulhw %ymm7,%ymm3,%ymm15 # a1c1.hi
vpmulhw %ymm8,%ymm3,%ymm3 # a1d1.hi
vpmulhw %ymm7,%ymm4,%ymm0 # b1c1.hi
vpmulhw %ymm8,%ymm4,%ymm4 # b1d1.hi

vmovdqa %ymm13,(%rsp)

// Compute low-parts of monomials in (a0+b0*X)*(c0+d0*X),
// using Montgomery twists calculated before
vpmullw %ymm5,%ymm9,%ymm13 # a0c0.lo
vpmullw %ymm6,%ymm9,%ymm9 # a0d0.lo
vpmullw %ymm5,%ymm10,%ymm5 # b0c0.lo
vpmullw %ymm6,%ymm10,%ymm10 # b0d0.lo

// Compute low-parts of monomials in (a1+b1*X)*(c1+d1*X),
// using Montgomery twists calculated before
vpmullw %ymm7,%ymm11,%ymm6 # a1c1.lo
vpmullw %ymm8,%ymm11,%ymm11 # a1d1.lo
vpmullw %ymm7,%ymm12,%ymm7 # b1c1.lo
vpmullw %ymm8,%ymm12,%ymm12 # b1d1.lo

// Compute 2nd high multiplication in Montgomery multiplication
vmovdqa _16XQ*2(%rcx),%ymm8
vpmulhw %ymm8,%ymm13,%ymm13
vpmulhw %ymm8,%ymm9,%ymm9
Expand All @@ -61,6 +71,7 @@ vpmulhw %ymm8,%ymm11,%ymm11
vpmulhw %ymm8,%ymm7,%ymm7
vpmulhw %ymm8,%ymm12,%ymm12

// Finish Montgomery multiplications
vpsubw (%rsp),%ymm13,%ymm13 # -a0c0
vpsubw %ymm9,%ymm1,%ymm9 # a0d0
vpsubw %ymm5,%ymm14,%ymm5 # b0c0
Expand All @@ -71,6 +82,10 @@ vpsubw %ymm11,%ymm3,%ymm11 # a1d1
vpsubw %ymm7,%ymm0,%ymm7 # b1c1
vpsubw %ymm12,%ymm4,%ymm12 # b1d1

// b0*d0 and b1*d1 need twisting by a twiddle, accounting
// for X^2=zeta in F_q[X]/(X^2-zeta).
//
// TODO: This could be precomputed in the mulcache
vmovdqa (%r9),%ymm0
vmovdqa 32(%r9),%ymm1
vpmullw %ymm0,%ymm10,%ymm2
Expand All @@ -87,6 +102,12 @@ vpaddw %ymm7,%ymm11,%ymm11
vpsubw %ymm13,%ymm10,%ymm13
vpsubw %ymm12,%ymm6,%ymm6

// Bounds: Note that we assume that a+b*X is coefficient-wise bound by q in
// in absolute value: This is coming from the contract for
// polyvec_basemul_acc_montgomery_cached().
//
// Then, each Montgomery multiplication has absolute value < q,
// and hence the coefficients of the output have absolute value < 2q.
vmovdqa %ymm13,(64*\off+ 0)*2(%rdi)
vmovdqa %ymm9,(64*\off+16)*2(%rdi)
vmovdqa %ymm6,(64*\off+32)*2(%rdi)
Expand Down
10 changes: 6 additions & 4 deletions mlkem/native/x86_64/basemul.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ void polyvec_basemul_acc_montgomery_cached_avx2(poly *r, const polyvec *a,
const polyvec *b,
const polyvec_mulcache *b_cache)
{
((void)b_cache); // cache unused

// TODO! Think through bounds

unsigned int i;
poly t;

// TODO: Use mulcache for AVX2. So far, it is unused.
((void)b_cache);

// Coefficient-wise bound of each basemul is 2q
// Since we are accumulating at most 4 times, the
// overall bound is 8q < INT16_MAX.
poly_basemul_montgomery_avx2(r, &a->vec[0], &b->vec[0]);
for (i = 1; i < MLKEM_K; i++)
{
Expand Down

0 comments on commit 2c4cdd2

Please sign in to comment.