diff --git a/mlkem/native/x86_64/basemul.S b/mlkem/native/x86_64/basemul.S index d90b3bc06..d6c387774 100644 --- a/mlkem/native/x86_64/basemul.S +++ b/mlkem/native/x86_64/basemul.S @@ -11,6 +11,7 @@ #include "consts.h" #include "params.h" +// Polynomials to be multiplied are denoted a+bX (rsi arg) and c+dX (rdx arg) .macro schoolbook off vmovdqa _16XQINV*2(%rcx),%ymm0 vmovdqa (64*\off+ 0)*2(%rsi),%ymm1 # a0 @@ -18,6 +19,7 @@ vmovdqa (64*\off+16)*2(%rsi),%ymm2 # b0 vmovdqa (64*\off+32)*2(%rsi),%ymm3 # a1 vmovdqa (64*\off+48)*2(%rsi),%ymm4 # b1 +// Prepare Montgomery twists vpmullw %ymm0,%ymm1,%ymm9 # a0.lo vpmullw %ymm0,%ymm2,%ymm10 # b0.lo vpmullw %ymm0,%ymm3,%ymm11 # a1.lo @@ -26,6 +28,7 @@ vpmullw %ymm0,%ymm4,%ymm12 # b1.lo vmovdqa (64*\off+ 0)*2(%rdx),%ymm5 # c0 vmovdqa (64*\off+16)*2(%rdx),%ymm6 # d0 +// Compute high-parts of monomials in (a0+b0*X)*(c0+d0*X) vpmulhw %ymm5,%ymm1,%ymm13 # a0c0.hi vpmulhw %ymm6,%ymm1,%ymm1 # a0d0.hi vpmulhw %ymm5,%ymm2,%ymm14 # b0c0.hi @@ -34,6 +37,8 @@ vpmulhw %ymm6,%ymm2,%ymm2 # b0d0.hi vmovdqa (64*\off+32)*2(%rdx),%ymm7 # c1 vmovdqa (64*\off+48)*2(%rdx),%ymm8 # d1 +// Compute high-parts of monomials in (a1+b1*X)*(c1+d1*X) +// Don't yet accumulate nor reduce X^2 vpmulhw %ymm7,%ymm3,%ymm15 # a1c1.hi vpmulhw %ymm8,%ymm3,%ymm3 # a1d1.hi vpmulhw %ymm7,%ymm4,%ymm0 # b1c1.hi @@ -41,16 +46,21 @@ vpmulhw %ymm8,%ymm4,%ymm4 # b1d1.hi vmovdqa %ymm13,(%rsp) +// Compute low-parts of monomials in (a0+b0*X)*(c0+d0*X), +// using Montgomery twists calculated before vpmullw %ymm5,%ymm9,%ymm13 # a0c0.lo vpmullw %ymm6,%ymm9,%ymm9 # a0d0.lo vpmullw %ymm5,%ymm10,%ymm5 # b0c0.lo vpmullw %ymm6,%ymm10,%ymm10 # b0d0.lo +// Compute low-parts of monomials in (a1+b1*X)*(c1+d1*X), +// using Montgomery twists calculated before vpmullw %ymm7,%ymm11,%ymm6 # a1c1.lo vpmullw %ymm8,%ymm11,%ymm11 # a1d1.lo vpmullw %ymm7,%ymm12,%ymm7 # b1c1.lo vpmullw %ymm8,%ymm12,%ymm12 # b1d1.lo +// Compute 2nd high multiplication in Montgomery multiplication vmovdqa _16XQ*2(%rcx),%ymm8 vpmulhw %ymm8,%ymm13,%ymm13 vpmulhw %ymm8,%ymm9,%ymm9 @@ -61,6 +71,7 @@ vpmulhw %ymm8,%ymm11,%ymm11 vpmulhw %ymm8,%ymm7,%ymm7 vpmulhw %ymm8,%ymm12,%ymm12 +// Finish Montgomery multiplications vpsubw (%rsp),%ymm13,%ymm13 # -a0c0 vpsubw %ymm9,%ymm1,%ymm9 # a0d0 vpsubw %ymm5,%ymm14,%ymm5 # b0c0 @@ -71,6 +82,10 @@ vpsubw %ymm11,%ymm3,%ymm11 # a1d1 vpsubw %ymm7,%ymm0,%ymm7 # b1c1 vpsubw %ymm12,%ymm4,%ymm12 # b1d1 +// b0*d0 and b1*d1 need twisting by a twiddle, accounting +// for X^2=zeta in F_q[X]/(X^2-zeta). +// +// TODO: This could be precomputed in the mulcache vmovdqa (%r9),%ymm0 vmovdqa 32(%r9),%ymm1 vpmullw %ymm0,%ymm10,%ymm2 @@ -87,6 +102,12 @@ vpaddw %ymm7,%ymm11,%ymm11 vpsubw %ymm13,%ymm10,%ymm13 vpsubw %ymm12,%ymm6,%ymm6 +// Bounds: Note that we assume that a+b*X is coefficient-wise bound by q in +// in absolute value: This is coming from the contract for +// polyvec_basemul_acc_montgomery_cached(). +// +// Then, each Montgomery multiplication has absolute value < q, +// and hence the coefficients of the output have absolute value < 2q. vmovdqa %ymm13,(64*\off+ 0)*2(%rdi) vmovdqa %ymm9,(64*\off+16)*2(%rdi) vmovdqa %ymm6,(64*\off+32)*2(%rdi) diff --git a/mlkem/native/x86_64/basemul.c b/mlkem/native/x86_64/basemul.c index 24dee0374..b51e98ab4 100644 --- a/mlkem/native/x86_64/basemul.c +++ b/mlkem/native/x86_64/basemul.c @@ -36,13 +36,15 @@ void polyvec_basemul_acc_montgomery_cached_avx2(poly *r, const polyvec *a, const polyvec *b, const polyvec_mulcache *b_cache) { - ((void)b_cache); // cache unused - - // TODO! Think through bounds - unsigned int i; poly t; + // TODO: Use mulcache for AVX2. So far, it is unused. + ((void)b_cache); + + // Coefficient-wise bound of each basemul is 2q + // Since we are accumulating at most 4 times, the + // overall bound is 8q < INT16_MAX. poly_basemul_montgomery_avx2(r, &a->vec[0], &b->vec[0]); for (i = 1; i < MLKEM_K; i++) {