Skip to content

Commit

Permalink
AVX2: Add polyvec_compress_avx2 and polyvec_decompress_avx2
Browse files Browse the repository at this point in the history
This commits adds the AVX2 intrinsic implementation of
polyvec_compress and polyvec_decompress from the official
Kyber repository.
As a part of #224
it was identified that the majority of the performance difference
in keypair and decaps of our current implementation and
the Kyber AVX2 implementation is due to the AVX2 polyvec_compress
and polyvec_decompress.

This commit adds these two functions to the native interface
and adds the AVX2 intrinic-based implementations from the Kyber
repository. These are almost verbatim copies.
The only two differences are:
 1) The AVX2 impelementations requires the uint8_t buffer
 to be slightly larger than MLKEM_POLYVECCOMPRESSEDBYTES, so that
 full vectors can be stored/loaded. The official implementation allocated
 those bytes on top level of the function. That would be slightly
 messy in our implementation, so I instead allocate the larger buffer
 in polyvec_compress_avx2/polyvec_decompress_avx2 itself and copy the
 inputs/outputs.
 2) The official AVX2 implementation extended the poly type to
   also be accessible as a __m256i*.
   I changed this to a cast as we guarantee the alignment in another way.

Below are the performance results on my 13th Gen Intel i7-1360P (Raptor Lake)
using gcc 14.2.1 from the Arch Linux repo.

| part     | Our code 6aa6118 |Kyber repo|Our code(+polyvec_{,de}compress) |
| -------- | ----------------  | -------- | ------------------------------- |
| 512 kg   | 22353             | 22348    | 22252                           |
| 512 enc  | 27820             | 24868    | 26472                           |
| 512 dec  | 35663             | 34984    | 33107                           |
| 768 kg   | 39626             | 38070    | 41590                           |
| 768 enc  | 43605             | 39056    | 44049                           |
| 768 dec  | 54916             | 53726    | 53432                           |
| 1024 kg  | 58983             | 53532    | 57411                           |
| 1024 enc | 65402             | 56698    | 61613                           |
| 1024 dec | 80370             | 75874    | 74681                           |

Signed-off-by: Matthias J. Kannwischer <[email protected]>
  • Loading branch information
mkannwischer committed Nov 18, 2024
1 parent 78eecac commit 0561120
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 1 deletion.
33 changes: 33 additions & 0 deletions mlkem/native/arith_native.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,5 +240,38 @@ static inline int rej_uniform_native(int16_t *r, unsigned int len,
const uint8_t *buf, unsigned int buflen);
#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */

#if defined(MLKEM_USE_NATIVE_POLYVEC_COMPRESS)
/*************************************************
* Name: polyvec_compress_native
*
* Description: Compress and serialize vector of polynomials
*
* Arguments: - uint8_t *r: pointer to output byte array
* (needs space for MLKEM_POLYVECCOMPRESSEDBYTES)
* - const polyvec *a: pointer to input vector of polynomials.
* Coefficients must be unsigned canonical,
* i.e. in [0,1,..,MLKEM_Q-1].
**************************************************/
static inline void polyvec_compress_native(
uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], const polyvec *a);
#endif /* MLKEM_USE_NATIVE_POLYVEC_COMPRESS */

#if defined(MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS)
/*************************************************
* Name: polyvec_decompress_native
*
* Description: De-serialize and decompress vector of polynomials;
* approximate inverse of polyvec_compress
*
* Arguments: - polyvec *r: pointer to output vector of polynomials.
* Output will have coefficients normalized to [0,..,q-1].
* - const uint8_t *a: pointer to input byte array
* (of length MLKEM_POLYVECCOMPRESSEDBYTES)
**************************************************/
static inline void polyvec_decompress_native(
polyvec *r, const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]);

#endif /* MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS */

#endif /* MLKEM_USE_NATIVE */
#endif /* MLKEM_ARITH_NATIVE_H */
8 changes: 8 additions & 0 deletions mlkem/native/x86_64/arith_native_x86_64.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ void nttfrombytes_avx2(__m256i *r, const uint8_t *a, const __m256i *qdata);
#define tomont_avx2 MLKEM_NAMESPACE(tomont_avx2)
void tomont_avx2(__m256i *r, const __m256i *qdata);

#define polyvec_compress_avx2 MLKEM_NAMESPACE(polyvec_compress_avx2)
void polyvec_compress_avx2(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES],
const polyvec *a);

#define polyvec_decompress_avx2 MLKEM_NAMESPACE(polyvec_decompress_avx2)
void polyvec_decompress_avx2(polyvec *r,
const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]);

#endif /* MLKEM_USE_NATIVE_X86_64 && SYS_X86_64_AVX2 */

#endif /* MLKEM_X86_64_NATIVE_H */
215 changes: 215 additions & 0 deletions mlkem/native/x86_64/polyvec_compress_avx2.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// SPDX-License-Identifier: Apache-2.0

// Implementation from Kyber reference repository
// https://github.com/pq-crystals/kyber/blob/main/avx2

#include "config.h"

#if defined(MLKEM_USE_NATIVE_X86_64) && defined(SYS_X86_64_AVX2)

#include "arith_native_x86_64.h"

#include <immintrin.h>
#include <stdint.h>
#include <string.h>
#include "consts.h"
#include "params.h"



#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 320))
static void poly_compress10(uint8_t r[320], const __m256i *restrict a) {
unsigned int i;
__m256i f0, f1, f2;
__m128i t0, t1;
const __m256i v = _mm256_load_si256(&qdata.vec[_16XV / 16]);
const __m256i v8 = _mm256_slli_epi16(v, 3);
const __m256i off = _mm256_set1_epi16(15);
const __m256i shift1 = _mm256_set1_epi16(1 << 12);
const __m256i mask = _mm256_set1_epi16(1023);
const __m256i shift2 =
_mm256_set1_epi64x((1024LL << 48) + (1LL << 32) + (1024 << 16) + 1);
const __m256i sllvdidx = _mm256_set1_epi64x(12);
const __m256i shufbidx =
_mm256_set_epi8(8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, 12, 11, 10, 9,
-1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0);

for (i = 0; i < MLKEM_N / 16; i++) {
f0 = _mm256_load_si256(&a[i]);
f1 = _mm256_mullo_epi16(f0, v8);
f2 = _mm256_add_epi16(f0, off);
f0 = _mm256_slli_epi16(f0, 3);
f0 = _mm256_mulhi_epi16(f0, v);
f2 = _mm256_sub_epi16(f1, f2);
f1 = _mm256_andnot_si256(f1, f2);
f1 = _mm256_srli_epi16(f1, 15);
f0 = _mm256_sub_epi16(f0, f1);
f0 = _mm256_mulhrs_epi16(f0, shift1);
f0 = _mm256_and_si256(f0, mask);
f0 = _mm256_madd_epi16(f0, shift2);
f0 = _mm256_sllv_epi32(f0, sllvdidx);
f0 = _mm256_srli_epi64(f0, 12);
f0 = _mm256_shuffle_epi8(f0, shufbidx);
t0 = _mm256_castsi256_si128(f0);
t1 = _mm256_extracti128_si256(f0, 1);
t0 = _mm_blend_epi16(t0, t1, 0xE0);
_mm_storeu_si128((__m128i *)&r[20 * i + 0], t0);
memcpy(&r[20 * i + 16], &t1, 4);
}
}

static void poly_decompress10(__m256i *restrict r, const uint8_t a[320 + 12]) {
unsigned int i;
__m256i f;
const __m256i q = _mm256_set1_epi32((MLKEM_Q << 16) + 4 * MLKEM_Q);
const __m256i shufbidx =
_mm256_set_epi8(11, 10, 10, 9, 9, 8, 8, 7, 6, 5, 5, 4, 4, 3, 3, 2, 9, 8,
8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0);
const __m256i sllvdidx = _mm256_set1_epi64x(4);
const __m256i mask = _mm256_set1_epi32((32736 << 16) + 8184);

for (i = 0; i < MLKEM_N / 16; i++) {
f = _mm256_loadu_si256((__m256i *)&a[20 * i]);
f = _mm256_permute4x64_epi64(f, 0x94);
f = _mm256_shuffle_epi8(f, shufbidx);
f = _mm256_sllv_epi32(f, sllvdidx);
f = _mm256_srli_epi16(f, 1);
f = _mm256_and_si256(f, mask);
f = _mm256_mulhrs_epi16(f, q);
_mm256_store_si256(&r[i], f);
}
}

#elif (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352))
static void poly_compress11(uint8_t r[352 + 2], const __m256i *restrict a) {
unsigned int i;
__m256i f0, f1, f2;
__m128i t0, t1;
const __m256i v = _mm256_load_si256(&qdata.vec[_16XV / 16]);
const __m256i v8 = _mm256_slli_epi16(v, 3);
const __m256i off = _mm256_set1_epi16(36);
const __m256i shift1 = _mm256_set1_epi16(1 << 13);
const __m256i mask = _mm256_set1_epi16(2047);
const __m256i shift2 =
_mm256_set1_epi64x((2048LL << 48) + (1LL << 32) + (2048 << 16) + 1);
const __m256i sllvdidx = _mm256_set1_epi64x(10);
const __m256i srlvqidx = _mm256_set_epi64x(30, 10, 30, 10);
const __m256i shufbidx =
_mm256_set_epi8(4, 3, 2, 1, 0, 0, -1, -1, -1, -1, 10, 9, 8, 7, 6, 5, -1,
-1, -1, -1, -1, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);

for (i = 0; i < MLKEM_N / 16; i++) {
f0 = _mm256_load_si256(&a[i]);
f1 = _mm256_mullo_epi16(f0, v8);
f2 = _mm256_add_epi16(f0, off);
f0 = _mm256_slli_epi16(f0, 3);
f0 = _mm256_mulhi_epi16(f0, v);
f2 = _mm256_sub_epi16(f1, f2);
f1 = _mm256_andnot_si256(f1, f2);
f1 = _mm256_srli_epi16(f1, 15);
f0 = _mm256_sub_epi16(f0, f1);
f0 = _mm256_mulhrs_epi16(f0, shift1);
f0 = _mm256_and_si256(f0, mask);
f0 = _mm256_madd_epi16(f0, shift2);
f0 = _mm256_sllv_epi32(f0, sllvdidx);
f1 = _mm256_bsrli_epi128(f0, 8);
f0 = _mm256_srlv_epi64(f0, srlvqidx);
f1 = _mm256_slli_epi64(f1, 34);
f0 = _mm256_add_epi64(f0, f1);
f0 = _mm256_shuffle_epi8(f0, shufbidx);
t0 = _mm256_castsi256_si128(f0);
t1 = _mm256_extracti128_si256(f0, 1);
t0 = _mm_blendv_epi8(t0, t1, _mm256_castsi256_si128(shufbidx));
_mm_storeu_si128((__m128i *)&r[22 * i + 0], t0);
_mm_storel_epi64((__m128i *)&r[22 * i + 16], t1);
}
}

static void poly_decompress11(__m256i *restrict r, const uint8_t a[352 + 10]) {
unsigned int i;
__m256i f;
const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ / 16]);
const __m256i shufbidx =
_mm256_set_epi8(13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 6, 5, 5, 4, 4, 3, 10,
9, 9, 8, 7, 6, 6, 5, 5, 4, 3, 2, 2, 1, 1, 0);
const __m256i srlvdidx = _mm256_set_epi32(0, 0, 1, 0, 0, 0, 1, 0);
const __m256i srlvqidx = _mm256_set_epi64x(2, 0, 2, 0);
const __m256i shift =
_mm256_set_epi16(4, 32, 1, 8, 32, 1, 4, 32, 4, 32, 1, 8, 32, 1, 4, 32);
const __m256i mask = _mm256_set1_epi16(32752);

for (i = 0; i < MLKEM_N / 16; i++) {
f = _mm256_loadu_si256((__m256i *)&a[22 * i]);
f = _mm256_permute4x64_epi64(f, 0x94);
f = _mm256_shuffle_epi8(f, shufbidx);
f = _mm256_srlv_epi32(f, srlvdidx);
f = _mm256_srlv_epi64(f, srlvqidx);
f = _mm256_mullo_epi16(f, shift);
f = _mm256_srli_epi16(f, 1);
f = _mm256_and_si256(f, mask);
f = _mm256_mulhrs_epi16(f, q);
_mm256_store_si256(&r[i], f);
}
}

#endif

/*************************************************
* Name: polyvec_compress
*
* Description: Compress and serialize vector of polynomials
*
* Arguments: - uint8_t *r: pointer to output byte array
* (needs space for MLKEM_POLYVECCOMPRESSEDBYTES)
* - polyvec *a: pointer to input vector of polynomials
**************************************************/
void polyvec_compress_avx2(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES],
const polyvec *a) {
unsigned int i;
// TODO: can we eliminate the extra bytes?
// pad input so we can always store full vectors

#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 320))
for (i = 0; i < MLKEM_K; i++)
poly_compress10(&r[320 * i], (__m256i *)&a->vec[i]);
#elif (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352))
uint8_t rpadded[MLKEM_POLYVECCOMPRESSEDBYTES + 2];
for (i = 0; i < MLKEM_K; i++)
poly_compress11(&rpadded[352 * i], (__m256i *)&a->vec[i]);
memcpy(r, rpadded, MLKEM_POLYVECCOMPRESSEDBYTES);
#endif
}

/*************************************************
* Name: polyvec_decompress
*
* Description: De-serialize and decompress vector of polynomials;
* approximate inverse of polyvec_compress
*
* Arguments: - polyvec *r: pointer to output vector of polynomials
* - const uint8_t *a: pointer to input byte array
* (of length MLKEM_POLYVECCOMPRESSEDBYTES)
**************************************************/
void polyvec_decompress_avx2(polyvec *r,
const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) {
unsigned int i;
// TODO: can we eliminate the extra bytes?
// pad input so we can always load full vectors

#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 320))
uint8_t apadded[MLKEM_POLYVECCOMPRESSEDBYTES + 12];
memcpy(apadded, a, MLKEM_POLYVECCOMPRESSEDBYTES);
for (i = 0; i < MLKEM_K; i++)
poly_decompress10((__m256i *)&r->vec[i], &apadded[320 * i]);
#elif (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352))
uint8_t apadded[MLKEM_POLYVECCOMPRESSEDBYTES + 10];
memcpy(apadded, a, MLKEM_POLYVECCOMPRESSEDBYTES);
for (i = 0; i < MLKEM_K; i++)
poly_decompress11((__m256i *)&r->vec[i], &apadded[352 * i]);
#endif
}

#else /* MLKEM_USE_NATIVE_X86_64 && SYS_X86_64_AVX2 */
// Dummy declaration for compilers disliking empty compilation units
int empty_cu_compress_avx2;
#endif /* MLKEM_USE_NATIVE_X86_64 && SYS_X86_64_AVX2 */
15 changes: 14 additions & 1 deletion mlkem/native/x86_64/profiles/default.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#include "../arith_native_x86_64.h"
#include "../consts.h"

#include "params.h"
#include "poly.h"
#include "polyvec.h"

#define MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER

Expand All @@ -26,6 +28,8 @@
#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE
#define MLKEM_USE_NATIVE_POLY_TOBYTES
#define MLKEM_USE_NATIVE_POLY_FROMBYTES
#define MLKEM_USE_NATIVE_POLYVEC_COMPRESS
#define MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS

#define INVNTT_BOUND_NATIVE \
(14870 + 1) // Bound from the official Kyber repository
Expand Down Expand Up @@ -89,4 +93,13 @@ static inline void poly_frombytes_native(poly *r,
nttfrombytes_avx2((__m256i *)r->coeffs, a, qdata.vec);
}

#endif /* MLKEM_ARITH_NATIVE_PROFILE_H */
static inline void polyvec_compress_native(
uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], const polyvec *a) {
polyvec_compress_avx2(r, a);
}

static inline void polyvec_decompress_native(
polyvec *r, const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) {
polyvec_decompress_avx2(r, a);
}
#endif /* MLKEM_ARITH_NATIVE_PROFILE_H */
18 changes: 18 additions & 0 deletions mlkem/polyvec.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "poly.h"

#include "debug/debug.h"

#if !defined(MLKEM_USE_NATIVE_POLYVEC_COMPRESS)
void polyvec_compress(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES],
const polyvec *a) {
POLYVEC_UBOUND(a, MLKEM_Q);
Expand Down Expand Up @@ -92,7 +94,15 @@ void polyvec_compress(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES],
#error "MLKEM_POLYVECCOMPRESSEDBYTES needs to be in {320*MLKEM_K, 352*MLKEM_K}"
#endif
}
#else /* MLKEM_USE_NATIVE_POLYVEC_COMPRESS */
void polyvec_compress(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES],
const polyvec *a) {
POLYVEC_UBOUND(a, MLKEM_Q);
polyvec_compress_native(r, a);
}
#endif /* MLKEM_USE_NATIVE_POLYVEC_COMPRESS */

#if !defined(MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS)
void polyvec_decompress(polyvec *r,
const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) {
#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352))
Expand Down Expand Up @@ -174,6 +184,14 @@ void polyvec_decompress(polyvec *r,

POLYVEC_UBOUND(r, MLKEM_Q);
}
#else /* MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS */
void polyvec_decompress(polyvec *r,
const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) {
polyvec_decompress_native(r, a);
POLYVEC_UBOUND(r, MLKEM_Q);
}
#endif /* MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS */


void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) {
unsigned int i;
Expand Down

0 comments on commit 0561120

Please sign in to comment.