From 056112032b689bfa59322fdff8209390fa48736a Mon Sep 17 00:00:00 2001 From: "Matthias J. Kannwischer" Date: Thu, 14 Nov 2024 20:05:54 +0800 Subject: [PATCH] AVX2: Add polyvec_compress_avx2 and polyvec_decompress_avx2 This commits adds the AVX2 intrinsic implementation of polyvec_compress and polyvec_decompress from the official Kyber repository. As a part of https://github.com/pq-code-package/mlkem-native/issues/224 it was identified that the majority of the performance difference in keypair and decaps of our current implementation and the Kyber AVX2 implementation is due to the AVX2 polyvec_compress and polyvec_decompress. This commit adds these two functions to the native interface and adds the AVX2 intrinic-based implementations from the Kyber repository. These are almost verbatim copies. The only two differences are: 1) The AVX2 impelementations requires the uint8_t buffer to be slightly larger than MLKEM_POLYVECCOMPRESSEDBYTES, so that full vectors can be stored/loaded. The official implementation allocated those bytes on top level of the function. That would be slightly messy in our implementation, so I instead allocate the larger buffer in polyvec_compress_avx2/polyvec_decompress_avx2 itself and copy the inputs/outputs. 2) The official AVX2 implementation extended the poly type to also be accessible as a __m256i*. I changed this to a cast as we guarantee the alignment in another way. Below are the performance results on my 13th Gen Intel i7-1360P (Raptor Lake) using gcc 14.2.1 from the Arch Linux repo. | part | Our code 6aa6118e |Kyber repo|Our code(+polyvec_{,de}compress) | | -------- | ---------------- | -------- | ------------------------------- | | 512 kg | 22353 | 22348 | 22252 | | 512 enc | 27820 | 24868 | 26472 | | 512 dec | 35663 | 34984 | 33107 | | 768 kg | 39626 | 38070 | 41590 | | 768 enc | 43605 | 39056 | 44049 | | 768 dec | 54916 | 53726 | 53432 | | 1024 kg | 58983 | 53532 | 57411 | | 1024 enc | 65402 | 56698 | 61613 | | 1024 dec | 80370 | 75874 | 74681 | Signed-off-by: Matthias J. Kannwischer --- mlkem/native/arith_native.h | 33 +++ mlkem/native/x86_64/arith_native_x86_64.h | 8 + mlkem/native/x86_64/polyvec_compress_avx2.c | 215 ++++++++++++++++++++ mlkem/native/x86_64/profiles/default.h | 15 +- mlkem/polyvec.c | 18 ++ 5 files changed, 288 insertions(+), 1 deletion(-) create mode 100644 mlkem/native/x86_64/polyvec_compress_avx2.c diff --git a/mlkem/native/arith_native.h b/mlkem/native/arith_native.h index 71581203a..b44dd1dcf 100644 --- a/mlkem/native/arith_native.h +++ b/mlkem/native/arith_native.h @@ -240,5 +240,38 @@ static inline int rej_uniform_native(int16_t *r, unsigned int len, const uint8_t *buf, unsigned int buflen); #endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */ +#if defined(MLKEM_USE_NATIVE_POLYVEC_COMPRESS) +/************************************************* + * Name: polyvec_compress_native + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES) + * - const polyvec *a: pointer to input vector of polynomials. + * Coefficients must be unsigned canonical, + * i.e. in [0,1,..,MLKEM_Q-1]. + **************************************************/ +static inline void polyvec_compress_native( + uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], const polyvec *a); +#endif /* MLKEM_USE_NATIVE_POLYVEC_COMPRESS */ + +#if defined(MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS) +/************************************************* + * Name: polyvec_decompress_native + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress + * + * Arguments: - polyvec *r: pointer to output vector of polynomials. + * Output will have coefficients normalized to [0,..,q-1]. + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES) + **************************************************/ +static inline void polyvec_decompress_native( + polyvec *r, const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]); + +#endif /* MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS */ + #endif /* MLKEM_USE_NATIVE */ #endif /* MLKEM_ARITH_NATIVE_H */ diff --git a/mlkem/native/x86_64/arith_native_x86_64.h b/mlkem/native/x86_64/arith_native_x86_64.h index 96daa0a6a..10b4089fd 100644 --- a/mlkem/native/x86_64/arith_native_x86_64.h +++ b/mlkem/native/x86_64/arith_native_x86_64.h @@ -55,6 +55,14 @@ void nttfrombytes_avx2(__m256i *r, const uint8_t *a, const __m256i *qdata); #define tomont_avx2 MLKEM_NAMESPACE(tomont_avx2) void tomont_avx2(__m256i *r, const __m256i *qdata); +#define polyvec_compress_avx2 MLKEM_NAMESPACE(polyvec_compress_avx2) +void polyvec_compress_avx2(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], + const polyvec *a); + +#define polyvec_decompress_avx2 MLKEM_NAMESPACE(polyvec_decompress_avx2) +void polyvec_decompress_avx2(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]); + #endif /* MLKEM_USE_NATIVE_X86_64 && SYS_X86_64_AVX2 */ #endif /* MLKEM_X86_64_NATIVE_H */ diff --git a/mlkem/native/x86_64/polyvec_compress_avx2.c b/mlkem/native/x86_64/polyvec_compress_avx2.c new file mode 100644 index 000000000..bff1a7ccb --- /dev/null +++ b/mlkem/native/x86_64/polyvec_compress_avx2.c @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: Apache-2.0 + +// Implementation from Kyber reference repository +// https://github.com/pq-crystals/kyber/blob/main/avx2 + +#include "config.h" + +#if defined(MLKEM_USE_NATIVE_X86_64) && defined(SYS_X86_64_AVX2) + +#include "arith_native_x86_64.h" + +#include +#include +#include +#include "consts.h" +#include "params.h" + + + +#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 320)) +static void poly_compress10(uint8_t r[320], const __m256i *restrict a) { + unsigned int i; + __m256i f0, f1, f2; + __m128i t0, t1; + const __m256i v = _mm256_load_si256(&qdata.vec[_16XV / 16]); + const __m256i v8 = _mm256_slli_epi16(v, 3); + const __m256i off = _mm256_set1_epi16(15); + const __m256i shift1 = _mm256_set1_epi16(1 << 12); + const __m256i mask = _mm256_set1_epi16(1023); + const __m256i shift2 = + _mm256_set1_epi64x((1024LL << 48) + (1LL << 32) + (1024 << 16) + 1); + const __m256i sllvdidx = _mm256_set1_epi64x(12); + const __m256i shufbidx = + _mm256_set_epi8(8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, + -1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0); + + for (i = 0; i < MLKEM_N / 16; i++) { + f0 = _mm256_load_si256(&a[i]); + f1 = _mm256_mullo_epi16(f0, v8); + f2 = _mm256_add_epi16(f0, off); + f0 = _mm256_slli_epi16(f0, 3); + f0 = _mm256_mulhi_epi16(f0, v); + f2 = _mm256_sub_epi16(f1, f2); + f1 = _mm256_andnot_si256(f1, f2); + f1 = _mm256_srli_epi16(f1, 15); + f0 = _mm256_sub_epi16(f0, f1); + f0 = _mm256_mulhrs_epi16(f0, shift1); + f0 = _mm256_and_si256(f0, mask); + f0 = _mm256_madd_epi16(f0, shift2); + f0 = _mm256_sllv_epi32(f0, sllvdidx); + f0 = _mm256_srli_epi64(f0, 12); + f0 = _mm256_shuffle_epi8(f0, shufbidx); + t0 = _mm256_castsi256_si128(f0); + t1 = _mm256_extracti128_si256(f0, 1); + t0 = _mm_blend_epi16(t0, t1, 0xE0); + _mm_storeu_si128((__m128i *)&r[20 * i + 0], t0); + memcpy(&r[20 * i + 16], &t1, 4); + } +} + +static void poly_decompress10(__m256i *restrict r, const uint8_t a[320 + 12]) { + unsigned int i; + __m256i f; + const __m256i q = _mm256_set1_epi32((MLKEM_Q << 16) + 4 * MLKEM_Q); + const __m256i shufbidx = + _mm256_set_epi8(11, 10, 10, 9, 9, 8, 8, 7, 6, 5, 5, 4, 4, 3, 3, 2, 9, 8, + 8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0); + const __m256i sllvdidx = _mm256_set1_epi64x(4); + const __m256i mask = _mm256_set1_epi32((32736 << 16) + 8184); + + for (i = 0; i < MLKEM_N / 16; i++) { + f = _mm256_loadu_si256((__m256i *)&a[20 * i]); + f = _mm256_permute4x64_epi64(f, 0x94); + f = _mm256_shuffle_epi8(f, shufbidx); + f = _mm256_sllv_epi32(f, sllvdidx); + f = _mm256_srli_epi16(f, 1); + f = _mm256_and_si256(f, mask); + f = _mm256_mulhrs_epi16(f, q); + _mm256_store_si256(&r[i], f); + } +} + +#elif (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352)) +static void poly_compress11(uint8_t r[352 + 2], const __m256i *restrict a) { + unsigned int i; + __m256i f0, f1, f2; + __m128i t0, t1; + const __m256i v = _mm256_load_si256(&qdata.vec[_16XV / 16]); + const __m256i v8 = _mm256_slli_epi16(v, 3); + const __m256i off = _mm256_set1_epi16(36); + const __m256i shift1 = _mm256_set1_epi16(1 << 13); + const __m256i mask = _mm256_set1_epi16(2047); + const __m256i shift2 = + _mm256_set1_epi64x((2048LL << 48) + (1LL << 32) + (2048 << 16) + 1); + const __m256i sllvdidx = _mm256_set1_epi64x(10); + const __m256i srlvqidx = _mm256_set_epi64x(30, 10, 30, 10); + const __m256i shufbidx = + _mm256_set_epi8(4, 3, 2, 1, 0, 0, -1, -1, -1, -1, 10, 9, 8, 7, 6, 5, -1, + -1, -1, -1, -1, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + + for (i = 0; i < MLKEM_N / 16; i++) { + f0 = _mm256_load_si256(&a[i]); + f1 = _mm256_mullo_epi16(f0, v8); + f2 = _mm256_add_epi16(f0, off); + f0 = _mm256_slli_epi16(f0, 3); + f0 = _mm256_mulhi_epi16(f0, v); + f2 = _mm256_sub_epi16(f1, f2); + f1 = _mm256_andnot_si256(f1, f2); + f1 = _mm256_srli_epi16(f1, 15); + f0 = _mm256_sub_epi16(f0, f1); + f0 = _mm256_mulhrs_epi16(f0, shift1); + f0 = _mm256_and_si256(f0, mask); + f0 = _mm256_madd_epi16(f0, shift2); + f0 = _mm256_sllv_epi32(f0, sllvdidx); + f1 = _mm256_bsrli_epi128(f0, 8); + f0 = _mm256_srlv_epi64(f0, srlvqidx); + f1 = _mm256_slli_epi64(f1, 34); + f0 = _mm256_add_epi64(f0, f1); + f0 = _mm256_shuffle_epi8(f0, shufbidx); + t0 = _mm256_castsi256_si128(f0); + t1 = _mm256_extracti128_si256(f0, 1); + t0 = _mm_blendv_epi8(t0, t1, _mm256_castsi256_si128(shufbidx)); + _mm_storeu_si128((__m128i *)&r[22 * i + 0], t0); + _mm_storel_epi64((__m128i *)&r[22 * i + 16], t1); + } +} + +static void poly_decompress11(__m256i *restrict r, const uint8_t a[352 + 10]) { + unsigned int i; + __m256i f; + const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ / 16]); + const __m256i shufbidx = + _mm256_set_epi8(13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 6, 5, 5, 4, 4, 3, 10, + 9, 9, 8, 7, 6, 6, 5, 5, 4, 3, 2, 2, 1, 1, 0); + const __m256i srlvdidx = _mm256_set_epi32(0, 0, 1, 0, 0, 0, 1, 0); + const __m256i srlvqidx = _mm256_set_epi64x(2, 0, 2, 0); + const __m256i shift = + _mm256_set_epi16(4, 32, 1, 8, 32, 1, 4, 32, 4, 32, 1, 8, 32, 1, 4, 32); + const __m256i mask = _mm256_set1_epi16(32752); + + for (i = 0; i < MLKEM_N / 16; i++) { + f = _mm256_loadu_si256((__m256i *)&a[22 * i]); + f = _mm256_permute4x64_epi64(f, 0x94); + f = _mm256_shuffle_epi8(f, shufbidx); + f = _mm256_srlv_epi32(f, srlvdidx); + f = _mm256_srlv_epi64(f, srlvqidx); + f = _mm256_mullo_epi16(f, shift); + f = _mm256_srli_epi16(f, 1); + f = _mm256_and_si256(f, mask); + f = _mm256_mulhrs_epi16(f, q); + _mm256_store_si256(&r[i], f); + } +} + +#endif + +/************************************************* + * Name: polyvec_compress + * + * Description: Compress and serialize vector of polynomials + * + * Arguments: - uint8_t *r: pointer to output byte array + * (needs space for MLKEM_POLYVECCOMPRESSEDBYTES) + * - polyvec *a: pointer to input vector of polynomials + **************************************************/ +void polyvec_compress_avx2(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], + const polyvec *a) { + unsigned int i; + // TODO: can we eliminate the extra bytes? + // pad input so we can always store full vectors + +#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 320)) + for (i = 0; i < MLKEM_K; i++) + poly_compress10(&r[320 * i], (__m256i *)&a->vec[i]); +#elif (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352)) + uint8_t rpadded[MLKEM_POLYVECCOMPRESSEDBYTES + 2]; + for (i = 0; i < MLKEM_K; i++) + poly_compress11(&rpadded[352 * i], (__m256i *)&a->vec[i]); + memcpy(r, rpadded, MLKEM_POLYVECCOMPRESSEDBYTES); +#endif +} + +/************************************************* + * Name: polyvec_decompress + * + * Description: De-serialize and decompress vector of polynomials; + * approximate inverse of polyvec_compress + * + * Arguments: - polyvec *r: pointer to output vector of polynomials + * - const uint8_t *a: pointer to input byte array + * (of length MLKEM_POLYVECCOMPRESSEDBYTES) + **************************************************/ +void polyvec_decompress_avx2(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) { + unsigned int i; + // TODO: can we eliminate the extra bytes? + // pad input so we can always load full vectors + +#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 320)) + uint8_t apadded[MLKEM_POLYVECCOMPRESSEDBYTES + 12]; + memcpy(apadded, a, MLKEM_POLYVECCOMPRESSEDBYTES); + for (i = 0; i < MLKEM_K; i++) + poly_decompress10((__m256i *)&r->vec[i], &apadded[320 * i]); +#elif (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352)) + uint8_t apadded[MLKEM_POLYVECCOMPRESSEDBYTES + 10]; + memcpy(apadded, a, MLKEM_POLYVECCOMPRESSEDBYTES); + for (i = 0; i < MLKEM_K; i++) + poly_decompress11((__m256i *)&r->vec[i], &apadded[352 * i]); +#endif +} + +#else /* MLKEM_USE_NATIVE_X86_64 && SYS_X86_64_AVX2 */ +// Dummy declaration for compilers disliking empty compilation units +int empty_cu_compress_avx2; +#endif /* MLKEM_USE_NATIVE_X86_64 && SYS_X86_64_AVX2 */ diff --git a/mlkem/native/x86_64/profiles/default.h b/mlkem/native/x86_64/profiles/default.h index 7f9586fea..7ac609694 100644 --- a/mlkem/native/x86_64/profiles/default.h +++ b/mlkem/native/x86_64/profiles/default.h @@ -13,7 +13,9 @@ #include "../arith_native_x86_64.h" #include "../consts.h" +#include "params.h" #include "poly.h" +#include "polyvec.h" #define MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER @@ -26,6 +28,8 @@ #define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE #define MLKEM_USE_NATIVE_POLY_TOBYTES #define MLKEM_USE_NATIVE_POLY_FROMBYTES +#define MLKEM_USE_NATIVE_POLYVEC_COMPRESS +#define MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS #define INVNTT_BOUND_NATIVE \ (14870 + 1) // Bound from the official Kyber repository @@ -89,4 +93,13 @@ static inline void poly_frombytes_native(poly *r, nttfrombytes_avx2((__m256i *)r->coeffs, a, qdata.vec); } -#endif /* MLKEM_ARITH_NATIVE_PROFILE_H */ +static inline void polyvec_compress_native( + uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], const polyvec *a) { + polyvec_compress_avx2(r, a); +} + +static inline void polyvec_decompress_native( + polyvec *r, const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) { + polyvec_decompress_avx2(r, a); +} +#endif /* MLKEM_ARITH_NATIVE_PROFILE_H */ diff --git a/mlkem/polyvec.c b/mlkem/polyvec.c index 47705263b..e689155bc 100644 --- a/mlkem/polyvec.c +++ b/mlkem/polyvec.c @@ -8,6 +8,8 @@ #include "poly.h" #include "debug/debug.h" + +#if !defined(MLKEM_USE_NATIVE_POLYVEC_COMPRESS) void polyvec_compress(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], const polyvec *a) { POLYVEC_UBOUND(a, MLKEM_Q); @@ -92,7 +94,15 @@ void polyvec_compress(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], #error "MLKEM_POLYVECCOMPRESSEDBYTES needs to be in {320*MLKEM_K, 352*MLKEM_K}" #endif } +#else /* MLKEM_USE_NATIVE_POLYVEC_COMPRESS */ +void polyvec_compress(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], + const polyvec *a) { + POLYVEC_UBOUND(a, MLKEM_Q); + polyvec_compress_native(r, a); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_COMPRESS */ +#if !defined(MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS) void polyvec_decompress(polyvec *r, const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) { #if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352)) @@ -174,6 +184,14 @@ void polyvec_decompress(polyvec *r, POLYVEC_UBOUND(r, MLKEM_Q); } +#else /* MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS */ +void polyvec_decompress(polyvec *r, + const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) { + polyvec_decompress_native(r, a); + POLYVEC_UBOUND(r, MLKEM_Q); +} +#endif /* MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS */ + void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) { unsigned int i;