From d04c220b3f3c0bf7c0af28328ccc1de425f4de78 Mon Sep 17 00:00:00 2001 From: "Matthias J. Kannwischer" Date: Thu, 13 Jun 2024 13:22:29 +0800 Subject: [PATCH] Fix secret-dependent branch in poly_fromsg (#55) See pq-crystals/kyber@9b8d306 See https://groups.google.com/a/list.nist.gov/g/pqc-forum/c/hqbtIGFKIpU/m/cnE3pbueBgAJ Signed-off-by: Matthias J. Kannwischer Co-authored-by: Hanno Becker --- mlkem/poly.c | 6 +++--- mlkem/verify.c | 16 ++++++++++++++++ mlkem/verify.h | 3 +++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/mlkem/poly.c b/mlkem/poly.c index d2c63d1c9..cb91b8240 100644 --- a/mlkem/poly.c +++ b/mlkem/poly.c @@ -6,6 +6,7 @@ #include "reduce.h" #include "cbd.h" #include "symmetric.h" +#include "verify.h" /************************************************************ * Name: scalar_compress_q_16 @@ -290,7 +291,6 @@ void poly_frombytes(poly *r, const uint8_t a[KYBER_POLYBYTES]) { **************************************************/ void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) { unsigned int i, j; - int16_t mask; #if (KYBER_INDCPA_MSGBYTES != KYBER_N/8) #error "KYBER_INDCPA_MSGBYTES must be equal to KYBER_N/8 bytes!" @@ -298,8 +298,8 @@ void poly_frommsg(poly *r, const uint8_t msg[KYBER_INDCPA_MSGBYTES]) { for (i = 0; i < KYBER_N / 8; i++) { for (j = 0; j < 8; j++) { - mask = -(int16_t)((msg[i] >> j) & 1); - r->coeffs[8 * i + j] = mask & ((KYBER_Q + 1) / 2); + r->coeffs[8 * i + j] = 0; + cmov_int16(r->coeffs + 8 * i + j, ((KYBER_Q + 1) / 2), (msg[i] >> j) & 1); } } } diff --git a/mlkem/verify.c b/mlkem/verify.c index 955dd5e3b..e826f8954 100644 --- a/mlkem/verify.c +++ b/mlkem/verify.c @@ -46,3 +46,19 @@ void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b) { r[i] ^= b & (r[i] ^ x[i]); } } + +/************************************************* +* Name: cmov_int16 +* +* Description: Copy input v to *r if b is 1, don't modify *r if b is 0. +* Requires b to be in {0,1}; +* Runs in constant time. +* +* Arguments: int16_t *r: pointer to output int16_t +* int16_t v: input int16_t +* uint8_t b: Condition bit; has to be in {0,1} +**************************************************/ +void cmov_int16(int16_t *r, int16_t v, uint16_t b) { + b = -b; + *r ^= b & ((*r) ^ v); +} diff --git a/mlkem/verify.h b/mlkem/verify.h index e28d02447..d6f362653 100644 --- a/mlkem/verify.h +++ b/mlkem/verify.h @@ -12,4 +12,7 @@ int verify(const uint8_t *a, const uint8_t *b, size_t len); #define cmov KYBER_NAMESPACE(cmov) void cmov(uint8_t *r, const uint8_t *x, size_t len, uint8_t b); +#define cmov_int16 KYBER_NAMESPACE(cmov_int16) +void cmov_int16(int16_t *r, int16_t v, uint16_t b); + #endif