diff --git a/Makefile b/Makefile index 0d98d7a83..06c98b064 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ CFLAGS_NISTRANDOMBYTES = ${CFLAGS} ${INCLUDE_NISTRANDOM} NISTFLAGS += -Wno-unused-result -O3 -fomit-frame-pointer RM = /bin/rm -ASM_CLEAN = mlkem/asm/clean/rej_uniform_asm.s +ASM_CLEAN = mlkem/asm/clean/rej_uniform_asm.s mlkem/asm/clean/poly_asm.s SOURCES = $(ASM_CLEAN) mlkem/kem.c mlkem/indcpa.c mlkem/polyvec.c mlkem/poly.c mlkem/ntt.c mlkem/cbd.c mlkem/reduce.c mlkem/verify.c mlkem/rej_uniform.c SOURCESKECCAK = $(SOURCES) fips202/keccakf1600.c fips202/fips202.c mlkem/symmetric-shake.c diff --git a/mlkem/asm/clean/poly_asm.h b/mlkem/asm/clean/poly_asm.h new file mode 100644 index 000000000..5cdf94ca2 --- /dev/null +++ b/mlkem/asm/clean/poly_asm.h @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: Apache-2.0 +#ifndef POLY_ASM_H +#define POLY_ASM_H + +#include "params.h" + +void poly_frommsg_asm(int16_t coeffs[KYBER_N], + const uint8_t msg[KYBER_INDCPA_MSGBYTES], + const uint16_t bits[8]); + +void poly_tomsg_asm(uint8_t msg[KYBER_INDCPA_MSGBYTES], + const int16_t coeffs[KYBER_N], + const uint16_t position[8]); + +#endif diff --git a/mlkem/asm/clean/poly_asm.s b/mlkem/asm/clean/poly_asm.s new file mode 100644 index 000000000..a6f59e16b --- /dev/null +++ b/mlkem/asm/clean/poly_asm.s @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: Apache-2.0 +/************************************************* +* Name: poly_frommsg_asm +* +* Description: Convert 32-byte message to polynomial +* +* Arguments: - int16_t *coeffs: pointer to output polynomial +* - const uint8_t *msg: pointer to input message +* - const uint16_t *bits: pointer to bit_table +**************************************************/ +.align 4 +.global poly_frommsg_asm +.global _poly_frommsg_asm +poly_frommsg_asm: +_poly_frommsg_asm: + + /* Input registers */ + coeffs .req x0 + msg .req x1 + bit_table .req x2 + + /* Temporary registers */ + iter .req x9 + tmp .req w10 + + /* Vector registers */ + const .req v16 + bits .req v17 + bitsq .req q17 + a0 .req v18 + a0q .req q18 + + /* Vectorize code start */ + mov tmp, #1665 // (KYBER_Q + 1) / 2 + dup const.8h, tmp + ldr bitsq, [bit_table] + mov iter, xzr + loop: + ldrb tmp, [msg, iter] + dup a0.8h, tmp + and a0.16b, a0.16b, bits.16b + cmeq a0.8h, a0.8h, #0 + bic a0.16b, const.16b, a0.16b + str a0q, [coeffs, iter, lsl #4] + add iter, iter, #1 + cmp iter, #32 // KYBER_N / 8 + b.ne loop + ret + + /* Input registers */ + .unreq coeffs + .unreq msg + .unreq bit_table + + /* Temporary registers */ + .unreq iter + .unreq tmp + + /* Vector registers */ + .unreq const + .unreq bits + .unreq bitsq + .unreq a0 + .unreq a0q +/************************************************* +* Name: poly_tomsg_asm +* +* Description: Convert polynomial to 32-byte message +* +* Arguments: - uint8_t *msg: pointer to output message +* - int16_t *coeffs: pointer to input polynomial +**************************************************/ +.align 4 +.global poly_tomsg_asm +.global _poly_tomsg_asm +poly_tomsg_asm: +_poly_tomsg_asm: + + /* Input registers */ + msg .req x0 + coeffs .req x1 + position .req x2 + + /* Temporary registers */ + iter .req x9 + tmp .req w10 + idx_addr .req x11 + + r0 .req w12 + r1 .req w13 + r2 .req w14 + r3 .req w15 + + /* Vector registers */ + vhq .req v16 + vhqinv .req v17 + + a0 .req v18 + a1 .req v19 + a2 .req v20 + a3 .req v21 + + idx .req v22 + idxq .req q22 + + t0 .req h23 + t1 .req h24 + t2 .req h25 + t3 .req h26 + + /* Vectorize code start */ + + mov w9, #1164 // KYBER_Q / 2 + dup vhq.8h, w9 + mov w10, #10079 // 2^26 / KYBER_Q / 2 + dup vhqinv.8h, w10 + ldr idxq, [position] + + mov iter, xzr + + loop32: + ld1 {a0.8h, a1.8h, a2.8h, a3.8h}, [x1], #64 + + /* t << = 1; */ + add a0.8h, a0.8h, a0.8h + add a1.8h, a1.8h, a1.8h + add a2.8h, a2.8h, a2.8h + add a3.8h, a3.8h, a3.8h + + /* t += KYBER_Q/2 */ + add a0.8h, a0.8h, vhq.8h + add a1.8h, a1.8h, vhq.8h + add a2.8h, a2.8h, vhq.8h + add a3.8h, a3.8h, vhq.8h + + /* + * t = t / KYBER_Q + * Instead of direct division, we multiply with inverse of KYBER_Q and utilize the sqdmulh instruction. + * To do so, we have a few options: + * 80635 = round(2^28/KYBER_Q) as in the reference C implementation + * However, we need number that fit in the range [-2^15..2^15] + * So we pick: + * 20159 = round(2^26/KYBER_Q) + * Because we use sqdmulh instruction, the constant will be: + * 10079 = round(2^26/KYBER_Q/2) + * sqdmulh helps us shift right by 16, we need additional shift right by 10 to complete shift right by 26. + * The other approach is to use smull/umull instructions, but they are inefficient. + */ + sqdmulh a0.8h, a0.8h, vhqinv.8h + sqdmulh a1.8h, a1.8h, vhqinv.8h + sqdmulh a2.8h, a2.8h, vhqinv.8h + sqdmulh a3.8h, a3.8h, vhqinv.8h + + ushr a0.8h, a0.8h, #10 + ushr a1.8h, a1.8h, #10 + ushr a2.8h, a2.8h, #10 + ushr a3.8h, a3.8h, #10 + + /* t = t & 1 */ + + bic a0.8h, #62 + bic a1.8h, #62 + bic a2.8h, #62 + bic a3.8h, #62 + + /* Position the bits */ + ushl a0.8h, a0.8h, idx.8h + ushl a1.8h, a1.8h, idx.8h + ushl a2.8h, a2.8h, idx.8h + ushl a3.8h, a3.8h, idx.8h + + /* Extract the result */ + addv t0, a0.8h + addv t1, a1.8h + addv t2, a2.8h + addv t3, a3.8h + + fmov r0, t0 + fmov r1, t1 + fmov r2, t2 + fmov r3, t3 + + strb r0, [x0], #1 + strb r1, [x0], #1 + strb r2, [x0], #1 + strb r3, [x0], #1 + + add iter, iter, #4 + cmp iter, #32 + b.ne loop32 + + ret + + /* Input registers */ + .unreq msg + .unreq coeffs + .unreq position + + /* Temporary registers */ + .unreq iter + .unreq tmp + .unreq idx_addr + + .unreq r0 + .unreq r1 + .unreq r2 + .unreq r3 + + /* Vector registers */ + .unreq vhq + .unreq vhqinv + + .unreq a0 + .unreq a1 + .unreq a2 + .unreq a3 + + .unreq idx + .unreq idxq + + .unreq t0 + .unreq t1 + .unreq t2 + .unreq t3 diff --git a/mlkem/asm/clean/rej_uniform_asm.h b/mlkem/asm/clean/rej_uniform_asm.h index ad03effc4..78ea611b6 100644 --- a/mlkem/asm/clean/rej_uniform_asm.h +++ b/mlkem/asm/clean/rej_uniform_asm.h @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 -#ifndef REJ_UNIFORM_ASM -#define REJ_UNIFORM_ASM +#ifndef REJ_UNIFORM_ASM_H +#define REJ_UNIFORM_ASM_H unsigned int rej_uniform_asm(int16_t *r, const uint8_t *buf, diff --git a/mlkem/poly.c b/mlkem/poly.c index 9860877d1..6383de44a 100644 --- a/mlkem/poly.c +++ b/mlkem/poly.c @@ -7,6 +7,8 @@ #include "reduce.h" #include "cbd.h" #include "symmetric.h" +#include "poly_asm.h" +#include "rej_uniform.h" /************************************************************ * Name: scalar_compress_q_16 @@ -258,19 +260,10 @@ void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) { * - const uint8_t *msg: pointer to input message **************************************************/ void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) { - unsigned int i, j; - int16_t mask; - #if (KYBER_INDCPA_MSGBYTES != KYBER_N/8) #error "KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!" #endif - - for (i = 0; i < KYBER_N / 8; i++) { - for (j = 0; j < 8; j++) { - mask = -(int16_t)((msg[i] >> j) & 1); - r->coeffs[8 * i + j] = mask & ((KYBER_Q + 1) / 2); - } - } + poly_frommsg_asm(r->coeffs, msg, bit_table); } /************************************************* @@ -281,24 +274,9 @@ void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) { * Arguments: - uint8_t *msg: pointer to output message * - const poly *a: pointer to input polynomial **************************************************/ -void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *a) { - unsigned int i, j; - uint32_t t; - - for (i = 0; i < KYBER_N / 8; i++) { - msg[i] = 0; - for (j = 0; j < 8; j++) { - t = a->coeffs[8 * i + j]; - // t += ((int16_t)t >> 15) & KYBER_Q; - // t = (((t << 1) + KYBER_Q/2)/KYBER_Q) & 1; - t <<= 1; - t += 1665; - t *= 80635; - t >>= 28; - t &= 1; - msg[i] |= t << j; - } - } +void poly_tomsg(uint8_t msg[KYBER_INDCPA_MSGBYTES], const poly *r) { + const uint16_t position[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + poly_tomsg_asm(msg, r->coeffs, position); } /************************************************* diff --git a/mlkem/rej_uniform.h b/mlkem/rej_uniform.h index 9a5d2bf7d..102932fac 100644 --- a/mlkem/rej_uniform.h +++ b/mlkem/rej_uniform.h @@ -8,4 +8,6 @@ unsigned int rej_uniform(int16_t *r, const uint8_t *buf, unsigned int buflen); +const uint16_t bit_table[8]; + #endif