Skip to content

Commit

Permalink
AVX2 implementation of mulcache
Browse files Browse the repository at this point in the history
Slight performance improvement:

|--------------------------------------|
|  MLKEM 512 | Before | After | Improv |
|    keypair |  15101 | 14835 |  1.02x |
|     encaps |  19664 | 18631 |  1.05x |
|     decaps |  25824 | 24420 |  1.05x |
|--------------------------------------|
|  MLKEM 768 | Before | After | Improv |
|    keypair |  26187 | 25157 |  1.04x |
|     encaps |  28014 | 27248 |  1.03x |
|     decaps |  36989 | 35659 |  1.03x |
|--------------------------------------|
| MLKEM 1024 | Before | After | Improv |
|    keypair |  36014 | 35630 |  1.01x |
|     encaps |  39797 | 39347 |  1.00x |
|     decaps |  52139 | 51524 |  1.01x |
|--------------------------------------|

measured on Intel(R) Xeon(R) Platinum 8488C.

Signed-off-by: dkostic <[email protected]>
  • Loading branch information
dkostic committed Dec 24, 2024
1 parent aeb97b1 commit 7f5da24
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 39 deletions.
6 changes: 5 additions & 1 deletion mlkem/native/x86_64/src/arith_native_x86_64.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ void reduce_avx2(__m256i *r, const __m256i *qdata);

#define basemul_avx2 MLKEM_NAMESPACE(basemul_avx2)
void basemul_avx2(__m256i *r, const __m256i *a, const __m256i *b,
const __m256i *qdata);
const __m256i *qdata, const __m256i *bcache);

#define poly_mulcache_compute_avx2 \

Check failure on line 42 in mlkem/native/x86_64/src/arith_native_x86_64.h

View workflow job for this annotation

GitHub Actions / Linting (ubuntu-latest)

Format error

mlkem/native/x86_64/src/arith_native_x86_64.h require to be formatted
MLKEM_NAMESPACE(poly_mulcache_compute_avx2)
void poly_mulcache_compute_avx2(poly_mulcache *x, const poly *y);

#define polyvec_basemul_acc_montgomery_cached_avx2 \
MLKEM_NAMESPACE(polyvec_basemul_acc_montgomery_cached_avx2)
Expand Down
35 changes: 10 additions & 25 deletions mlkem/native/x86_64/src/basemul.S
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ vmovdqa (64*\off+16)*2(%rdx),%ymm6 # d0
vpmulhw %ymm5,%ymm1,%ymm13 # a0c0.hi
vpmulhw %ymm6,%ymm1,%ymm1 # a0d0.hi
vpmulhw %ymm5,%ymm2,%ymm14 # b0c0.hi
vpmulhw %ymm6,%ymm2,%ymm2 # b0d0.hi

vmovdqa (64*\off+32)*2(%rdx),%ymm7 # c1
vmovdqa (64*\off+48)*2(%rdx),%ymm8 # d1
Expand All @@ -43,7 +42,6 @@ vmovdqa (64*\off+48)*2(%rdx),%ymm8 # d1
vpmulhw %ymm7,%ymm3,%ymm15 # a1c1.hi
vpmulhw %ymm8,%ymm3,%ymm3 # a1d1.hi
vpmulhw %ymm7,%ymm4,%ymm0 # b1c1.hi
vpmulhw %ymm8,%ymm4,%ymm4 # b1d1.hi

vmovdqa %ymm13,(%rsp)

Expand All @@ -52,14 +50,20 @@ vmovdqa %ymm13,(%rsp)
vpmullw %ymm5,%ymm9,%ymm13 # a0c0.lo
vpmullw %ymm6,%ymm9,%ymm9 # a0d0.lo
vpmullw %ymm5,%ymm10,%ymm5 # b0c0.lo
vpmullw %ymm6,%ymm10,%ymm10 # b0d0.lo

/* Compute low-parts of monomials in (a1+b1*X)*(c1+d1*X), */
/* using Montgomery twists calculated before */
vpmullw %ymm7,%ymm11,%ymm6 # a1c1.lo
vpmullw %ymm8,%ymm11,%ymm11 # a1d1.lo
vpmullw %ymm7,%ymm12,%ymm7 # b1c1.lo
vpmullw %ymm8,%ymm12,%ymm12 # b1d1.lo

/* Use cached (d * zeta) to compute the high part of (b*d) terms */
vmovdqa (32*\off+ 0)*2(%r8),%ymm8 # d0z
vpmulhw %ymm8,%ymm2,%ymm2 # b0d0z.hi
vpmullw %ymm8,%ymm10,%ymm10 # b0d0z.lo
vmovdqa (32*\off+16)*2(%r8),%ymm8 # d1z
vpmulhw %ymm8,%ymm4,%ymm4 # b1d1z.hi
vpmullw %ymm8,%ymm12,%ymm12 # b1d1z.lo

/* Compute 2nd high multiplication in Montgomery multiplication */
vmovdqa _16XQ*2(%rcx),%ymm8
Expand All @@ -83,21 +87,6 @@ vpsubw %ymm11,%ymm3,%ymm11 # a1d1
vpsubw %ymm7,%ymm0,%ymm7 # b1c1
vpsubw %ymm12,%ymm4,%ymm12 # b1d1

/* b0*d0 and b1*d1 need twisting by a twiddle, accounting
* for X^2=zeta in F_q[X]/(X^2-zeta).
*
* TODO: This could be precomputed in the mulcache */
vmovdqa (%r9),%ymm0
vmovdqa 32(%r9),%ymm1
vpmullw %ymm0,%ymm10,%ymm2
vpmullw %ymm0,%ymm12,%ymm3
vpmulhw %ymm1,%ymm10,%ymm10
vpmulhw %ymm1,%ymm12,%ymm12
vpmulhw %ymm8,%ymm2,%ymm2
vpmulhw %ymm8,%ymm3,%ymm3
vpsubw %ymm2,%ymm10,%ymm10 # rb0d0
vpsubw %ymm3,%ymm12,%ymm12 # rb1d1

vpaddw %ymm5,%ymm9,%ymm9
vpaddw %ymm7,%ymm11,%ymm11
vpsubw %ymm13,%ymm10,%ymm13
Expand All @@ -115,23 +104,19 @@ vmovdqa %ymm11,(64*\off+48)*2(%rdi)
.text
.global MLKEM_ASM_NAMESPACE(basemul_avx2)
MLKEM_ASM_NAMESPACE(basemul_avx2):
mov %rsp,%r8
mov %rsp,%r11
and $-32,%rsp
sub $32,%rsp

lea (_ZETAS_EXP+176)*2(%rcx),%r9
schoolbook 0

add $32*2,%r9
schoolbook 1

add $192*2,%r9
schoolbook 2

add $32*2,%r9
schoolbook 3

mov %r8,%rsp
mov %r11,%rsp
ret

#endif /* MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT */
55 changes: 47 additions & 8 deletions mlkem/native/x86_64/src/basemul.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,57 @@
#if defined(MLKEM_NATIVE_ARITH_BACKEND_X86_64_DEFAULT)

#include "consts.h"

#include "poly.h"
#include "polyvec.h"

#include "arith_native_x86_64.h"

static void poly_basemul_montgomery_avx2(poly *r, const poly *a, const poly *b)
int16_t zetas_avx2[64] = {

Check failure on line 17 in mlkem/native/x86_64/src/basemul.c

View workflow job for this annotation

GitHub Actions / Linting (ubuntu-latest)

Format error

mlkem/native/x86_64/src/basemul.c require to be formatted
-1103, 555, -1251, 1550, 422, 177, -291, 1574, -246, 1159, -777, -602, -1590, -872, 418, -156,

Check failure on line 18 in mlkem/native/x86_64/src/basemul.c

View workflow job for this annotation

GitHub Actions / Linting (ubuntu-latest)

Format error

mlkem/native/x86_64/src/basemul.c require to be formatted

Check failure on line 18 in mlkem/native/x86_64/src/basemul.c

View workflow job for this annotation

GitHub Actions / Linting (ubuntu-latest)

Format error

mlkem/native/x86_64/src/basemul.c require to be formatted

Check failure on line 18 in mlkem/native/x86_64/src/basemul.c

View workflow job for this annotation

GitHub Actions / Linting (ubuntu-latest)

Format error

mlkem/native/x86_64/src/basemul.c require to be formatted

Check failure on line 18 in mlkem/native/x86_64/src/basemul.c

View workflow job for this annotation

GitHub Actions / Linting (ubuntu-latest)

Format error

mlkem/native/x86_64/src/basemul.c require to be formatted

Check failure on line 18 in mlkem/native/x86_64/src/basemul.c

View workflow job for this annotation

GitHub Actions / Linting (ubuntu-latest)

Format error

mlkem/native/x86_64/src/basemul.c require to be formatted

Check failure on line 18 in mlkem/native/x86_64/src/basemul.c

View workflow job for this annotation

GitHub Actions / Linting (ubuntu-latest)

Format error

mlkem/native/x86_64/src/basemul.c require to be formatted

Check failure on line 18 in mlkem/native/x86_64/src/basemul.c

View workflow job for this annotation

GitHub Actions / Linting (ubuntu-latest)

Format error

mlkem/native/x86_64/src/basemul.c require to be formatted
430, 843, 871, 105, 587, -235, -460, 1653, 778, -147, 1483, 1119, 644, 349, 329, -75,
817, 603, 1322, -1465, -1215, 1218, -874, -1187, -1185, -1278, -1510, -870, -108, 996, 958, 1522,
1097, 610, -1285, 384, -136, -1335, 220, 1670, -1530, 794, -854, 478, -308, 991, -1460, 1628,
};

#define QINV (-3327) /* q^-1 mod 2^16 */

void poly_mulcache_compute_avx2(poly_mulcache *x, const poly *y) {
__m256i q, qinv, a0, a1, z, t0, t1, s0, s1, r0, r1;

q = _mm256_set1_epi16(MLKEM_Q);
qinv = _mm256_set1_epi16(QINV);

for (int j = 0; j < 4; j++) {
a0 = _mm256_load_si256((const __m256i*)&y->coeffs[64*j+16]);
a1 = _mm256_load_si256((const __m256i*)&y->coeffs[64*j+48]);
z = _mm256_load_si256((const __m256i*)&zetas_avx2[16*j]);

t0 = _mm256_mullo_epi16(a0, qinv);
t1 = _mm256_mullo_epi16(a1, qinv);
t0 = _mm256_mullo_epi16(t0, z);
t1 = _mm256_mullo_epi16(t1, z);

s0 = _mm256_mulhi_epi16(a0, z);
s1 = _mm256_mulhi_epi16(a1, z);
r0 = _mm256_mulhi_epi16(t0, q);
r1 = _mm256_mulhi_epi16(t1, q);

r0 = _mm256_sub_epi16(s0, r0);
r1 = _mm256_sub_epi16(s1, r1);

_mm256_store_si256((__m256i*)&x->coeffs[32*j], r0);
_mm256_store_si256((__m256i*)&x->coeffs[32*j+16], r1);
}
}

static void poly_basemul_montgomery_avx2(poly *r,
const poly *a, const poly *b,
const poly_mulcache *b_cache)
{
basemul_avx2((__m256i *)r->coeffs, (const __m256i *)a->coeffs,
(const __m256i *)b->coeffs, qdata.vec);
basemul_avx2((__m256i *)r->coeffs,
(const __m256i *)a->coeffs, (const __m256i *)b->coeffs,
qdata.vec, (const __m256i *)b_cache->coeffs);
}

/*
Expand Down Expand Up @@ -44,16 +86,13 @@ void polyvec_basemul_acc_montgomery_cached_avx2(poly *r, const polyvec *a,
unsigned int i;
poly t;

/* TODO: Use mulcache for AVX2. So far, it is unused. */
((void)b_cache);

/* Coefficient-wise bound of each basemul is 2q.
* Since we are accumulating at most 4 times, the
* overall bound is 8q < INT16_MAX. */
poly_basemul_montgomery_avx2(r, &a->vec[0], &b->vec[0]);
poly_basemul_montgomery_avx2(r, &a->vec[0], &b->vec[0], &b_cache->vec[0]);
for (i = 1; i < MLKEM_K; i++)
{
poly_basemul_montgomery_avx2(&t, &a->vec[i], &b->vec[i]);
poly_basemul_montgomery_avx2(&t, &a->vec[i], &b->vec[i], &b_cache->vec[i]);
poly_add_avx2(r, r, &t);
}
}
Expand Down
4 changes: 1 addition & 3 deletions mlkem/native/x86_64/src/default_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ static INLINE void poly_tomont_native(poly *data)

static INLINE void poly_mulcache_compute_native(poly_mulcache *x, const poly *y)
{
/* AVX2 backend does not use mulcache */
((void)y);
((void)x);
poly_mulcache_compute_avx2(x, y);
}

static INLINE void polyvec_basemul_acc_montgomery_cached_native(
Expand Down
1 change: 1 addition & 0 deletions mlkem/poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,4 @@ void poly_mulcache_compute(poly_mulcache *x, const poly *a)
* of poly_basemul_montgomery_cached() does still include the check. */
}
#endif /* MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE */

2 changes: 1 addition & 1 deletion mlkem/poly.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ typedef struct
typedef struct
{
int16_t coeffs[MLKEM_N >> 1];
} poly_mulcache;
} ALIGN poly_mulcache;

/************************************************************
* Name: scalar_compress_d1
Expand Down
2 changes: 1 addition & 1 deletion mlkem/polyvec.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ typedef struct
typedef struct
{
poly_mulcache vec[MLKEM_K];
} polyvec_mulcache;
} ALIGN polyvec_mulcache;

#define polyvec_compress_du MLKEM_NAMESPACE(polyvec_compress_du)
/*************************************************
Expand Down

0 comments on commit 7f5da24

Please sign in to comment.