diff --git a/mlkem/ntt.c b/mlkem/ntt.c index 3bb0694a7..572f78396 100644 --- a/mlkem/ntt.c +++ b/mlkem/ntt.c @@ -42,8 +42,8 @@ // 5 -- 7 // STATIC_TESTABLE -void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, int len, - int bound) +void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int16_t zeta_twisted, + int start, int len, int bound) __contract__( requires(0 <= start && start < MLKEM_N) requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N) @@ -69,7 +69,7 @@ __contract__( invariant(array_abs_bound(r, j + len, MLKEM_N - 1, bound))) { int16_t t; - t = fqmul(r[j + len], zeta); + t = fqmul_bar(r[j + len], zeta, zeta_twisted); r[j + len] = r[j] - t; r[j] = r[j] + t; } @@ -107,8 +107,9 @@ __contract__( invariant(array_abs_bound(r, 0, start - 1, (layer * MLKEM_Q - 1) + MLKEM_Q)) invariant(array_abs_bound(r, start, MLKEM_N - 1, layer * MLKEM_Q - 1))) { - int16_t zeta = zetas[k++]; - ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q - 1); + int16_t zeta = zetas[k]; + uint16_t zeta_twisted = zetas_twisted[k++]; + ntt_butterfly_block(r, zeta, zeta_twisted, start, len, layer * MLKEM_Q - 1); } } @@ -178,7 +179,8 @@ __contract__( // Normalised form of k == MLKEM_N / len - 1 - start / (2 * len) invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len)) { - int16_t zeta = zetas[k--]; + int16_t zeta = zetas[k]; + uint16_t zeta_twisted = zetas_twisted[k--]; for (int j = start; j < start + len; j++) __loop__( invariant(start <= j && j <= start + len) @@ -188,7 +190,7 @@ __contract__( int16_t t = r[j]; r[j] = barrett_reduce(t + r[j + len]); r[j + len] = r[j + len] - t; - r[j + len] = fqmul(r[j + len], zeta); + r[j + len] = fqmul_bar(r[j + len], zeta, zeta_twisted); } } } diff --git a/mlkem/ntt.h b/mlkem/ntt.h index 48fbd43c7..399eba715 100644 --- a/mlkem/ntt.h +++ b/mlkem/ntt.h @@ -13,6 +13,9 @@ #define zetas MLKEM_NAMESPACE(zetas) extern const int16_t zetas[128]; +#define zetas_twisted MLKEM_NAMESPACE(zetas_twisted) +extern const int16_t zetas_twisted[128]; + /************************************************* * Name: poly_ntt * diff --git a/mlkem/poly.c b/mlkem/poly.c index ccf214f06..8ca496096 100644 --- a/mlkem/poly.c +++ b/mlkem/poly.c @@ -517,8 +517,10 @@ void poly_mulcache_compute(poly_mulcache *x, const poly *a) for (i = 0; i < MLKEM_N / 4; i++) __loop__(invariant(i >= 0 && i <= MLKEM_N / 4)) { - x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]); - x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]); + x->coeffs[2 * i + 0] = + fqmul_bar(a->coeffs[4 * i + 1], zetas[64 + i], zetas_twisted[64 + i]); + x->coeffs[2 * i + 1] = + fqmul_bar(a->coeffs[4 * i + 3], -zetas[64 + i], -zetas_twisted[64 + i]); } POLY_BOUND(x, MLKEM_Q); } diff --git a/mlkem/reduce.c b/mlkem/reduce.c index cac77880c..7cf4a09fa 100644 --- a/mlkem/reduce.c +++ b/mlkem/reduce.c @@ -111,6 +111,19 @@ int16_t fqmul(int16_t a, int16_t b) return res; } +int16_t fqmul_bar(int16_t a, int16_t b, int16_t b_twisted) +{ + SCALAR_BOUND(b, HALF_Q, "fqmul input"); + + int16_t quot = ((int32_t)a * b_twisted) >> 16; + uint16_t prod_low = a * b; + uint16_t round_low = quot * MLKEM_Q; + uint16_t r = prod_low - round_low; + + SCALAR_BOUND(r, MLKEM_Q, "fqmul output"); + return (int16_t)r; +} + // To divide by MLKEM_Q using Barrett multiplication, the "magic number" // multiplier is round_to_nearest(2**26/MLKEM_Q) #define BPOWER 26 diff --git a/mlkem/reduce.h b/mlkem/reduce.h index 64521ebed..026c56dbc 100644 --- a/mlkem/reduce.h +++ b/mlkem/reduce.h @@ -58,5 +58,23 @@ __contract__( ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q) ); +/************************************************* + * Name: fqmul + * + * Description: Barrett multiplication modulo q=3329 + * (https://eprint.iacr.org/2021/986) + * + * Arguments: - int16_t a: first factor + * Can be any int16_t. + * - int16_t b: second factor. + * Must be signed canonical (abs value <(q+1)/2) + * - int16_t b_twisted: Barrett twist of second factor + * + * Returns 16-bit integer congruent to a*b*R^{-1} mod q, and + * smaller than q in absolute value. + * + **************************************************/ +#define fqmul_bar MLKEM_NAMESPACE(fqmul_bar) +int16_t fqmul_bar(int16_t a, int16_t b, int16_t b_twisted); #endif diff --git a/mlkem/zetas.c b/mlkem/zetas.c index d19c104b5..41de2488a 100644 --- a/mlkem/zetas.c +++ b/mlkem/zetas.c @@ -9,16 +9,34 @@ // Table of zeta values used in the reference NTT and inverse NTT. // See autogenerate_files.py for details. const int16_t zetas[128] = { - -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, - 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, - -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223, - 652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666, - -1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247, - -951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430, - 555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291, - -460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119, - -1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603, - 610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220, - -1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108, - -308, 996, 991, 958, -1460, 1522, 1628, + 1, -1600, -749, -40, -687, 630, -1432, 848, 1062, -1410, 193, + 797, -543, -69, 569, -1583, 296, -882, 1339, 1476, -283, 56, + -1089, 1333, 1426, -1235, 535, -447, -936, -450, -1355, 821, 289, + 331, -76, -1573, 1197, -1025, -1052, -1274, 650, -1352, -816, 632, + -464, 33, 1320, -1414, -1010, 1435, 807, 452, 1438, -461, 1534, + -927, -682, -712, 1481, 648, -855, -219, 1227, 910, 17, -568, + 583, -680, 1637, 723, -1041, 1100, 1409, -667, -48, 233, 756, + -1173, -314, -279, -1626, 1651, -540, -1540, -1482, 952, 1461, -642, + 939, -1021, -892, -941, 733, -992, 268, 641, 1584, -1031, -1292, + -109, 375, -780, -1239, 1645, 1063, 319, -556, 757, -1230, 561, + -863, -735, -525, 1092, 403, 1026, 1143, -1179, -554, 886, -1607, + 1212, -1455, 1029, -1219, -394, 885, -1175, +}; + +const int16_t zetas_twisted[128] = { + 19, -31499, -14746, -788, -13525, 12402, -28191, 16694, 20906, + -27758, 3799, 15690, -10690, -1359, 11201, -31164, 5827, -17364, + 26360, 29057, -5572, 1102, -21439, 26241, 28072, -24313, 10532, + -8800, -18427, -8859, -26676, 16162, 5689, 6516, -1497, -30967, + 23564, -20179, -20711, -25081, 12796, -26617, -16065, 12441, -9135, + 649, 25986, -27837, -19884, 28249, 15886, 8898, 28309, -9076, + 30198, -18250, -13427, -14017, 29155, 12756, -16832, -4312, 24155, + 17914, 334, -11182, 11477, -13387, 32226, 14233, -20494, 21655, + 27738, -13131, -945, 4586, 14882, -23093, -6182, -5493, -32011, + 32502, -10631, -30318, -29176, 18741, 28761, -12639, 18485, -20100, + -17561, -18525, 14430, -19529, 5275, 12618, 31183, -20297, -25435, + -2146, 7382, -15356, -24392, 32384, 20926, 6279, -10946, 14902, + -24215, 11044, -16990, -14470, -10336, 21497, 7933, 20198, 22501, + -23211, -10907, 17442, -31637, 23859, -28644, 20257, -23998, -7757, + 17422, -23132, }; diff --git a/scripts/autogenerate_files.py b/scripts/autogenerate_files.py index 4b95043fa..887cc77f5 100644 --- a/scripts/autogenerate_files.py +++ b/scripts/autogenerate_files.py @@ -4,6 +4,7 @@ import subprocess import argparse +import math import os modulus = 3329 @@ -59,7 +60,46 @@ def signed_reduce(a): c -= modulus return c -def gen_c_zetas(): +def prepare_root_for_montmul(root): + """Takes a constant that the code needs to Montgomery-multiply with, + and returns the pair of (a) the signed canonical representative of its + Montgomery form, (b) the twisted constant used in the low-mul part of + the Montgomery multiplication.""" + + # Convert to Montgomery form and pick canonical signed representative + root = signed_reduce(root * montgomery_factor) + root_twisted = signed_reduce_u16(root * pow(modulus, -1, 2**16)) + return root, root_twisted + +def prepare_root_for_barrett(root, even=True): + """Takes a constant that the code needs to Barrett-multiply with, + and returns the pair of (a) its signed canonical form, (b) the + twisted constant used in the high-mul part of the Barrett multiplication.""" + + # Signed canonical reduction + root = signed_reduce(root) + + def round_to_suitable(t): + if even is True: + rt = round(t) + if rt % 2 == 0: + return rt + # Make sure to pick a rounding target + # that's <= 1 away from x in absolute value. + if rt <= t: + return rt + 1 + return rt - 1 + else: + return math.floor(t) + + root_twisted = round_to_suitable((root * 2**16) / modulus) + + if even is True: + root_twisted = root_twisted // 2 + + return root, root_twisted + +def gen_c_zetas(twisted=False): """Generate source and header file for zeta values used in the reference NTT and invNTT""" @@ -68,7 +108,11 @@ def gen_c_zetas(): zeta = [] for i in range(128): - zeta.append(signed_reduce(pow(root_of_unity, i, modulus) * montgomery_factor)) + root, root_twisted = prepare_root_for_barrett(pow(root_of_unity, i, modulus), even=False) + if twisted is False: + zeta.append(root) + else: + zeta.append(root_twisted) # The source code stores the zeta table in bit reversed form yield from (zeta[bitreverse(i,7)] for i in range(128)) @@ -84,29 +128,12 @@ def gen(): yield from map(lambda t: str(t) + ",", gen_c_zetas()) yield "};" yield "" + yield "const int16_t zetas_twisted[128] = {" + yield from map(lambda t: str(t) + ",", gen_c_zetas(twisted=True)) + yield "};" + yield "" update_file("mlkem/zetas.c", '\n'.join(gen()), dry_run=dry_run) -def prepare_root_for_barrett(root): - """Takes a constant that the code needs to Barrett-multiply with, - and returns the pair of (a) its signed canonical form, (b) the - twisted constant used in the high-mul part of the Barrett multiplication.""" - - # Signed canonical reduction - root = signed_reduce(root) - - def round_to_even(t): - rt = round(t) - if rt % 2 == 0: - return rt - # Make sure to pick a rounding target - # that's <= 1 away from x in absolute value. - if rt <= t: - return rt + 1 - return rt - 1 - - root_twisted = round_to_even((root * 2**16) / modulus) // 2 - return root, root_twisted - def gen_aarch64_root_of_unity_for_block(layer, block, inv=False): # We are computing a negacyclic NTT; the twiddles needed here is # the second half of the twiddles for a cyclic NTT of twice the size. @@ -260,17 +287,6 @@ def signed_reduce_u16(x): x -= 2**16 return x -def prepare_root_for_montmul(root): - """Takes a constant that the code needs to Montgomery-multiply with, - and returns the pair of (a) the signed canonical representative of its - Montgomery form, (b) the twisted constant used in the low-mul part of - the Montgomery multiplication.""" - - # Convert to Montgomery form and pick canonical signed representative - root = signed_reduce(root * montgomery_factor) - root_twisted = signed_reduce_u16(root * pow(modulus, -1, 2**16)) - return root, root_twisted - def gen_avx2_root_of_unity_for_block(layer, block, inv=False): # We are computing a negacyclic NTT; the twiddles needed here is # the second half of the twiddles for a cyclic NTT of twice the size.