From 80f3f17c4d3778a5751957164907fe598dedec74 Mon Sep 17 00:00:00 2001
From: "Thing-han, Lim" <15379156+potsrevennil@users.noreply.github.com>
Date: Wed, 4 Dec 2024 11:45:12 +0800
Subject: [PATCH 1/2] ci: fix Python files detection in format and lint scripts
Initially, the --include scripts/tests option was used with the Python formatter black
because scripts/tests was the only Python script. However, this flag overrides the
default file detection pattern, meaning only scripts/tests was being formatted.
As more .py files have been added, this setup became insufficient.
This commit updates the regex to include both scripts/tests and all .py files,
ensuring proper detection and formatting of all Python scripts in the repository.
Signed-off-by: Thing-han, Lim <15379156+potsrevennil@users.noreply.github.com>
---
scripts/ci/lint | 2 +-
scripts/format | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/scripts/ci/lint b/scripts/ci/lint
index 3c33aeb13..bcb550d4e 100755
--- a/scripts/ci/lint
+++ b/scripts/ci/lint
@@ -48,7 +48,7 @@ checkerr "Lint shell" "$(shfmt -s -l -i 2 -ci -fn $(shfmt -f $(git grep -l '' :/
echo "::endgroup::"
echo "::group::Linting python scripts with black"
-if ! diff=$(black --check --diff -q --include scripts/tests "$ROOT"); then
+if ! diff=$(black --check --diff -q --include "(scripts/tests|\.py$)" "$ROOT"); then
echo "::error title=Format error::$diff"
SUCCESS=false
echo ":x: Lint python" >>"$GITHUB_STEP_SUMMARY"
diff --git a/scripts/format b/scripts/format
index 267811204..486ad7e30 100755
--- a/scripts/format
+++ b/scripts/format
@@ -26,7 +26,7 @@ info "Formatting shell scripts"
shfmt -s -w -l -i 2 -ci -fn $(shfmt -f $(git grep -l '' :/))
info "Formatting python scripts"
-black --include scripts/tests "$ROOT"
+black --include "(scripts/tests|\.py$)" "$ROOT"
info "Formatting c files"
clang-format -i $(git ls-files ":/*.c" ":/*.h")
From 372e74a9f08708848f3416e5153e358abac9e891 Mon Sep 17 00:00:00 2001
From: "Thing-han, Lim" <15379156+potsrevennil@users.noreply.github.com>
Date: Wed, 4 Dec 2024 12:26:09 +0800
Subject: [PATCH 2/2] style: format all python scripts
Signed-off-by: Thing-han, Lim <15379156+potsrevennil@users.noreply.github.com>
---
cbmc/proofs/lib/print_tool_versions.py | 8 +-
cbmc/proofs/lib/summarize.py | 18 +-
cbmc/proofs/run-cbmc-proofs.py | 228 +++++++++++++++--------
scripts/autogenerate_files.py | 240 +++++++++++++++++--------
scripts/lib/mlkem_test.py | 12 +-
test/test_bounds.py | 74 +++++---
6 files changed, 392 insertions(+), 188 deletions(-)
diff --git a/cbmc/proofs/lib/print_tool_versions.py b/cbmc/proofs/lib/print_tool_versions.py
index bdeb429e3..c97757b41 100755
--- a/cbmc/proofs/lib/print_tool_versions.py
+++ b/cbmc/proofs/lib/print_tool_versions.py
@@ -29,11 +29,12 @@ def _format_versions(table):
if version:
v_str = f'{version}
'
else:
- v_str = 'not found'
+ v_str = "not found"
lines.append(
f'
{tool}: | '
- f'{v_str} |
')
+ f"{v_str} | "
+ )
lines.append("")
return "\n".join(lines)
@@ -55,7 +56,8 @@ def _get_tool_versions():
continue
if proc.returncode:
logging.error(
- "%s'%s --version' returned %s", err, tool, str(proc.returncode))
+ "%s'%s --version' returned %s", err, tool, str(proc.returncode)
+ )
continue
ret[tool] = out.strip()
return ret
diff --git a/cbmc/proofs/lib/summarize.py b/cbmc/proofs/lib/summarize.py
index 5dbd230c8..f90508e17 100644
--- a/cbmc/proofs/lib/summarize.py
+++ b/cbmc/proofs/lib/summarize.py
@@ -15,11 +15,13 @@
def get_args():
"""Parse arguments for summarize script."""
parser = argparse.ArgumentParser(description=DESCRIPTION)
- for arg in [{
+ for arg in [
+ {
"flags": ["--run-file"],
"help": "path to the Litani run.json file",
"required": True,
- }]:
+ }
+ ]:
flags = arg.pop("flags")
parser.add_argument(*flags, **arg)
return parser.parse_args()
@@ -111,7 +113,7 @@ def print_proof_results(out_file):
When printing, each string will render as a GitHub flavored Markdown table.
"""
output = "## Summary of CBMC proof results\n\n"
- with open(out_file, encoding='utf-8') as run_json:
+ with open(out_file, encoding="utf-8") as run_json:
run_dict = json.load(run_json)
status_table, proof_table = _get_status_and_proof_summaries(run_dict)
for summary in (status_table, proof_table):
@@ -126,12 +128,12 @@ def print_proof_results(out_file):
print(output, file=handle)
handle.flush()
else:
- logging.warning(
- "$GITHUB_STEP_SUMMARY not set, not writing summary file")
+ logging.warning("$GITHUB_STEP_SUMMARY not set, not writing summary file")
msg = (
"Click the 'Summary' button to view a Markdown table "
- "summarizing all proof results")
+ "summarizing all proof results"
+ )
if run_dict["status"] != "success":
logging.error("Not all proofs passed.")
logging.error(msg)
@@ -139,10 +141,10 @@ def print_proof_results(out_file):
logging.info(msg)
-if __name__ == '__main__':
+if __name__ == "__main__":
args = get_args()
logging.basicConfig(format="%(levelname)s: %(message)s")
try:
print_proof_results(args.run_file)
- except Exception as ex: # pylint: disable=broad-except
+ except Exception as ex: # pylint: disable=broad-except
logging.critical("Could not print results. Exception: %s", str(ex))
diff --git a/cbmc/proofs/run-cbmc-proofs.py b/cbmc/proofs/run-cbmc-proofs.py
index 557e65e8e..f2db280f3 100755
--- a/cbmc/proofs/run-cbmc-proofs.py
+++ b/cbmc/proofs/run-cbmc-proofs.py
@@ -71,11 +71,14 @@ def get_project_name():
cmd = [
"make",
"--no-print-directory",
- "-f", "Makefile.common",
+ "-f",
+ "Makefile.common",
"echo-project-name",
]
logging.debug(" ".join(cmd))
- proc = subprocess.run(cmd, universal_newlines=True, stdout=subprocess.PIPE, check=False)
+ proc = subprocess.run(
+ cmd, universal_newlines=True, stdout=subprocess.PIPE, check=False
+ )
if proc.returncode:
logging.critical("could not run make to determine project name")
sys.exit(1)
@@ -83,84 +86,103 @@ def get_project_name():
logging.warning(
"project name has not been set; using generic name instead. "
"Set the PROJECT_NAME value in Makefile-project-defines to "
- "remove this warning")
+ "remove this warning"
+ )
return ""
return proc.stdout.strip()
def get_args():
pars = argparse.ArgumentParser(
- description=DESCRIPTION, epilog=EPILOG,
- formatter_class=argparse.RawDescriptionHelpFormatter)
- for arg in [{
+ description=DESCRIPTION,
+ epilog=EPILOG,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ for arg in [
+ {
"flags": ["-j", "--parallel-jobs"],
"type": int,
"metavar": "N",
"help": "run at most N proof jobs in parallel",
- }, {
+ },
+ {
"flags": ["--fail-on-proof-failure"],
"action": "store_true",
"help": "exit with return code `10' if any proof failed"
- " (default: exit 0)",
- }, {
+ " (default: exit 0)",
+ },
+ {
"flags": ["--no-standalone"],
"action": "store_true",
"help": "only configure proofs: do not initialize nor run",
- }, {
+ },
+ {
"flags": ["-p", "--proofs"],
"nargs": "+",
"metavar": "DIR",
"help": "only run proof in directory DIR (can pass more than one)",
- }, {
+ },
+ {
"flags": ["--project-name"],
"metavar": "NAME",
"default": get_project_name(),
"help": "project name for report. Default: %(default)s",
- }, {
+ },
+ {
"flags": ["--marker-file"],
"metavar": "FILE",
"default": "cbmc-proof.txt",
"help": (
- "name of file that marks proof directories. Default: "
- "%(default)s"),
- }, {
+ "name of file that marks proof directories. Default: " "%(default)s"
+ ),
+ },
+ {
"flags": ["--no-memory-profile"],
"action": "store_true",
- "help": "disable memory profiling, even if Litani supports it"
- }, {
+ "help": "disable memory profiling, even if Litani supports it",
+ },
+ {
"flags": ["--no-expensive-limit"],
"action": "store_true",
"help": "do not limit parallelism of 'EXPENSIVE' jobs",
- }, {
+ },
+ {
"flags": ["--expensive-jobs-parallelism"],
"metavar": "N",
"default": 1,
"type": int,
"help": (
"how many proof jobs marked 'EXPENSIVE' to run in parallel. "
- "Default: %(default)s"),
- }, {
+ "Default: %(default)s"
+ ),
+ },
+ {
"flags": ["--verbose"],
"action": "store_true",
"help": "verbose output",
- }, {
+ },
+ {
"flags": ["--debug"],
"action": "store_true",
"help": "debug output",
- }, {
+ },
+ {
"flags": ["--summarize"],
"action": "store_true",
"help": "summarize proof results with two tables on stdout",
- }, {
+ },
+ {
"flags": ["--version"],
"action": "version",
"version": "CBMC starter kit 2.10",
- "help": "display version and exit"
- }, {
+ "help": "display version and exit",
+ },
+ {
"flags": ["--no-coverage"],
"action": "store_true",
- "help": "do property checking without coverage checking"
- }]:
+ "help": "do property checking without coverage checking",
+ },
+ ]:
flags = arg.pop("flags")
pars.add_argument(*flags, **arg)
return pars.parse_args()
@@ -171,8 +193,7 @@ def set_up_logging(verbose):
level = logging.DEBUG
else:
level = logging.WARNING
- logging.basicConfig(
- format="run-cbmc-proofs: %(message)s", level=level)
+ logging.basicConfig(format="run-cbmc-proofs: %(message)s", level=level)
def task_pool_size():
@@ -184,8 +205,12 @@ def task_pool_size():
def print_counter(counter):
# pylint: disable=consider-using-f-string
- print("\rConfiguring CBMC proofs: "
- "{complete:{width}} / {total:{width}}".format(**counter), end="", file=sys.stderr)
+ print(
+ "\rConfiguring CBMC proofs: "
+ "{complete:{width}} / {total:{width}}".format(**counter),
+ end="",
+ file=sys.stderr,
+ )
def get_proof_dirs(proof_root, proof_list, marker_file):
@@ -207,8 +232,8 @@ def get_proof_dirs(proof_root, proof_list, marker_file):
if proofs_remaining:
logging.critical(
- "The following proofs were not found: %s",
- ", ".join(proofs_remaining))
+ "The following proofs were not found: %s", ", ".join(proofs_remaining)
+ )
sys.exit(1)
@@ -237,16 +262,20 @@ def run_build(litani, jobs, fail_on_proof_failure, summarize):
logging.error("One or more proofs failed")
sys.exit(10)
+
def get_litani_path(proof_root):
cmd = [
"make",
"--no-print-directory",
f"PROOF_ROOT={proof_root}",
- "-f", "Makefile.common",
+ "-f",
+ "Makefile.common",
"litani-path",
]
logging.debug(" ".join(cmd))
- proc = subprocess.run(cmd, universal_newlines=True, stdout=subprocess.PIPE, check=False)
+ proc = subprocess.run(
+ cmd, universal_newlines=True, stdout=subprocess.PIPE, check=False
+ )
if proc.returncode:
logging.critical("Could not determine path to litani")
sys.exit(1)
@@ -256,7 +285,8 @@ def get_litani_path(proof_root):
def get_litani_capabilities(litani_path):
cmd = [litani_path, "print-capabilities"]
proc = subprocess.run(
- cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=False)
+ cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=False
+ )
if proc.returncode:
return []
try:
@@ -279,11 +309,14 @@ def check_uid_uniqueness(proof_dir, proof_uids):
logging.critical(
"The Makefile in directory '%s' should have a different "
"PROOF_UID than the Makefile in directory '%s'",
- proof_dir, proof_uids[match["uid"]])
+ proof_dir,
+ proof_uids[match["uid"]],
+ )
sys.exit(1)
logging.critical(
- "The Makefile in directory '%s' should contain a line like", proof_dir)
+ "The Makefile in directory '%s' should contain a line like", proof_dir
+ )
logging.critical("PROOF_UID = ...")
logging.critical("with a unique identifier for the proof.")
sys.exit(1)
@@ -301,8 +334,15 @@ def should_enable_pools(litani_caps, args):
return "pools" in litani_caps
-async def configure_proof_dirs( # pylint: disable=too-many-arguments
- queue, counter, proof_uids, enable_pools, enable_memory_profiling, report_target, debug):
+async def configure_proof_dirs( # pylint: disable=too-many-arguments
+ queue,
+ counter,
+ proof_uids,
+ enable_pools,
+ enable_memory_profiling,
+ report_target,
+ debug,
+):
while True:
print_counter(counter)
path = str(await queue.get())
@@ -310,19 +350,32 @@ async def configure_proof_dirs( # pylint: disable=too-many-arguments
check_uid_uniqueness(path, proof_uids)
pools = ["ENABLE_POOLS=true"] if enable_pools else []
- profiling = [
- "ENABLE_MEMORY_PROFILING=true"] if enable_memory_profiling else []
+ profiling = ["ENABLE_MEMORY_PROFILING=true"] if enable_memory_profiling else []
# delete old reports
proc = await asyncio.create_subprocess_exec(
- "make", "veryclean", cwd=path,
- stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
+ "make",
+ "veryclean",
+ cwd=path,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ )
# Allow interactive tasks to preempt proof configuration
proc = await asyncio.create_subprocess_exec(
- "nice", "-n", "15", "make", *pools,
- *profiling, "-B", report_target, "" if debug else "--quiet", cwd=path,
- stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
+ "nice",
+ "-n",
+ "15",
+ "make",
+ *pools,
+ *profiling,
+ "-B",
+ report_target,
+ "" if debug else "--quiet",
+ cwd=path,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ )
stdout, stderr = await proc.communicate()
logging.debug("returncode: %s", str(proc.returncode))
logging.debug("stdout:")
@@ -341,13 +394,20 @@ async def configure_proof_dirs( # pylint: disable=too-many-arguments
def add_tool_version_job():
cmd = [
- "litani", "add-job",
- "--command", "./lib/print_tool_versions.py .",
- "--description", "printing out tool versions",
- "--phony-outputs", str(uuid.uuid4()),
- "--pipeline-name", "print_tool_versions",
- "--ci-stage", "report",
- "--tags", "front-page-text",
+ "litani",
+ "add-job",
+ "--command",
+ "./lib/print_tool_versions.py .",
+ "--description",
+ "printing out tool versions",
+ "--phony-outputs",
+ str(uuid.uuid4()),
+ "--pipeline-name",
+ "print_tool_versions",
+ "--ci-stage",
+ "report",
+ "--tags",
+ "front-page-text",
]
proc = subprocess.run(cmd)
if proc.returncode:
@@ -355,7 +415,7 @@ def add_tool_version_job():
sys.exit(1)
-async def main(): # pylint: disable=too-many-locals
+async def main(): # pylint: disable=too-many-locals
args = get_args()
set_up_logging(args.verbose)
@@ -364,13 +424,19 @@ async def main(): # pylint: disable=too-many-locals
litani_caps = get_litani_capabilities(litani)
enable_pools = should_enable_pools(litani_caps, args)
- init_pools = [
- "--pools", f"expensive:{args.expensive_jobs_parallelism}"
- ] if enable_pools else []
+ init_pools = (
+ ["--pools", f"expensive:{args.expensive_jobs_parallelism}"]
+ if enable_pools
+ else []
+ )
if not args.no_standalone:
cmd = [
- str(litani), "init", *init_pools, "--project", args.project_name,
+ str(litani),
+ "init",
+ *init_pools,
+ "--project",
+ args.project_name,
"--no-print-out-dir",
]
@@ -378,13 +444,19 @@ async def main(): # pylint: disable=too-many-locals
out_prefix = proof_root / "output"
out_symlink = out_prefix / "latest"
out_index = out_symlink / "html" / "index.html"
- cmd.extend([
- "--output-prefix", str(out_prefix),
- "--output-symlink", str(out_symlink),
- ])
+ cmd.extend(
+ [
+ "--output-prefix",
+ str(out_prefix),
+ "--output-symlink",
+ str(out_symlink),
+ ]
+ )
print(
"\nFor your convenience, the output of this run will be symbolically linked to ",
- out_index, "\n")
+ out_index,
+ "\n",
+ )
logging.debug(" ".join(cmd))
proc = subprocess.run(cmd, check=False)
@@ -392,8 +464,7 @@ async def main(): # pylint: disable=too-many-locals
logging.critical("Failed to run litani init")
sys.exit(1)
- proof_dirs = list(get_proof_dirs(
- proof_root, args.proofs, args.marker_file))
+ proof_dirs = list(get_proof_dirs(proof_root, args.proofs, args.marker_file))
if not proof_dirs:
logging.critical("No proof directories found")
sys.exit(1)
@@ -407,7 +478,7 @@ async def main(): # pylint: disable=too-many-locals
"fail": [],
"complete": 0,
"total": len(proof_dirs),
- "width": int(math.log10(len(proof_dirs))) + 1
+ "width": int(math.log10(len(proof_dirs))) + 1,
}
proof_uids = {}
@@ -417,9 +488,17 @@ async def main(): # pylint: disable=too-many-locals
report_target = "_report_no_coverage" if args.no_coverage else "_report"
for _ in range(task_pool_size()):
- task = asyncio.create_task(configure_proof_dirs(
- proof_queue, counter, proof_uids, enable_pools,
- enable_memory_profiling, report_target, args.debug))
+ task = asyncio.create_task(
+ configure_proof_dirs(
+ proof_queue,
+ counter,
+ proof_uids,
+ enable_pools,
+ enable_memory_profiling,
+ report_target,
+ args.debug,
+ )
+ )
tasks.append(task)
await proof_queue.join()
@@ -431,12 +510,15 @@ async def main(): # pylint: disable=too-many-locals
if counter["fail"]:
logging.critical(
- "Failed to configure the following proofs:\n%s", "\n".join(
- [str(f) for f in counter["fail"]]))
+ "Failed to configure the following proofs:\n%s",
+ "\n".join([str(f) for f in counter["fail"]]),
+ )
sys.exit(1)
if not args.no_standalone:
- run_build(litani, args.parallel_jobs, args.fail_on_proof_failure, args.summarize)
+ run_build(
+ litani, args.parallel_jobs, args.fail_on_proof_failure, args.summarize
+ )
if __name__ == "__main__":
diff --git a/scripts/autogenerate_files.py b/scripts/autogenerate_files.py
index 6d3ed0cd7..0e23f69db 100644
--- a/scripts/autogenerate_files.py
+++ b/scripts/autogenerate_files.py
@@ -15,6 +15,7 @@
# It currently covers:
# - zeta values for the reference NTT and invNTT
+
def gen_header():
yield "/*"
yield " * Copyright (c) 2024 The mlkem-native project authors"
@@ -27,12 +28,15 @@ def gen_header():
yield " */"
yield ""
+
def update_file(filename, content, dry_run=False):
# Format content
p = subprocess.run(["clang-format"], capture_output=True, input=content, text=True)
if p.returncode != 0:
- print(f"Failed to auto-format autogenerated code (clang-format return code {p.returncode}")
+ print(
+ f"Failed to auto-format autogenerated code (clang-format return code {p.returncode}"
+ )
exit(1)
content = p.stdout
@@ -46,16 +50,20 @@ def update_file(filename, content, dry_run=False):
with open(filename, "r") as f:
current_content = f.read()
if current_content != content:
- print(f"Autogenerated file {filename} needs updating. Have you called scripts/autogenerated.py?")
+ print(
+ f"Autogenerated file {filename} needs updating. Have you called scripts/autogenerated.py?"
+ )
exit(1)
-def bitreverse(i,n):
+
+def bitreverse(i, n):
r = 0
for _ in range(n):
- r = 2*r + (i & 1)
+ r = 2 * r + (i & 1)
i >>= 1
return r
+
def signed_reduce(a):
"""Return signed canonical representative of a mod b"""
c = a % modulus
@@ -63,6 +71,7 @@ def signed_reduce(a):
c -= modulus
return c
+
def gen_c_zetas():
"""Generate source and header file for zeta values used in
the reference NTT and invNTT"""
@@ -75,12 +84,13 @@ def gen_c_zetas():
zeta.append(signed_reduce(pow(root_of_unity, i, modulus) * montgomery_factor))
# The source code stores the zeta table in bit reversed form
- yield from (zeta[bitreverse(i,7)] for i in range(128))
+ yield from (zeta[bitreverse(i, 7)] for i in range(128))
+
def gen_c_zeta_file(dry_run=False):
def gen():
yield from gen_header()
- yield "#include \"ntt.h\""
+ yield '#include "ntt.h"'
yield ""
yield "/*"
yield " * Table of zeta values used in the reference NTT and inverse NTT."
@@ -90,7 +100,9 @@ def gen():
yield from map(lambda t: str(t) + ",", gen_c_zetas())
yield "};"
yield ""
- update_file("mlkem/zetas.c", '\n'.join(gen()), dry_run=dry_run)
+
+ update_file("mlkem/zetas.c", "\n".join(gen()), dry_run=dry_run)
+
def prepare_root_for_barrett(root):
"""Takes a constant that the code needs to Barrett-multiply with,
@@ -113,121 +125,181 @@ def round_to_even(t):
root_twisted = round_to_even((root * 2**16) / modulus) // 2
return root, root_twisted
+
def gen_aarch64_root_of_unity_for_block(layer, block, inv=False):
# We are computing a negacyclic NTT; the twiddles needed here is
# the second half of the twiddles for a cyclic NTT of twice the size.
- log = bitreverse(pow(2,layer) + block, 7)
+ log = bitreverse(pow(2, layer) + block, 7)
if inv is True:
log = -log
root, root_twisted = prepare_root_for_barrett(pow(root_of_unity, log, modulus))
return root, root_twisted
+
def gen_aarch64_fwd_ntt_zetas_layer01234():
# Layers 0,1,2 are merged
- yield from gen_aarch64_root_of_unity_for_block(0,0)
- yield from gen_aarch64_root_of_unity_for_block(1,0)
- yield from gen_aarch64_root_of_unity_for_block(1,1)
- yield from gen_aarch64_root_of_unity_for_block(2,0)
- yield from gen_aarch64_root_of_unity_for_block(2,1)
- yield from gen_aarch64_root_of_unity_for_block(2,2)
- yield from gen_aarch64_root_of_unity_for_block(2,3)
- yield from (0,0) # Padding
+ yield from gen_aarch64_root_of_unity_for_block(0, 0)
+ yield from gen_aarch64_root_of_unity_for_block(1, 0)
+ yield from gen_aarch64_root_of_unity_for_block(1, 1)
+ yield from gen_aarch64_root_of_unity_for_block(2, 0)
+ yield from gen_aarch64_root_of_unity_for_block(2, 1)
+ yield from gen_aarch64_root_of_unity_for_block(2, 2)
+ yield from gen_aarch64_root_of_unity_for_block(2, 3)
+ yield from (0, 0) # Padding
# Layers 3,4,5,6 are merged, but we emit roots for 3,4
# in separate arrays than those for 5,6
- for block in range(8): # There are 8 blocks in Layer 4
- yield from gen_aarch64_root_of_unity_for_block(3,block)
- yield from gen_aarch64_root_of_unity_for_block(4,2*block+0)
- yield from gen_aarch64_root_of_unity_for_block(4,2*block+1)
- yield from (0,0) # Padding
+ for block in range(8): # There are 8 blocks in Layer 4
+ yield from gen_aarch64_root_of_unity_for_block(3, block)
+ yield from gen_aarch64_root_of_unity_for_block(4, 2 * block + 0)
+ yield from gen_aarch64_root_of_unity_for_block(4, 2 * block + 1)
+ yield from (0, 0) # Padding
+
def gen_aarch64_fwd_ntt_zetas_layer56():
# Layers 3,4,5,6 are merged, but we emit roots for 3,4
# in separate arrays than those for 5,6
for block in range(8):
+
def double_ith(t, i):
yield from (t[i], t[i])
+
# Ordering of blocks is adjusted to suit the transposed internal
# presentation of the data
for i in range(2):
- yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+0), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+1), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+2), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+3), i)
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(5, 4 * block + 0), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(5, 4 * block + 1), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(5, 4 * block + 2), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(5, 4 * block + 3), i
+ )
for i in range(2):
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+0), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+2), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+4), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+6), i)
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 0), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 2), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 4), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 6), i
+ )
for i in range(2):
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+1), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+3), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+5), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+7), i)
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 1), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 3), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 5), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 7), i
+ )
+
def gen_aarch64_inv_ntt_zetas_layer01234():
# Layers 3,4,5,6 are merged, but we emit roots for 3,4
# in separate arrays than those for 5,6
- for block in range(8): # There are 8 blocks in Layer 4
- yield from gen_aarch64_root_of_unity_for_block(3,block,inv=True)
- yield from gen_aarch64_root_of_unity_for_block(4,2*block+0,inv=True)
- yield from gen_aarch64_root_of_unity_for_block(4,2*block+1,inv=True)
- yield from (0,0) # Padding
+ for block in range(8): # There are 8 blocks in Layer 4
+ yield from gen_aarch64_root_of_unity_for_block(3, block, inv=True)
+ yield from gen_aarch64_root_of_unity_for_block(4, 2 * block + 0, inv=True)
+ yield from gen_aarch64_root_of_unity_for_block(4, 2 * block + 1, inv=True)
+ yield from (0, 0) # Padding
# Layers 0,1,2 are merged
- yield from gen_aarch64_root_of_unity_for_block(0,0,inv=True)
- yield from gen_aarch64_root_of_unity_for_block(1,0,inv=True)
- yield from gen_aarch64_root_of_unity_for_block(1,1,inv=True)
- yield from gen_aarch64_root_of_unity_for_block(2,0,inv=True)
- yield from gen_aarch64_root_of_unity_for_block(2,1,inv=True)
- yield from gen_aarch64_root_of_unity_for_block(2,2,inv=True)
- yield from gen_aarch64_root_of_unity_for_block(2,3,inv=True)
- yield from (0,0) # Padding
+ yield from gen_aarch64_root_of_unity_for_block(0, 0, inv=True)
+ yield from gen_aarch64_root_of_unity_for_block(1, 0, inv=True)
+ yield from gen_aarch64_root_of_unity_for_block(1, 1, inv=True)
+ yield from gen_aarch64_root_of_unity_for_block(2, 0, inv=True)
+ yield from gen_aarch64_root_of_unity_for_block(2, 1, inv=True)
+ yield from gen_aarch64_root_of_unity_for_block(2, 2, inv=True)
+ yield from gen_aarch64_root_of_unity_for_block(2, 3, inv=True)
+ yield from (0, 0) # Padding
+
def gen_aarch64_inv_ntt_zetas_layer56():
# Layers 3,4,5,6 are merged, but we emit roots for 3,4
# in separate arrays than those for 5,6
for block in range(8):
+
def double_ith(t, i):
yield from (t[i], t[i])
+
# Ordering of blocks is adjusted to suit the transposed internal
# presentation of the data
for i in range(2):
- yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+0, inv=True), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+1, inv=True), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+2, inv=True), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+3, inv=True), i)
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(5, 4 * block + 0, inv=True), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(5, 4 * block + 1, inv=True), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(5, 4 * block + 2, inv=True), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(5, 4 * block + 3, inv=True), i
+ )
for i in range(2):
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+0, inv=True), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+2, inv=True), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+4, inv=True), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+6, inv=True), i)
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 0, inv=True), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 2, inv=True), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 4, inv=True), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 6, inv=True), i
+ )
for i in range(2):
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+1, inv=True), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+3, inv=True), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+5, inv=True), i)
- yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+7, inv=True), i)
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 1, inv=True), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 3, inv=True), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 5, inv=True), i
+ )
+ yield from double_ith(
+ gen_aarch64_root_of_unity_for_block(6, 8 * block + 7, inv=True), i
+ )
+
def gen_aarch64_mulcache_twiddles():
for idx in range(64):
- root = pow(root_of_unity, bitreverse(64+idx,7), modulus)
+ root = pow(root_of_unity, bitreverse(64 + idx, 7), modulus)
yield prepare_root_for_barrett(root)[0]
yield prepare_root_for_barrett(-root)[0]
+
def gen_aarch64_mulcache_twiddles_twisted():
for idx in range(64):
- root = pow(root_of_unity, bitreverse(64+idx,7), modulus)
+ root = pow(root_of_unity, bitreverse(64 + idx, 7), modulus)
yield prepare_root_for_barrett(root)[1]
yield prepare_root_for_barrett(-root)[1]
+
def gen_aarch64_fwd_ntt_zeta_file(dry_run=False):
def gen():
yield from gen_header()
- yield "#include \"arith_native_aarch64.h\""
+ yield '#include "arith_native_aarch64.h"'
yield ""
yield "#ifdef MLKEM_USE_NATIVE_AARCH64"
yield ""
- yield "/*"
+ yield "/*"
yield " * Table of zeta values used in the AArch64 forward NTT"
yield " * See autogenerate_files.py for details."
yield " */"
@@ -260,7 +332,11 @@ def gen():
yield "int empty_cu_aarch64_zetas;"
yield "#endif /* MLKEM_USE_NATIVE_AARCH64 */"
yield ""
- update_file("mlkem/native/aarch64/aarch64_zetas.c", '\n'.join(gen()), dry_run=dry_run)
+
+ update_file(
+ "mlkem/native/aarch64/aarch64_zetas.c", "\n".join(gen()), dry_run=dry_run
+ )
+
def signed_reduce_u16(x):
x = x % 2**16
@@ -268,6 +344,7 @@ def signed_reduce_u16(x):
x -= 2**16
return x
+
def prepare_root_for_montmul(root):
"""Takes a constant that the code needs to Montgomery-multiply with,
and returns the pair of (a) the signed canonical representative of its
@@ -279,44 +356,49 @@ def prepare_root_for_montmul(root):
root_twisted = signed_reduce_u16(root * pow(modulus, -1, 2**16))
return root, root_twisted
+
def gen_avx2_root_of_unity_for_block(layer, block, inv=False):
# We are computing a negacyclic NTT; the twiddles needed here is
# the second half of the twiddles for a cyclic NTT of twice the size.
- log = bitreverse(pow(2,layer) + block, 7)
+ log = bitreverse(pow(2, layer) + block, 7)
if inv is True:
log = -log
root, root_twisted = prepare_root_for_montmul(pow(root_of_unity, log, modulus))
return root, root_twisted
+
def gen_avx2_fwd_ntt_zetas():
def gen_twiddles(layer, block, repeat):
"""Generates twisted twiddle, then twiddle, for given layer and block.
Repeat both the given number of times."""
root, root_twisted = gen_avx2_root_of_unity_for_block(layer, block)
- return [root]*repeat, [root_twisted]*repeat
+ return [root] * repeat, [root_twisted] * repeat
def gen_twiddles_many(layer, block_base, block_offsets, repeat):
"""Generates twisted twiddles, then twiddles, of each (layer, block_base + i)
pair for i in block_offsets. Each twiddle is repeated `repeat` times."""
- root_pairs = list(map(lambda x: gen_twiddles(layer, block_base + x, repeat), block_offsets))
+ root_pairs = list(
+ map(lambda x: gen_twiddles(layer, block_base + x, repeat), block_offsets)
+ )
yield from (r for l in root_pairs for r in l[1])
yield from (r for l in root_pairs for r in l[0])
# Layers 0 twiddle
yield from gen_twiddles_many(0, 0, range(1), 4)
# Padding so that the subsequent twiddles are 16-byte aligned
- yield from [0]*8
+ yield from [0] * 8
# Layer 1-6 twiddles, separated by whether they belong to the upper or lower half
for i in range(2):
- yield from gen_twiddles_many(1, i*(2**0), range(1), 16)
- yield from gen_twiddles_many(2, i*(2**1), range(2), 8)
- yield from gen_twiddles_many(3, i*(2**2), range(4), 4)
- yield from gen_twiddles_many(4, i*(2**3), range(8), 2)
- yield from gen_twiddles_many(5, i*(2**4), range(16), 1)
- yield from gen_twiddles_many(6, i*(2**5), range(0,32,2), 1)
- yield from gen_twiddles_many(6, i*(2**5), range(1,32,2), 1)
+ yield from gen_twiddles_many(1, i * (2**0), range(1), 16)
+ yield from gen_twiddles_many(2, i * (2**1), range(2), 8)
+ yield from gen_twiddles_many(3, i * (2**2), range(4), 4)
+ yield from gen_twiddles_many(4, i * (2**3), range(8), 2)
+ yield from gen_twiddles_many(5, i * (2**4), range(16), 1)
+ yield from gen_twiddles_many(6, i * (2**5), range(0, 32, 2), 1)
+ yield from gen_twiddles_many(6, i * (2**5), range(1, 32, 2), 1)
+
def gen_avx2_fwd_ntt_zeta_file(dry_run=False):
def gen():
@@ -328,17 +410,21 @@ def gen():
yield ""
yield from map(lambda t: str(t) + ",", gen_avx2_fwd_ntt_zetas())
yield ""
- update_file("mlkem/native/x86_64/x86_64_zetas.i", '\n'.join(gen()), dry_run=dry_run)
+
+ update_file("mlkem/native/x86_64/x86_64_zetas.i", "\n".join(gen()), dry_run=dry_run)
+
def _main():
parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument("--dry-run", default=False, action='store_true')
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument("--dry-run", default=False, action="store_true")
args = parser.parse_args()
gen_c_zeta_file(args.dry_run)
gen_aarch64_fwd_ntt_zeta_file(args.dry_run)
gen_avx2_fwd_ntt_zeta_file(args.dry_run)
+
if __name__ == "__main__":
_main()
diff --git a/scripts/lib/mlkem_test.py b/scripts/lib/mlkem_test.py
index 9e82521e3..69c2032a3 100644
--- a/scripts/lib/mlkem_test.py
+++ b/scripts/lib/mlkem_test.py
@@ -709,19 +709,27 @@ def all(opt: bool):
def cbmc(self, k):
config_logger(self.verbose)
+
def run_cbmc(mlkem_k):
envvars = {"MLKEM_K": mlkem_k}
cpucount = os.cpu_count()
p = subprocess.Popen(
- ["python3", "run-cbmc-proofs.py", "--summarize", "--no-coverage", f"-j{cpucount}"],
+ [
+ "python3",
+ "run-cbmc-proofs.py",
+ "--summarize",
+ "--no-coverage",
+ f"-j{cpucount}",
+ ],
cwd="cbmc/proofs",
env=os.environ.copy() | envvars,
)
p.communicate()
assert p.returncode == 0
+
if k == "ALL":
run_cbmc("2")
run_cbmc("3")
run_cbmc("4")
- else:
+ else:
run_cbmc(k)
diff --git a/test/test_bounds.py b/test/test_bounds.py
index 7da75716d..afb37d6c7 100644
--- a/test/test_bounds.py
+++ b/test/test_bounds.py
@@ -11,16 +11,18 @@
from functools import lru_cache
# Global constants
-R = 2**16
-Q = 3329
+R = 2**16
+Q = 3329
Qinv = pow(-Q, -1, R)
#
# Barrett multiplication via doubling
#
+
def round_even(x):
- return 2*round(x/2)
+ return 2 * round(x / 2)
+
@lru_cache(maxsize=None)
def barrett_twiddle(b):
@@ -28,69 +30,81 @@ def barrett_twiddle(b):
via doubling-high-multiply."""
return round_even(b * R / Q) // 2
-def sqrdmulh_i16(a,b):
+
+def sqrdmulh_i16(a, b):
"""Doubling multiply high with rouding"""
# We cannot use round() here because of its behaviour
# on multiples of 0.5: round(-.5) = round(0.5) = 0
- return ((2 * a * b + 2**15) // 2**16)
+ return (2 * a * b + 2**15) // 2**16
+
def barmul(a, b):
"""Compute doubling Barrett multiplication of a and b"""
b_twiddle = barrett_twiddle(b)
return a * b - Q * sqrdmulh_i16(a, b_twiddle)
+
#
# Montgomery multiplication
#
+
def lift_signed_i16(x):
"""Returns signed canonical representative modulo R=2^16."""
x = x % R
- if x >= R//2:
+ if x >= R // 2:
x -= R
return x
+
@lru_cache(maxsize=None)
def montmul_neg_twiddle(b):
return (b * Qinv) % R
+
def montmul_neg(a, b):
b_twiddle = montmul_neg_twiddle(b)
return (a * b + Q * lift_signed_i16(a * b_twiddle)) // R
+
#
# Generic test functions
#
+
def test_all_i16(f):
- for a in range(-R//2, R//2):
+ for a in range(-R // 2, R // 2):
if a % 1000 == 0:
print(f"{a} ...")
- for b in range(-Q//2,Q//2):
- f(a,b)
- f(-a,b)
+ for b in range(-Q // 2, Q // 2):
+ f(a, b)
+ f(-a, b)
+
-def test_random(f,num_tests=10000000, bound=2*R):
+def test_random(f, num_tests=10000000, bound=2 * R):
print(f"Randomly checking Barrett<->Montgomery relation ({num_tests} tests)...")
for i in range(num_tests):
if i % 100000 == 0:
print(f"... run {i} tests ({((i * 1000) // num_tests)/10}%)")
a = random.randrange(-bound, bound)
b = random.randrange(-bound, bound)
- f(a,b)
+ f(a, b)
+
#
# Test relation between Barrett and Montgomery multiplication
# (Proposition 1 in https://eprint.iacr.org/2021/986.pdf)
#
+
@lru_cache(maxsize=None)
def modq_even(a):
- return a - Q * round_even(a/Q)
+ return a - Q * round_even(a / Q)
+
-def barmul_test(a,b):
+def barmul_test(a, b):
bp = modq_even(b * R)
- r0 = barmul(a,b)
+ r0 = barmul(a, b)
r1 = montmul_neg(a, bp)
if r0 != r1:
print(f"barmul test failure for {a,b}!")
@@ -98,12 +112,15 @@ def barmul_test(a,b):
print(f"Montgomery multiplication: {r1} (factor {bp})")
assert False
+
def bar_mont_test_all_i16():
test_all_i16(barmul_test)
+
def bar_test_random():
test_random(barmul_test)
+
#
# Test bound on Barrett multiplication
#
@@ -112,36 +129,41 @@ def bar_test_random():
# where 0.0508 appears as a close upper boun for Q/2**16.
#
-def bar_bound_test(a,b, max_scale=[]):
+
+def bar_bound_test(a, b, max_scale=[]):
if a == 0:
return
- C = abs(a)/Q
- ab = barmul(a,b)
+ C = abs(a) / Q
+ ab = barmul(a, b)
Cp = abs(ab) / Q
- scale_bound = 0.0508 # Upper bound to Q/2**16
- scale = (Cp - 1/2) / C
+ scale_bound = 0.0508 # Upper bound to Q/2**16
+ scale = (Cp - 1 / 2) / C
if len(max_scale) == 0 or scale > max_scale[-1]:
max_scale.append(scale)
print(f"New scale bound for {(a,b)}: {scale}")
- if Cp >= scale_bound * C + 1/2:
+ if Cp >= scale_bound * C + 1 / 2:
print(f"bar bound test failure for (a,b)={(a,b)}")
print(f"barmul(a,b): {ab}")
print(f"C (=a/q): {C}")
print(f"Cp (=barmul(a,b)/q): {Cp}")
assert False
+
def bar_bound_test_all_i16():
test_all_i16(bar_bound_test)
+
#
# NTT bounds progression
#
-def funciter(f,n,x):
+
+def funciter(f, n, x):
"""Compute f^n(x)"""
if n == 0:
return x
- return funciter(f, n-1, f(x))
+ return funciter(f, n - 1, f(x))
+
def ntt_layer_bound_growth(factor):
"""If the inputs to a CT-based layer of the NTT are bound by C*Q,
@@ -151,12 +173,14 @@ def ntt_layer_bound_growth(factor):
# Each coefficient is replaced by a +- t*b where a,b are input coefficients,
# t is a suitable twiddle, and * is Barrett multiplication. a is thus bound
# by C*q, while t*b is bound by 0.0508*C + 1/2 (see above).
- return (lambda C: C + factor * C + 1/2)
+ return lambda C: C + factor * C + 1 / 2
+
def ntt_layer_bound(factor, layers=7, initial=Q):
"""Returns a bound on the absolute value of coefficients after a fixed number
of layers, assuming an initial absolute bound `initial`."""
- return funciter(ntt_layer_bound_growth(factor),layers, initial/Q)
+ return funciter(ntt_layer_bound_growth(factor), layers, initial / Q)
+
barmul_ntt_layer_bound = ntt_layer_bound(0.0508)
montmul_ntt_layer_bound = ntt_layer_bound(0.0204)