Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AVX2 polyvec_{de,}compress #410

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions mlkem/native/arith_native.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,5 +240,38 @@ static inline int rej_uniform_native(int16_t *r, unsigned int len,
const uint8_t *buf, unsigned int buflen);
#endif /* MLKEM_USE_NATIVE_REJ_UNIFORM */

#if defined(MLKEM_USE_NATIVE_POLYVEC_COMPRESS)
/*************************************************
* Name: polyvec_compress_native
*
* Description: Compress and serialize vector of polynomials
*
* Arguments: - uint8_t *r: pointer to output byte array
* (needs space for MLKEM_POLYVECCOMPRESSEDBYTES)
* - const polyvec *a: pointer to input vector of polynomials.
* Coefficients must be unsigned canonical,
* i.e. in [0,1,..,MLKEM_Q-1].
**************************************************/
static inline void polyvec_compress_native(
uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], const polyvec *a);
#endif /* MLKEM_USE_NATIVE_POLYVEC_COMPRESS */

#if defined(MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS)
/*************************************************
* Name: polyvec_decompress_native
*
* Description: De-serialize and decompress vector of polynomials;
* approximate inverse of polyvec_compress
*
* Arguments: - polyvec *r: pointer to output vector of polynomials.
* Output will have coefficients normalized to [0,..,q-1].
* - const uint8_t *a: pointer to input byte array
* (of length MLKEM_POLYVECCOMPRESSEDBYTES)
**************************************************/
static inline void polyvec_decompress_native(
polyvec *r, const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]);

#endif /* MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS */

#endif /* MLKEM_USE_NATIVE */
#endif /* MLKEM_ARITH_NATIVE_H */
8 changes: 8 additions & 0 deletions mlkem/native/x86_64/arith_native_x86_64.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ void nttfrombytes_avx2(__m256i *r, const uint8_t *a, const __m256i *qdata);
#define tomont_avx2 MLKEM_NAMESPACE(tomont_avx2)
void tomont_avx2(__m256i *r, const __m256i *qdata);

#define polyvec_compress_avx2 MLKEM_NAMESPACE(polyvec_compress_avx2)
void polyvec_compress_avx2(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES],
const polyvec *a);

#define polyvec_decompress_avx2 MLKEM_NAMESPACE(polyvec_decompress_avx2)
void polyvec_decompress_avx2(polyvec *r,
const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]);

#endif /* MLKEM_USE_NATIVE_X86_64 && SYS_X86_64_AVX2 */

#endif /* MLKEM_X86_64_NATIVE_H */
215 changes: 215 additions & 0 deletions mlkem/native/x86_64/polyvec_compress_avx2.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// SPDX-License-Identifier: Apache-2.0

// Implementation from Kyber reference repository
// https://github.com/pq-crystals/kyber/blob/main/avx2

#include "config.h"

#if defined(MLKEM_USE_NATIVE_X86_64) && defined(SYS_X86_64_AVX2)

#include "arith_native_x86_64.h"

#include <immintrin.h>
#include <stdint.h>
#include <string.h>
#include "consts.h"
#include "params.h"



#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 320))
static void poly_compress10(uint8_t r[320], const __m256i *restrict a) {
unsigned int i;
__m256i f0, f1, f2;
__m128i t0, t1;
const __m256i v = _mm256_load_si256(&qdata.vec[_16XV / 16]);
const __m256i v8 = _mm256_slli_epi16(v, 3);
const __m256i off = _mm256_set1_epi16(15);
const __m256i shift1 = _mm256_set1_epi16(1 << 12);
const __m256i mask = _mm256_set1_epi16(1023);
const __m256i shift2 =
_mm256_set1_epi64x((1024LL << 48) + (1LL << 32) + (1024 << 16) + 1);
const __m256i sllvdidx = _mm256_set1_epi64x(12);
const __m256i shufbidx =
_mm256_set_epi8(8, 4, 3, 2, 1, 0, -1, -1, -1, -1, -1, -1, 12, 11, 10, 9,
-1, -1, -1, -1, -1, -1, 12, 11, 10, 9, 8, 4, 3, 2, 1, 0);

for (i = 0; i < MLKEM_N / 16; i++) {
f0 = _mm256_load_si256(&a[i]);
f1 = _mm256_mullo_epi16(f0, v8);
f2 = _mm256_add_epi16(f0, off);
f0 = _mm256_slli_epi16(f0, 3);
f0 = _mm256_mulhi_epi16(f0, v);
f2 = _mm256_sub_epi16(f1, f2);
f1 = _mm256_andnot_si256(f1, f2);
f1 = _mm256_srli_epi16(f1, 15);
f0 = _mm256_sub_epi16(f0, f1);
f0 = _mm256_mulhrs_epi16(f0, shift1);
f0 = _mm256_and_si256(f0, mask);
f0 = _mm256_madd_epi16(f0, shift2);
f0 = _mm256_sllv_epi32(f0, sllvdidx);
f0 = _mm256_srli_epi64(f0, 12);
f0 = _mm256_shuffle_epi8(f0, shufbidx);
t0 = _mm256_castsi256_si128(f0);
t1 = _mm256_extracti128_si256(f0, 1);
t0 = _mm_blend_epi16(t0, t1, 0xE0);
_mm_storeu_si128((__m128i *)&r[20 * i + 0], t0);
memcpy(&r[20 * i + 16], &t1, 4);
}
}

static void poly_decompress10(__m256i *restrict r, const uint8_t a[320 + 12]) {
unsigned int i;
__m256i f;
const __m256i q = _mm256_set1_epi32((MLKEM_Q << 16) + 4 * MLKEM_Q);
const __m256i shufbidx =
_mm256_set_epi8(11, 10, 10, 9, 9, 8, 8, 7, 6, 5, 5, 4, 4, 3, 3, 2, 9, 8,
8, 7, 7, 6, 6, 5, 4, 3, 3, 2, 2, 1, 1, 0);
const __m256i sllvdidx = _mm256_set1_epi64x(4);
const __m256i mask = _mm256_set1_epi32((32736 << 16) + 8184);

for (i = 0; i < MLKEM_N / 16; i++) {
f = _mm256_loadu_si256((__m256i *)&a[20 * i]);
f = _mm256_permute4x64_epi64(f, 0x94);
f = _mm256_shuffle_epi8(f, shufbidx);
f = _mm256_sllv_epi32(f, sllvdidx);
f = _mm256_srli_epi16(f, 1);
f = _mm256_and_si256(f, mask);
f = _mm256_mulhrs_epi16(f, q);
_mm256_store_si256(&r[i], f);
}
}

#elif (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352))
static void poly_compress11(uint8_t r[352 + 2], const __m256i *restrict a) {
unsigned int i;
__m256i f0, f1, f2;
__m128i t0, t1;
const __m256i v = _mm256_load_si256(&qdata.vec[_16XV / 16]);
const __m256i v8 = _mm256_slli_epi16(v, 3);
const __m256i off = _mm256_set1_epi16(36);
const __m256i shift1 = _mm256_set1_epi16(1 << 13);
const __m256i mask = _mm256_set1_epi16(2047);
const __m256i shift2 =
_mm256_set1_epi64x((2048LL << 48) + (1LL << 32) + (2048 << 16) + 1);
const __m256i sllvdidx = _mm256_set1_epi64x(10);
const __m256i srlvqidx = _mm256_set_epi64x(30, 10, 30, 10);
const __m256i shufbidx =
_mm256_set_epi8(4, 3, 2, 1, 0, 0, -1, -1, -1, -1, 10, 9, 8, 7, 6, 5, -1,
-1, -1, -1, -1, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);

for (i = 0; i < MLKEM_N / 16; i++) {
f0 = _mm256_load_si256(&a[i]);
f1 = _mm256_mullo_epi16(f0, v8);
f2 = _mm256_add_epi16(f0, off);
f0 = _mm256_slli_epi16(f0, 3);
f0 = _mm256_mulhi_epi16(f0, v);
f2 = _mm256_sub_epi16(f1, f2);
f1 = _mm256_andnot_si256(f1, f2);
f1 = _mm256_srli_epi16(f1, 15);
f0 = _mm256_sub_epi16(f0, f1);
f0 = _mm256_mulhrs_epi16(f0, shift1);
f0 = _mm256_and_si256(f0, mask);
f0 = _mm256_madd_epi16(f0, shift2);
f0 = _mm256_sllv_epi32(f0, sllvdidx);
f1 = _mm256_bsrli_epi128(f0, 8);
f0 = _mm256_srlv_epi64(f0, srlvqidx);
f1 = _mm256_slli_epi64(f1, 34);
f0 = _mm256_add_epi64(f0, f1);
f0 = _mm256_shuffle_epi8(f0, shufbidx);
t0 = _mm256_castsi256_si128(f0);
t1 = _mm256_extracti128_si256(f0, 1);
t0 = _mm_blendv_epi8(t0, t1, _mm256_castsi256_si128(shufbidx));
_mm_storeu_si128((__m128i *)&r[22 * i + 0], t0);
_mm_storel_epi64((__m128i *)&r[22 * i + 16], t1);
}
}

static void poly_decompress11(__m256i *restrict r, const uint8_t a[352 + 10]) {
unsigned int i;
__m256i f;
const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ / 16]);
const __m256i shufbidx =
_mm256_set_epi8(13, 12, 12, 11, 10, 9, 9, 8, 8, 7, 6, 5, 5, 4, 4, 3, 10,
9, 9, 8, 7, 6, 6, 5, 5, 4, 3, 2, 2, 1, 1, 0);
const __m256i srlvdidx = _mm256_set_epi32(0, 0, 1, 0, 0, 0, 1, 0);
const __m256i srlvqidx = _mm256_set_epi64x(2, 0, 2, 0);
const __m256i shift =
_mm256_set_epi16(4, 32, 1, 8, 32, 1, 4, 32, 4, 32, 1, 8, 32, 1, 4, 32);
const __m256i mask = _mm256_set1_epi16(32752);

for (i = 0; i < MLKEM_N / 16; i++) {
f = _mm256_loadu_si256((__m256i *)&a[22 * i]);
f = _mm256_permute4x64_epi64(f, 0x94);
f = _mm256_shuffle_epi8(f, shufbidx);
f = _mm256_srlv_epi32(f, srlvdidx);
f = _mm256_srlv_epi64(f, srlvqidx);
f = _mm256_mullo_epi16(f, shift);
f = _mm256_srli_epi16(f, 1);
f = _mm256_and_si256(f, mask);
f = _mm256_mulhrs_epi16(f, q);
_mm256_store_si256(&r[i], f);
}
}

#endif

/*************************************************
* Name: polyvec_compress
*
* Description: Compress and serialize vector of polynomials
*
* Arguments: - uint8_t *r: pointer to output byte array
* (needs space for MLKEM_POLYVECCOMPRESSEDBYTES)
* - polyvec *a: pointer to input vector of polynomials
**************************************************/
void polyvec_compress_avx2(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES],
const polyvec *a) {
unsigned int i;
// TODO: can we eliminate the extra bytes?
// pad input so we can always store full vectors

#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 320))
for (i = 0; i < MLKEM_K; i++)
poly_compress10(&r[320 * i], (__m256i *)&a->vec[i]);
#elif (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352))
uint8_t rpadded[MLKEM_POLYVECCOMPRESSEDBYTES + 2];
for (i = 0; i < MLKEM_K; i++)
poly_compress11(&rpadded[352 * i], (__m256i *)&a->vec[i]);
memcpy(r, rpadded, MLKEM_POLYVECCOMPRESSEDBYTES);
#endif
}

/*************************************************
* Name: polyvec_decompress
*
* Description: De-serialize and decompress vector of polynomials;
* approximate inverse of polyvec_compress
*
* Arguments: - polyvec *r: pointer to output vector of polynomials
* - const uint8_t *a: pointer to input byte array
* (of length MLKEM_POLYVECCOMPRESSEDBYTES)
**************************************************/
void polyvec_decompress_avx2(polyvec *r,
const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) {
unsigned int i;
// TODO: can we eliminate the extra bytes?
// pad input so we can always load full vectors

#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 320))
uint8_t apadded[MLKEM_POLYVECCOMPRESSEDBYTES + 12];
memcpy(apadded, a, MLKEM_POLYVECCOMPRESSEDBYTES);
for (i = 0; i < MLKEM_K; i++)
poly_decompress10((__m256i *)&r->vec[i], &apadded[320 * i]);
#elif (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352))
uint8_t apadded[MLKEM_POLYVECCOMPRESSEDBYTES + 10];
memcpy(apadded, a, MLKEM_POLYVECCOMPRESSEDBYTES);
for (i = 0; i < MLKEM_K; i++)
poly_decompress11((__m256i *)&r->vec[i], &apadded[352 * i]);
#endif
}

#else /* MLKEM_USE_NATIVE_X86_64 && SYS_X86_64_AVX2 */
// Dummy declaration for compilers disliking empty compilation units
int empty_cu_compress_avx2;
#endif /* MLKEM_USE_NATIVE_X86_64 && SYS_X86_64_AVX2 */
15 changes: 14 additions & 1 deletion mlkem/native/x86_64/profiles/default.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
#include "../arith_native_x86_64.h"
#include "../consts.h"

#include "params.h"
#include "poly.h"
#include "polyvec.h"

#define MLKEM_USE_NATIVE_NTT_CUSTOM_ORDER

Expand All @@ -26,6 +28,8 @@
#define MLKEM_USE_NATIVE_POLY_MULCACHE_COMPUTE
#define MLKEM_USE_NATIVE_POLY_TOBYTES
#define MLKEM_USE_NATIVE_POLY_FROMBYTES
#define MLKEM_USE_NATIVE_POLYVEC_COMPRESS
#define MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS

#define INVNTT_BOUND_NATIVE \
(14870 + 1) // Bound from the official Kyber repository
Expand Down Expand Up @@ -89,4 +93,13 @@ static inline void poly_frombytes_native(poly *r,
nttfrombytes_avx2((__m256i *)r->coeffs, a, qdata.vec);
}

#endif /* MLKEM_ARITH_NATIVE_PROFILE_H */
static inline void polyvec_compress_native(
uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES], const polyvec *a) {
polyvec_compress_avx2(r, a);
}

static inline void polyvec_decompress_native(
polyvec *r, const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) {
polyvec_decompress_avx2(r, a);
}
#endif /* MLKEM_ARITH_NATIVE_PROFILE_H */
28 changes: 18 additions & 10 deletions mlkem/polyvec.c
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "poly.h"

#include "debug/debug.h"

#if !defined(MLKEM_USE_NATIVE_POLYVEC_COMPRESS)
void polyvec_compress(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES],
const polyvec *a) {
POLYVEC_UBOUND(a, MLKEM_Q);
Expand Down Expand Up @@ -92,17 +94,15 @@ void polyvec_compress(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES],
#error "MLKEM_POLYVECCOMPRESSEDBYTES needs to be in {320*MLKEM_K, 352*MLKEM_K}"
#endif
}
#else /* MLKEM_USE_NATIVE_POLYVEC_COMPRESS */
void polyvec_compress(uint8_t r[MLKEM_POLYVECCOMPRESSEDBYTES],
const polyvec *a) {
POLYVEC_UBOUND(a, MLKEM_Q);
polyvec_compress_native(r, a);
}
#endif /* MLKEM_USE_NATIVE_POLYVEC_COMPRESS */

/*************************************************
* Name: polyvec_decompress
*
* Description: De-serialize and decompress vector of polynomials;
* approximate inverse of polyvec_compress
*
* Arguments: - polyvec *r: pointer to output vector of polynomials
* - const uint8_t *a: pointer to input byte array
* (of length MLKEM_POLYVECCOMPRESSEDBYTES)
**************************************************/
#if !defined(MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS)
void polyvec_decompress(polyvec *r,
const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) {
#if (MLKEM_POLYVECCOMPRESSEDBYTES == (MLKEM_K * 352))
Expand Down Expand Up @@ -184,6 +184,14 @@ void polyvec_decompress(polyvec *r,

POLYVEC_UBOUND(r, MLKEM_Q);
}
#else /* MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS */
void polyvec_decompress(polyvec *r,
const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) {
polyvec_decompress_native(r, a);
POLYVEC_UBOUND(r, MLKEM_Q);
}
#endif /* MLKEM_USE_NATIVE_POLYVEC_DECOMPRESS */


void polyvec_tobytes(uint8_t r[MLKEM_POLYVECBYTES], const polyvec *a) {
unsigned int i;
Expand Down
11 changes: 11 additions & 0 deletions mlkem/polyvec.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ REQUIRES(FORALL(int, k0, 0, MLKEM_K - 1,
// clang-format on

#define polyvec_decompress MLKEM_NAMESPACE(polyvec_decompress)
/*************************************************
* Name: polyvec_decompress
*
* Description: De-serialize and decompress vector of polynomials;
* approximate inverse of polyvec_compress
*
* Arguments: - polyvec *r: pointer to output vector of polynomials.
* Output will have coefficients normalized to [0,..,q-1].
* - const uint8_t *a: pointer to input byte array
* (of length MLKEM_POLYVECCOMPRESSEDBYTES)
**************************************************/
void polyvec_decompress(
polyvec *r,
const uint8_t a[MLKEM_POLYVECCOMPRESSEDBYTES]) // clang-format off
Expand Down
5 changes: 5 additions & 0 deletions test/bench_components_mlkem.c
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ static int bench(void) {
BENCH("gen_matrix", gen_matrix((polyvec *)data0, (uint8_t *)data1, 0))


BENCH("polyvec-compress",
polyvec_compress((uint8_t *)data0, (polyvec *)data1));
BENCH("polyvec-decompress",
polyvec_decompress((polyvec *)data0, (uint8_t *)data1));

Comment on lines +171 to +175
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this needs to be removed.

#if defined(MLKEM_USE_NATIVE_AARCH64)
BENCH("ntt-clean", ntt_asm_clean((int16_t *)data0));
BENCH("intt-clean", intt_asm_clean((int16_t *)data0));
Expand Down