diff --git a/mlkem/poly.c b/mlkem/poly.c index d520e1d6a..5dec2d08c 100644 --- a/mlkem/poly.c +++ b/mlkem/poly.c @@ -7,6 +7,7 @@ #include "reduce.h" #include "cbd.h" #include "symmetric.h" +#include "verify.h" /************************************************************ * Name: scalar_compress_q_16 @@ -260,7 +261,6 @@ void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) { **************************************************/ void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) { unsigned int i, j; - int16_t mask; #if (KYBER_INDCPA_MSGBYTES != KYBER_N/8) #error "KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!" @@ -268,8 +268,8 @@ void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) { for (i = 0; i < KYBER_N / 8; i++) { for (j = 0; j < 8; j++) { - mask = -(int16_t)((msg[i] >> j) & 1); - r->coeffs[8 * i + j] = mask & ((KYBER_Q + 1) / 2); + r->coeffs[8 * i + j] = 0; + cmov_int16(r->coeffs + 8 * i + j, ((KYBER_Q + 1) / 2), (msg[i] >> j) & 1); } } } diff --git a/mlkem/verify.c b/mlkem/verify.c index 955dd5e3b..e826f8954 100644 --- a/mlkem/verify.c +++ b/mlkem/verify.c @@ -46,3 +46,19 @@ void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b) { r[i] ^= b & (r[i] ^ x[i]); } } + +/************************************************* +* Name: cmov_int16 +* +* Description: Copy input v to *r if b is 1, don't modify *r if b is 0. +* Requires b to be in {0,1}; +* Runs in constant time. +* +* Arguments: int16_t *r: pointer to output int16_t +* int16_t v: input int16_t +* uint8_t b: Condition bit; has to be in {0,1} +**************************************************/ +void cmov_int16(int16_t *r, int16_t v, uint16_t b) { + b = -b; + *r ^= b & ((*r) ^ v); +} diff --git a/mlkem/verify.h b/mlkem/verify.h index e28d02447..d6f362653 100644 --- a/mlkem/verify.h +++ b/mlkem/verify.h @@ -12,4 +12,7 @@ int verify(const uint8_t *a, const uint8_t *b, size_t len); #define cmov KYBER_NAMESPACE(cmov) void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b); +#define cmov_int16 KYBER_NAMESPACE(cmov_int16) +void cmov_int16(int16_t *r, int16_t v, uint16_t b); + #endif