diff --git a/mlkem/debug/debug.c b/mlkem/debug/debug.c index 94a5dbe7e..c95a3ca11 100644 --- a/mlkem/debug/debug.c +++ b/mlkem/debug/debug.c @@ -5,6 +5,15 @@ static char debug_buf[256]; +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val) { + if (val == 0) { + snprintf(debug_buf, sizeof(debug_buf), "Assertion failed: %s (value %d)", + description, val); + mlkem_debug_print_error(file, line, debug_buf); + exit(1); + } +} void mlkem_debug_check_bounds(const char *file, int line, const char *description, const int16_t *ptr, unsigned len, int lower_bound_exclusive, diff --git a/mlkem/debug/debug.h b/mlkem/debug/debug.h index bcbc70ee8..13b18418c 100644 --- a/mlkem/debug/debug.h +++ b/mlkem/debug/debug.h @@ -7,6 +7,22 @@ #include #include +/************************************************* + * Name: mlkem_debug_assert + * + * Description: Check debug assertion + * + * Prints an error message to stderr and calls + * exit(1) if not. + * + * Arguments: - file: filename + * - line: line number + * - description: Textual description of assertion + * - val: Value asserted to be non-zero + **************************************************/ +void mlkem_debug_assert(const char *file, int line, const char *description, + const int val); + /************************************************* * Name: mlkem_debug_check_bounds * @@ -21,17 +37,36 @@ * - description: Textual description of check * - ptr: Base of array to be checked * - len: Number of int16_t in ptr - * - lower_bound_inclusive: Inclusive lower bound - * - upper_bound_inclusive: Inclusive upper bound + * - lower_bound_exclusive: Exclusive lower bound + * - upper_bound_exclusive: Exclusive upper bound **************************************************/ void mlkem_debug_check_bounds(const char *file, int line, const char *description, const int16_t *ptr, - unsigned len, int lower_bound_inclusive, - int upper_bound_inclusive); + unsigned len, int lower_bound_exclusive, + int upper_bound_exclusive); /* Print error message to stderr alongside file and line information */ void mlkem_debug_print_error(const char *file, int line, const char *msg); +/* Check assertion, calling exit() upon failure + * + * val: Value that's asserted to be non-zero + * msg: Message to print on failure + * + * Currently called CASSERT to avoid clash with CBMC assert. + */ +#define CASSERT(val, msg) \ + do { \ + mlkem_debug_assert(__FILE__, __LINE__, (msg), (val)); \ + } while (0) + +/* Check absolute bounds of scalar + * val: Scalar to be checked + * abs_bound: Exclusive upper bound on absolute value to check + * msg: Message to print on failure */ +#define SCALAR_BOUND(val, abs_bound, msg) \ + CASSERT((val) > -(abs_bound) && (val) < (abs_bound), msg) + /* Check that all coefficients in array of int16_t's are non-negative * and below an exclusive upper bound. * @@ -127,6 +162,12 @@ void mlkem_debug_print_error(const char *file, int line, const char *msg); #else /* MLKEM_DEBUG */ +#define CASSERT(...) \ + do { \ + } while (0) +#define SCALAR_BOUND(...) \ + do { \ + } while (0) #define BOUND(...) \ do { \ } while (0) diff --git a/mlkem/ntt.c b/mlkem/ntt.c index b7b7ff4aa..d4d455ab2 100644 --- a/mlkem/ntt.c +++ b/mlkem/ntt.c @@ -107,7 +107,7 @@ void poly_ntt(poly *p) { for (start = 0; start < 256; start = j + len) { zeta = zetas[k++]; for (j = start; j < start + len; j++) { - t = fqmul(zeta, r[j + len]); + t = fqmul(r[j + len], zeta); r[j + len] = r[j] - t; r[j] = r[j] + t; } @@ -173,7 +173,7 @@ void poly_invntt_tomont(poly *p) { t = r[j]; r[j] = barrett_reduce(t + r[j + len]); // abs < q/2 r[j + len] = r[j + len] - t; - r[j + len] = fqmul(zeta, r[j + len]); // abs < 3/4 q + r[j + len] = fqmul(r[j + len], zeta); // abs < 3/4 q } } } diff --git a/mlkem/poly.c b/mlkem/poly.c index 882163be8..c1d16e597 100644 --- a/mlkem/poly.c +++ b/mlkem/poly.c @@ -523,7 +523,7 @@ void poly_tomont(poly *r) { unsigned int i; const int16_t f = (1ULL << 32) % MLKEM_Q; // 1353 for (i = 0; i < MLKEM_N; i++) { - r->coeffs[i] = montgomery_reduce((int32_t)r->coeffs[i] * f); + r->coeffs[i] = fqmul(r->coeffs[i], f); } POLY_BOUND(r, MLKEM_Q); diff --git a/mlkem/reduce.h b/mlkem/reduce.h index f401777ca..110b32f8f 100644 --- a/mlkem/reduce.h +++ b/mlkem/reduce.h @@ -4,6 +4,7 @@ #include #include "cbmc.h" +#include "debug/debug.h" #include "params.h" #define MONT -1044 // 2^16 mod q @@ -29,17 +30,21 @@ ENSURES(RETURN_VALUE >= -HALF_Q && RETURN_VALUE <= HALF_Q); * Description: Multiplication followed by Montgomery reduction * * Arguments: - int16_t a: first factor - * - int16_t b: second factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <|q|/2) * - * Returns 16-bit integer congruent to a*b*R^{-1} mod q - * - * If one input is < |q|/2 in absolute value (which is given - * in the common case of multiplication with constants), the - * return value is < |q| in absolute value. + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. * **************************************************/ static inline int16_t fqmul(int16_t a, int16_t b) { - return montgomery_reduce((int32_t)a * (int32_t)b); + SCALAR_BOUND(b, HALF_Q + 1, "fqmul input"); + + int16_t res = montgomery_reduce((int32_t)a * (int32_t)b); + + SCALAR_BOUND(res, MLKEM_Q, "fqmul output"); + return res; } #endif