Skip to content

Commit

Permalink
C NTT: Use Barrett instead of Montgomery multiplication
Browse files Browse the repository at this point in the history
This commit changes the C implementation of NTT and inverse NTT to
use Barrett multiplication (https://eprint.iacr.org/2021/986) for
multiplications with twiddles.

Using Barrett multiplication requires precomputing twisted twiddles,
which is done by adjusting the twiddle generation script
autogenerate_files.py.

Signed-off-by: Hanno Becker <[email protected]>
  • Loading branch information
hanno-becker committed Dec 1, 2024
1 parent b2c6403 commit 3f747d2
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 55 deletions.
16 changes: 9 additions & 7 deletions mlkem/ntt.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
// 5 -- 7
//
STATIC_TESTABLE
void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int start, int len,
int bound)
void ntt_butterfly_block(int16_t r[MLKEM_N], int16_t zeta, int16_t zeta_twisted,
int start, int len, int bound)
__contract__(
requires(0 <= start && start < MLKEM_N)
requires(1 <= len && len <= MLKEM_N / 2 && start + 2 * len <= MLKEM_N)
Expand All @@ -69,7 +69,7 @@ __contract__(
invariant(array_abs_bound(r, j + len, MLKEM_N - 1, bound)))
{
int16_t t;
t = fqmul(r[j + len], zeta);
t = fqmul_bar(r[j + len], zeta, zeta_twisted);
r[j + len] = r[j] - t;
r[j] = r[j] + t;
}
Expand Down Expand Up @@ -107,8 +107,9 @@ __contract__(
invariant(array_abs_bound(r, 0, start - 1, (layer * MLKEM_Q - 1) + MLKEM_Q))
invariant(array_abs_bound(r, start, MLKEM_N - 1, layer * MLKEM_Q - 1)))
{
int16_t zeta = zetas[k++];
ntt_butterfly_block(r, zeta, start, len, layer * MLKEM_Q - 1);
int16_t zeta = zetas[k];
uint16_t zeta_twisted = zetas_twisted[k++];
ntt_butterfly_block(r, zeta, zeta_twisted, start, len, layer * MLKEM_Q - 1);
}
}

Expand Down Expand Up @@ -178,7 +179,8 @@ __contract__(
// Normalised form of k == MLKEM_N / len - 1 - start / (2 * len)
invariant(2 * len * k + start == 2 * MLKEM_N - 2 * len))
{
int16_t zeta = zetas[k--];
int16_t zeta = zetas[k];
uint16_t zeta_twisted = zetas_twisted[k--];
for (int j = start; j < start + len; j++)
__loop__(
invariant(start <= j && j <= start + len)
Expand All @@ -188,7 +190,7 @@ __contract__(
int16_t t = r[j];
r[j] = barrett_reduce(t + r[j + len]);
r[j + len] = r[j + len] - t;
r[j + len] = fqmul(r[j + len], zeta);
r[j + len] = fqmul_bar(r[j + len], zeta, zeta_twisted);
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions mlkem/ntt.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#define zetas MLKEM_NAMESPACE(zetas)
extern const int16_t zetas[128];

#define zetas_twisted MLKEM_NAMESPACE(zetas_twisted)
extern const int16_t zetas_twisted[128];

/*************************************************
* Name: poly_ntt
*
Expand Down
6 changes: 4 additions & 2 deletions mlkem/poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,10 @@ void poly_mulcache_compute(poly_mulcache *x, const poly *a)
for (i = 0; i < MLKEM_N / 4; i++)
__loop__(invariant(i >= 0 && i <= MLKEM_N / 4))
{
x->coeffs[2 * i + 0] = fqmul(a->coeffs[4 * i + 1], zetas[64 + i]);
x->coeffs[2 * i + 1] = fqmul(a->coeffs[4 * i + 3], -zetas[64 + i]);
x->coeffs[2 * i + 0] =
fqmul_bar(a->coeffs[4 * i + 1], zetas[64 + i], zetas_twisted[64 + i]);
x->coeffs[2 * i + 1] =
fqmul_bar(a->coeffs[4 * i + 3], -zetas[64 + i], -zetas_twisted[64 + i]);
}
POLY_BOUND(x, MLKEM_Q);
}
Expand Down
13 changes: 13 additions & 0 deletions mlkem/reduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,19 @@ int16_t fqmul(int16_t a, int16_t b)
return res;
}

int16_t fqmul_bar(int16_t a, int16_t b, int16_t b_twisted)
{
SCALAR_BOUND(b, HALF_Q, "fqmul input");

int16_t quot = ((int32_t)a * b_twisted) >> 16;
uint16_t prod_low = a * b;
uint16_t round_low = quot * MLKEM_Q;
uint16_t r = prod_low - round_low;

SCALAR_BOUND(r, MLKEM_Q, "fqmul output");
return (int16_t)r;
}

// To divide by MLKEM_Q using Barrett multiplication, the "magic number"
// multiplier is round_to_nearest(2**26/MLKEM_Q)
#define BPOWER 26
Expand Down
18 changes: 18 additions & 0 deletions mlkem/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,23 @@ __contract__(
ensures(return_value > -MLKEM_Q && return_value < MLKEM_Q)
);

/*************************************************
* Name: fqmul
*
* Description: Barrett multiplication modulo q=3329
* (https://eprint.iacr.org/2021/986)
*
* Arguments: - int16_t a: first factor
* Can be any int16_t.
* - int16_t b: second factor.
* Must be signed canonical (abs value <(q+1)/2)
* - int16_t b_twisted: Barrett twist of second factor
*
* Returns 16-bit integer congruent to a*b*R^{-1} mod q, and
* smaller than q in absolute value.
*
**************************************************/
#define fqmul_bar MLKEM_NAMESPACE(fqmul_bar)
int16_t fqmul_bar(int16_t a, int16_t b, int16_t b_twisted);

#endif
42 changes: 30 additions & 12 deletions mlkem/zetas.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,34 @@
// Table of zeta values used in the reference NTT and inverse NTT.
// See autogenerate_files.py for details.
const int16_t zetas[128] = {
-1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577,
182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458,
-1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, 1223,
652, -552, 1015, -1293, 1491, -282, -1544, 516, -8, -320, -666,
-1618, -1162, 126, 1469, -853, -90, -271, 830, 107, -1421, -247,
-951, -398, 961, -1508, -725, 448, -1065, 677, -1275, -1103, 430,
555, 843, -1251, 871, 1550, 105, 422, 587, 177, -235, -291,
-460, 1574, 1653, -246, 778, 1159, -147, -777, 1483, -602, 1119,
-1590, 644, -872, 349, 418, 329, -156, -75, 817, 1097, 603,
610, 1322, -1285, -1465, 384, -1215, -136, 1218, -1335, -874, 220,
-1187, -1659, -1185, -1530, -1278, 794, -1510, -854, -870, 478, -108,
-308, 996, 991, 958, -1460, 1522, 1628,
1, -1600, -749, -40, -687, 630, -1432, 848, 1062, -1410, 193,
797, -543, -69, 569, -1583, 296, -882, 1339, 1476, -283, 56,
-1089, 1333, 1426, -1235, 535, -447, -936, -450, -1355, 821, 289,
331, -76, -1573, 1197, -1025, -1052, -1274, 650, -1352, -816, 632,
-464, 33, 1320, -1414, -1010, 1435, 807, 452, 1438, -461, 1534,
-927, -682, -712, 1481, 648, -855, -219, 1227, 910, 17, -568,
583, -680, 1637, 723, -1041, 1100, 1409, -667, -48, 233, 756,
-1173, -314, -279, -1626, 1651, -540, -1540, -1482, 952, 1461, -642,
939, -1021, -892, -941, 733, -992, 268, 641, 1584, -1031, -1292,
-109, 375, -780, -1239, 1645, 1063, 319, -556, 757, -1230, 561,
-863, -735, -525, 1092, 403, 1026, 1143, -1179, -554, 886, -1607,
1212, -1455, 1029, -1219, -394, 885, -1175,
};

const int16_t zetas_twisted[128] = {
19, -31499, -14746, -788, -13525, 12402, -28191, 16694, 20906,
-27758, 3799, 15690, -10690, -1359, 11201, -31164, 5827, -17364,
26360, 29057, -5572, 1102, -21439, 26241, 28072, -24313, 10532,
-8800, -18427, -8859, -26676, 16162, 5689, 6516, -1497, -30967,
23564, -20179, -20711, -25081, 12796, -26617, -16065, 12441, -9135,
649, 25986, -27837, -19884, 28249, 15886, 8898, 28309, -9076,
30198, -18250, -13427, -14017, 29155, 12756, -16832, -4312, 24155,
17914, 334, -11182, 11477, -13387, 32226, 14233, -20494, 21655,
27738, -13131, -945, 4586, 14882, -23093, -6182, -5493, -32011,
32502, -10631, -30318, -29176, 18741, 28761, -12639, 18485, -20100,
-17561, -18525, 14430, -19529, 5275, 12618, 31183, -20297, -25435,
-2146, 7382, -15356, -24392, 32384, 20926, 6279, -10946, 14902,
-24215, 11044, -16990, -14470, -10336, 21497, 7933, 20198, 22501,
-23211, -10907, 17442, -31637, 23859, -28644, 20257, -23998, -7757,
17422, -23132,
};
84 changes: 50 additions & 34 deletions scripts/autogenerate_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import subprocess
import argparse
import math
import os

modulus = 3329
Expand Down Expand Up @@ -59,7 +60,46 @@ def signed_reduce(a):
c -= modulus
return c

def gen_c_zetas():
def prepare_root_for_montmul(root):
"""Takes a constant that the code needs to Montgomery-multiply with,
and returns the pair of (a) the signed canonical representative of its
Montgomery form, (b) the twisted constant used in the low-mul part of
the Montgomery multiplication."""

# Convert to Montgomery form and pick canonical signed representative
root = signed_reduce(root * montgomery_factor)
root_twisted = signed_reduce_u16(root * pow(modulus, -1, 2**16))
return root, root_twisted

def prepare_root_for_barrett(root, even=True):
"""Takes a constant that the code needs to Barrett-multiply with,
and returns the pair of (a) its signed canonical form, (b) the
twisted constant used in the high-mul part of the Barrett multiplication."""

# Signed canonical reduction
root = signed_reduce(root)

def round_to_suitable(t):
if even is True:
rt = round(t)
if rt % 2 == 0:
return rt
# Make sure to pick a rounding target
# that's <= 1 away from x in absolute value.
if rt <= t:
return rt + 1
return rt - 1
else:
return math.floor(t)

root_twisted = round_to_suitable((root * 2**16) / modulus)

if even is True:
root_twisted = root_twisted // 2

return root, root_twisted

def gen_c_zetas(twisted=False):
"""Generate source and header file for zeta values used in
the reference NTT and invNTT"""

Expand All @@ -68,7 +108,11 @@ def gen_c_zetas():

zeta = []
for i in range(128):
zeta.append(signed_reduce(pow(root_of_unity, i, modulus) * montgomery_factor))
root, root_twisted = prepare_root_for_barrett(pow(root_of_unity, i, modulus), even=False)
if twisted is False:
zeta.append(root)
else:
zeta.append(root_twisted)

# The source code stores the zeta table in bit reversed form
yield from (zeta[bitreverse(i,7)] for i in range(128))
Expand All @@ -84,29 +128,12 @@ def gen():
yield from map(lambda t: str(t) + ",", gen_c_zetas())
yield "};"
yield ""
yield "const int16_t zetas_twisted[128] = {"
yield from map(lambda t: str(t) + ",", gen_c_zetas(twisted=True))
yield "};"
yield ""
update_file("mlkem/zetas.c", '\n'.join(gen()), dry_run=dry_run)

def prepare_root_for_barrett(root):
"""Takes a constant that the code needs to Barrett-multiply with,
and returns the pair of (a) its signed canonical form, (b) the
twisted constant used in the high-mul part of the Barrett multiplication."""

# Signed canonical reduction
root = signed_reduce(root)

def round_to_even(t):
rt = round(t)
if rt % 2 == 0:
return rt
# Make sure to pick a rounding target
# that's <= 1 away from x in absolute value.
if rt <= t:
return rt + 1
return rt - 1

root_twisted = round_to_even((root * 2**16) / modulus) // 2
return root, root_twisted

def gen_aarch64_root_of_unity_for_block(layer, block, inv=False):
# We are computing a negacyclic NTT; the twiddles needed here is
# the second half of the twiddles for a cyclic NTT of twice the size.
Expand Down Expand Up @@ -260,17 +287,6 @@ def signed_reduce_u16(x):
x -= 2**16
return x

def prepare_root_for_montmul(root):
"""Takes a constant that the code needs to Montgomery-multiply with,
and returns the pair of (a) the signed canonical representative of its
Montgomery form, (b) the twisted constant used in the low-mul part of
the Montgomery multiplication."""

# Convert to Montgomery form and pick canonical signed representative
root = signed_reduce(root * montgomery_factor)
root_twisted = signed_reduce_u16(root * pow(modulus, -1, 2**16))
return root, root_twisted

def gen_avx2_root_of_unity_for_block(layer, block, inv=False):
# We are computing a negacyclic NTT; the twiddles needed here is
# the second half of the twiddles for a cyclic NTT of twice the size.
Expand Down

0 comments on commit 3f747d2

Please sign in to comment.