From 80f3f17c4d3778a5751957164907fe598dedec74 Mon Sep 17 00:00:00 2001 From: "Thing-han, Lim" <15379156+potsrevennil@users.noreply.github.com> Date: Wed, 4 Dec 2024 11:45:12 +0800 Subject: [PATCH 1/2] ci: fix Python files detection in format and lint scripts Initially, the --include scripts/tests option was used with the Python formatter black because scripts/tests was the only Python script. However, this flag overrides the default file detection pattern, meaning only scripts/tests was being formatted. As more .py files have been added, this setup became insufficient. This commit updates the regex to include both scripts/tests and all .py files, ensuring proper detection and formatting of all Python scripts in the repository. Signed-off-by: Thing-han, Lim <15379156+potsrevennil@users.noreply.github.com> --- scripts/ci/lint | 2 +- scripts/format | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/ci/lint b/scripts/ci/lint index 3c33aeb13..bcb550d4e 100755 --- a/scripts/ci/lint +++ b/scripts/ci/lint @@ -48,7 +48,7 @@ checkerr "Lint shell" "$(shfmt -s -l -i 2 -ci -fn $(shfmt -f $(git grep -l '' :/ echo "::endgroup::" echo "::group::Linting python scripts with black" -if ! diff=$(black --check --diff -q --include scripts/tests "$ROOT"); then +if ! diff=$(black --check --diff -q --include "(scripts/tests|\.py$)" "$ROOT"); then echo "::error title=Format error::$diff" SUCCESS=false echo ":x: Lint python" >>"$GITHUB_STEP_SUMMARY" diff --git a/scripts/format b/scripts/format index 267811204..486ad7e30 100755 --- a/scripts/format +++ b/scripts/format @@ -26,7 +26,7 @@ info "Formatting shell scripts" shfmt -s -w -l -i 2 -ci -fn $(shfmt -f $(git grep -l '' :/)) info "Formatting python scripts" -black --include scripts/tests "$ROOT" +black --include "(scripts/tests|\.py$)" "$ROOT" info "Formatting c files" clang-format -i $(git ls-files ":/*.c" ":/*.h") From 372e74a9f08708848f3416e5153e358abac9e891 Mon Sep 17 00:00:00 2001 From: "Thing-han, Lim" <15379156+potsrevennil@users.noreply.github.com> Date: Wed, 4 Dec 2024 12:26:09 +0800 Subject: [PATCH 2/2] style: format all python scripts Signed-off-by: Thing-han, Lim <15379156+potsrevennil@users.noreply.github.com> --- cbmc/proofs/lib/print_tool_versions.py | 8 +- cbmc/proofs/lib/summarize.py | 18 +- cbmc/proofs/run-cbmc-proofs.py | 228 +++++++++++++++-------- scripts/autogenerate_files.py | 240 +++++++++++++++++-------- scripts/lib/mlkem_test.py | 12 +- test/test_bounds.py | 74 +++++--- 6 files changed, 392 insertions(+), 188 deletions(-) diff --git a/cbmc/proofs/lib/print_tool_versions.py b/cbmc/proofs/lib/print_tool_versions.py index bdeb429e3..c97757b41 100755 --- a/cbmc/proofs/lib/print_tool_versions.py +++ b/cbmc/proofs/lib/print_tool_versions.py @@ -29,11 +29,12 @@ def _format_versions(table): if version: v_str = f'
{version}
' else: - v_str = 'not found' + v_str = "not found" lines.append( f'{tool}:' - f'{v_str}') + f"{v_str}" + ) lines.append("") return "\n".join(lines) @@ -55,7 +56,8 @@ def _get_tool_versions(): continue if proc.returncode: logging.error( - "%s'%s --version' returned %s", err, tool, str(proc.returncode)) + "%s'%s --version' returned %s", err, tool, str(proc.returncode) + ) continue ret[tool] = out.strip() return ret diff --git a/cbmc/proofs/lib/summarize.py b/cbmc/proofs/lib/summarize.py index 5dbd230c8..f90508e17 100644 --- a/cbmc/proofs/lib/summarize.py +++ b/cbmc/proofs/lib/summarize.py @@ -15,11 +15,13 @@ def get_args(): """Parse arguments for summarize script.""" parser = argparse.ArgumentParser(description=DESCRIPTION) - for arg in [{ + for arg in [ + { "flags": ["--run-file"], "help": "path to the Litani run.json file", "required": True, - }]: + } + ]: flags = arg.pop("flags") parser.add_argument(*flags, **arg) return parser.parse_args() @@ -111,7 +113,7 @@ def print_proof_results(out_file): When printing, each string will render as a GitHub flavored Markdown table. """ output = "## Summary of CBMC proof results\n\n" - with open(out_file, encoding='utf-8') as run_json: + with open(out_file, encoding="utf-8") as run_json: run_dict = json.load(run_json) status_table, proof_table = _get_status_and_proof_summaries(run_dict) for summary in (status_table, proof_table): @@ -126,12 +128,12 @@ def print_proof_results(out_file): print(output, file=handle) handle.flush() else: - logging.warning( - "$GITHUB_STEP_SUMMARY not set, not writing summary file") + logging.warning("$GITHUB_STEP_SUMMARY not set, not writing summary file") msg = ( "Click the 'Summary' button to view a Markdown table " - "summarizing all proof results") + "summarizing all proof results" + ) if run_dict["status"] != "success": logging.error("Not all proofs passed.") logging.error(msg) @@ -139,10 +141,10 @@ def print_proof_results(out_file): logging.info(msg) -if __name__ == '__main__': +if __name__ == "__main__": args = get_args() logging.basicConfig(format="%(levelname)s: %(message)s") try: print_proof_results(args.run_file) - except Exception as ex: # pylint: disable=broad-except + except Exception as ex: # pylint: disable=broad-except logging.critical("Could not print results. Exception: %s", str(ex)) diff --git a/cbmc/proofs/run-cbmc-proofs.py b/cbmc/proofs/run-cbmc-proofs.py index 557e65e8e..f2db280f3 100755 --- a/cbmc/proofs/run-cbmc-proofs.py +++ b/cbmc/proofs/run-cbmc-proofs.py @@ -71,11 +71,14 @@ def get_project_name(): cmd = [ "make", "--no-print-directory", - "-f", "Makefile.common", + "-f", + "Makefile.common", "echo-project-name", ] logging.debug(" ".join(cmd)) - proc = subprocess.run(cmd, universal_newlines=True, stdout=subprocess.PIPE, check=False) + proc = subprocess.run( + cmd, universal_newlines=True, stdout=subprocess.PIPE, check=False + ) if proc.returncode: logging.critical("could not run make to determine project name") sys.exit(1) @@ -83,84 +86,103 @@ def get_project_name(): logging.warning( "project name has not been set; using generic name instead. " "Set the PROJECT_NAME value in Makefile-project-defines to " - "remove this warning") + "remove this warning" + ) return "" return proc.stdout.strip() def get_args(): pars = argparse.ArgumentParser( - description=DESCRIPTION, epilog=EPILOG, - formatter_class=argparse.RawDescriptionHelpFormatter) - for arg in [{ + description=DESCRIPTION, + epilog=EPILOG, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + for arg in [ + { "flags": ["-j", "--parallel-jobs"], "type": int, "metavar": "N", "help": "run at most N proof jobs in parallel", - }, { + }, + { "flags": ["--fail-on-proof-failure"], "action": "store_true", "help": "exit with return code `10' if any proof failed" - " (default: exit 0)", - }, { + " (default: exit 0)", + }, + { "flags": ["--no-standalone"], "action": "store_true", "help": "only configure proofs: do not initialize nor run", - }, { + }, + { "flags": ["-p", "--proofs"], "nargs": "+", "metavar": "DIR", "help": "only run proof in directory DIR (can pass more than one)", - }, { + }, + { "flags": ["--project-name"], "metavar": "NAME", "default": get_project_name(), "help": "project name for report. Default: %(default)s", - }, { + }, + { "flags": ["--marker-file"], "metavar": "FILE", "default": "cbmc-proof.txt", "help": ( - "name of file that marks proof directories. Default: " - "%(default)s"), - }, { + "name of file that marks proof directories. Default: " "%(default)s" + ), + }, + { "flags": ["--no-memory-profile"], "action": "store_true", - "help": "disable memory profiling, even if Litani supports it" - }, { + "help": "disable memory profiling, even if Litani supports it", + }, + { "flags": ["--no-expensive-limit"], "action": "store_true", "help": "do not limit parallelism of 'EXPENSIVE' jobs", - }, { + }, + { "flags": ["--expensive-jobs-parallelism"], "metavar": "N", "default": 1, "type": int, "help": ( "how many proof jobs marked 'EXPENSIVE' to run in parallel. " - "Default: %(default)s"), - }, { + "Default: %(default)s" + ), + }, + { "flags": ["--verbose"], "action": "store_true", "help": "verbose output", - }, { + }, + { "flags": ["--debug"], "action": "store_true", "help": "debug output", - }, { + }, + { "flags": ["--summarize"], "action": "store_true", "help": "summarize proof results with two tables on stdout", - }, { + }, + { "flags": ["--version"], "action": "version", "version": "CBMC starter kit 2.10", - "help": "display version and exit" - }, { + "help": "display version and exit", + }, + { "flags": ["--no-coverage"], "action": "store_true", - "help": "do property checking without coverage checking" - }]: + "help": "do property checking without coverage checking", + }, + ]: flags = arg.pop("flags") pars.add_argument(*flags, **arg) return pars.parse_args() @@ -171,8 +193,7 @@ def set_up_logging(verbose): level = logging.DEBUG else: level = logging.WARNING - logging.basicConfig( - format="run-cbmc-proofs: %(message)s", level=level) + logging.basicConfig(format="run-cbmc-proofs: %(message)s", level=level) def task_pool_size(): @@ -184,8 +205,12 @@ def task_pool_size(): def print_counter(counter): # pylint: disable=consider-using-f-string - print("\rConfiguring CBMC proofs: " - "{complete:{width}} / {total:{width}}".format(**counter), end="", file=sys.stderr) + print( + "\rConfiguring CBMC proofs: " + "{complete:{width}} / {total:{width}}".format(**counter), + end="", + file=sys.stderr, + ) def get_proof_dirs(proof_root, proof_list, marker_file): @@ -207,8 +232,8 @@ def get_proof_dirs(proof_root, proof_list, marker_file): if proofs_remaining: logging.critical( - "The following proofs were not found: %s", - ", ".join(proofs_remaining)) + "The following proofs were not found: %s", ", ".join(proofs_remaining) + ) sys.exit(1) @@ -237,16 +262,20 @@ def run_build(litani, jobs, fail_on_proof_failure, summarize): logging.error("One or more proofs failed") sys.exit(10) + def get_litani_path(proof_root): cmd = [ "make", "--no-print-directory", f"PROOF_ROOT={proof_root}", - "-f", "Makefile.common", + "-f", + "Makefile.common", "litani-path", ] logging.debug(" ".join(cmd)) - proc = subprocess.run(cmd, universal_newlines=True, stdout=subprocess.PIPE, check=False) + proc = subprocess.run( + cmd, universal_newlines=True, stdout=subprocess.PIPE, check=False + ) if proc.returncode: logging.critical("Could not determine path to litani") sys.exit(1) @@ -256,7 +285,8 @@ def get_litani_path(proof_root): def get_litani_capabilities(litani_path): cmd = [litani_path, "print-capabilities"] proc = subprocess.run( - cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=False) + cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, check=False + ) if proc.returncode: return [] try: @@ -279,11 +309,14 @@ def check_uid_uniqueness(proof_dir, proof_uids): logging.critical( "The Makefile in directory '%s' should have a different " "PROOF_UID than the Makefile in directory '%s'", - proof_dir, proof_uids[match["uid"]]) + proof_dir, + proof_uids[match["uid"]], + ) sys.exit(1) logging.critical( - "The Makefile in directory '%s' should contain a line like", proof_dir) + "The Makefile in directory '%s' should contain a line like", proof_dir + ) logging.critical("PROOF_UID = ...") logging.critical("with a unique identifier for the proof.") sys.exit(1) @@ -301,8 +334,15 @@ def should_enable_pools(litani_caps, args): return "pools" in litani_caps -async def configure_proof_dirs( # pylint: disable=too-many-arguments - queue, counter, proof_uids, enable_pools, enable_memory_profiling, report_target, debug): +async def configure_proof_dirs( # pylint: disable=too-many-arguments + queue, + counter, + proof_uids, + enable_pools, + enable_memory_profiling, + report_target, + debug, +): while True: print_counter(counter) path = str(await queue.get()) @@ -310,19 +350,32 @@ async def configure_proof_dirs( # pylint: disable=too-many-arguments check_uid_uniqueness(path, proof_uids) pools = ["ENABLE_POOLS=true"] if enable_pools else [] - profiling = [ - "ENABLE_MEMORY_PROFILING=true"] if enable_memory_profiling else [] + profiling = ["ENABLE_MEMORY_PROFILING=true"] if enable_memory_profiling else [] # delete old reports proc = await asyncio.create_subprocess_exec( - "make", "veryclean", cwd=path, - stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + "make", + "veryclean", + cwd=path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) # Allow interactive tasks to preempt proof configuration proc = await asyncio.create_subprocess_exec( - "nice", "-n", "15", "make", *pools, - *profiling, "-B", report_target, "" if debug else "--quiet", cwd=path, - stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE) + "nice", + "-n", + "15", + "make", + *pools, + *profiling, + "-B", + report_target, + "" if debug else "--quiet", + cwd=path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) stdout, stderr = await proc.communicate() logging.debug("returncode: %s", str(proc.returncode)) logging.debug("stdout:") @@ -341,13 +394,20 @@ async def configure_proof_dirs( # pylint: disable=too-many-arguments def add_tool_version_job(): cmd = [ - "litani", "add-job", - "--command", "./lib/print_tool_versions.py .", - "--description", "printing out tool versions", - "--phony-outputs", str(uuid.uuid4()), - "--pipeline-name", "print_tool_versions", - "--ci-stage", "report", - "--tags", "front-page-text", + "litani", + "add-job", + "--command", + "./lib/print_tool_versions.py .", + "--description", + "printing out tool versions", + "--phony-outputs", + str(uuid.uuid4()), + "--pipeline-name", + "print_tool_versions", + "--ci-stage", + "report", + "--tags", + "front-page-text", ] proc = subprocess.run(cmd) if proc.returncode: @@ -355,7 +415,7 @@ def add_tool_version_job(): sys.exit(1) -async def main(): # pylint: disable=too-many-locals +async def main(): # pylint: disable=too-many-locals args = get_args() set_up_logging(args.verbose) @@ -364,13 +424,19 @@ async def main(): # pylint: disable=too-many-locals litani_caps = get_litani_capabilities(litani) enable_pools = should_enable_pools(litani_caps, args) - init_pools = [ - "--pools", f"expensive:{args.expensive_jobs_parallelism}" - ] if enable_pools else [] + init_pools = ( + ["--pools", f"expensive:{args.expensive_jobs_parallelism}"] + if enable_pools + else [] + ) if not args.no_standalone: cmd = [ - str(litani), "init", *init_pools, "--project", args.project_name, + str(litani), + "init", + *init_pools, + "--project", + args.project_name, "--no-print-out-dir", ] @@ -378,13 +444,19 @@ async def main(): # pylint: disable=too-many-locals out_prefix = proof_root / "output" out_symlink = out_prefix / "latest" out_index = out_symlink / "html" / "index.html" - cmd.extend([ - "--output-prefix", str(out_prefix), - "--output-symlink", str(out_symlink), - ]) + cmd.extend( + [ + "--output-prefix", + str(out_prefix), + "--output-symlink", + str(out_symlink), + ] + ) print( "\nFor your convenience, the output of this run will be symbolically linked to ", - out_index, "\n") + out_index, + "\n", + ) logging.debug(" ".join(cmd)) proc = subprocess.run(cmd, check=False) @@ -392,8 +464,7 @@ async def main(): # pylint: disable=too-many-locals logging.critical("Failed to run litani init") sys.exit(1) - proof_dirs = list(get_proof_dirs( - proof_root, args.proofs, args.marker_file)) + proof_dirs = list(get_proof_dirs(proof_root, args.proofs, args.marker_file)) if not proof_dirs: logging.critical("No proof directories found") sys.exit(1) @@ -407,7 +478,7 @@ async def main(): # pylint: disable=too-many-locals "fail": [], "complete": 0, "total": len(proof_dirs), - "width": int(math.log10(len(proof_dirs))) + 1 + "width": int(math.log10(len(proof_dirs))) + 1, } proof_uids = {} @@ -417,9 +488,17 @@ async def main(): # pylint: disable=too-many-locals report_target = "_report_no_coverage" if args.no_coverage else "_report" for _ in range(task_pool_size()): - task = asyncio.create_task(configure_proof_dirs( - proof_queue, counter, proof_uids, enable_pools, - enable_memory_profiling, report_target, args.debug)) + task = asyncio.create_task( + configure_proof_dirs( + proof_queue, + counter, + proof_uids, + enable_pools, + enable_memory_profiling, + report_target, + args.debug, + ) + ) tasks.append(task) await proof_queue.join() @@ -431,12 +510,15 @@ async def main(): # pylint: disable=too-many-locals if counter["fail"]: logging.critical( - "Failed to configure the following proofs:\n%s", "\n".join( - [str(f) for f in counter["fail"]])) + "Failed to configure the following proofs:\n%s", + "\n".join([str(f) for f in counter["fail"]]), + ) sys.exit(1) if not args.no_standalone: - run_build(litani, args.parallel_jobs, args.fail_on_proof_failure, args.summarize) + run_build( + litani, args.parallel_jobs, args.fail_on_proof_failure, args.summarize + ) if __name__ == "__main__": diff --git a/scripts/autogenerate_files.py b/scripts/autogenerate_files.py index 6d3ed0cd7..0e23f69db 100644 --- a/scripts/autogenerate_files.py +++ b/scripts/autogenerate_files.py @@ -15,6 +15,7 @@ # It currently covers: # - zeta values for the reference NTT and invNTT + def gen_header(): yield "/*" yield " * Copyright (c) 2024 The mlkem-native project authors" @@ -27,12 +28,15 @@ def gen_header(): yield " */" yield "" + def update_file(filename, content, dry_run=False): # Format content p = subprocess.run(["clang-format"], capture_output=True, input=content, text=True) if p.returncode != 0: - print(f"Failed to auto-format autogenerated code (clang-format return code {p.returncode}") + print( + f"Failed to auto-format autogenerated code (clang-format return code {p.returncode}" + ) exit(1) content = p.stdout @@ -46,16 +50,20 @@ def update_file(filename, content, dry_run=False): with open(filename, "r") as f: current_content = f.read() if current_content != content: - print(f"Autogenerated file {filename} needs updating. Have you called scripts/autogenerated.py?") + print( + f"Autogenerated file {filename} needs updating. Have you called scripts/autogenerated.py?" + ) exit(1) -def bitreverse(i,n): + +def bitreverse(i, n): r = 0 for _ in range(n): - r = 2*r + (i & 1) + r = 2 * r + (i & 1) i >>= 1 return r + def signed_reduce(a): """Return signed canonical representative of a mod b""" c = a % modulus @@ -63,6 +71,7 @@ def signed_reduce(a): c -= modulus return c + def gen_c_zetas(): """Generate source and header file for zeta values used in the reference NTT and invNTT""" @@ -75,12 +84,13 @@ def gen_c_zetas(): zeta.append(signed_reduce(pow(root_of_unity, i, modulus) * montgomery_factor)) # The source code stores the zeta table in bit reversed form - yield from (zeta[bitreverse(i,7)] for i in range(128)) + yield from (zeta[bitreverse(i, 7)] for i in range(128)) + def gen_c_zeta_file(dry_run=False): def gen(): yield from gen_header() - yield "#include \"ntt.h\"" + yield '#include "ntt.h"' yield "" yield "/*" yield " * Table of zeta values used in the reference NTT and inverse NTT." @@ -90,7 +100,9 @@ def gen(): yield from map(lambda t: str(t) + ",", gen_c_zetas()) yield "};" yield "" - update_file("mlkem/zetas.c", '\n'.join(gen()), dry_run=dry_run) + + update_file("mlkem/zetas.c", "\n".join(gen()), dry_run=dry_run) + def prepare_root_for_barrett(root): """Takes a constant that the code needs to Barrett-multiply with, @@ -113,121 +125,181 @@ def round_to_even(t): root_twisted = round_to_even((root * 2**16) / modulus) // 2 return root, root_twisted + def gen_aarch64_root_of_unity_for_block(layer, block, inv=False): # We are computing a negacyclic NTT; the twiddles needed here is # the second half of the twiddles for a cyclic NTT of twice the size. - log = bitreverse(pow(2,layer) + block, 7) + log = bitreverse(pow(2, layer) + block, 7) if inv is True: log = -log root, root_twisted = prepare_root_for_barrett(pow(root_of_unity, log, modulus)) return root, root_twisted + def gen_aarch64_fwd_ntt_zetas_layer01234(): # Layers 0,1,2 are merged - yield from gen_aarch64_root_of_unity_for_block(0,0) - yield from gen_aarch64_root_of_unity_for_block(1,0) - yield from gen_aarch64_root_of_unity_for_block(1,1) - yield from gen_aarch64_root_of_unity_for_block(2,0) - yield from gen_aarch64_root_of_unity_for_block(2,1) - yield from gen_aarch64_root_of_unity_for_block(2,2) - yield from gen_aarch64_root_of_unity_for_block(2,3) - yield from (0,0) # Padding + yield from gen_aarch64_root_of_unity_for_block(0, 0) + yield from gen_aarch64_root_of_unity_for_block(1, 0) + yield from gen_aarch64_root_of_unity_for_block(1, 1) + yield from gen_aarch64_root_of_unity_for_block(2, 0) + yield from gen_aarch64_root_of_unity_for_block(2, 1) + yield from gen_aarch64_root_of_unity_for_block(2, 2) + yield from gen_aarch64_root_of_unity_for_block(2, 3) + yield from (0, 0) # Padding # Layers 3,4,5,6 are merged, but we emit roots for 3,4 # in separate arrays than those for 5,6 - for block in range(8): # There are 8 blocks in Layer 4 - yield from gen_aarch64_root_of_unity_for_block(3,block) - yield from gen_aarch64_root_of_unity_for_block(4,2*block+0) - yield from gen_aarch64_root_of_unity_for_block(4,2*block+1) - yield from (0,0) # Padding + for block in range(8): # There are 8 blocks in Layer 4 + yield from gen_aarch64_root_of_unity_for_block(3, block) + yield from gen_aarch64_root_of_unity_for_block(4, 2 * block + 0) + yield from gen_aarch64_root_of_unity_for_block(4, 2 * block + 1) + yield from (0, 0) # Padding + def gen_aarch64_fwd_ntt_zetas_layer56(): # Layers 3,4,5,6 are merged, but we emit roots for 3,4 # in separate arrays than those for 5,6 for block in range(8): + def double_ith(t, i): yield from (t[i], t[i]) + # Ordering of blocks is adjusted to suit the transposed internal # presentation of the data for i in range(2): - yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+0), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+1), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+2), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+3), i) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(5, 4 * block + 0), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(5, 4 * block + 1), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(5, 4 * block + 2), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(5, 4 * block + 3), i + ) for i in range(2): - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+0), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+2), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+4), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+6), i) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 0), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 2), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 4), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 6), i + ) for i in range(2): - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+1), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+3), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+5), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+7), i) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 1), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 3), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 5), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 7), i + ) + def gen_aarch64_inv_ntt_zetas_layer01234(): # Layers 3,4,5,6 are merged, but we emit roots for 3,4 # in separate arrays than those for 5,6 - for block in range(8): # There are 8 blocks in Layer 4 - yield from gen_aarch64_root_of_unity_for_block(3,block,inv=True) - yield from gen_aarch64_root_of_unity_for_block(4,2*block+0,inv=True) - yield from gen_aarch64_root_of_unity_for_block(4,2*block+1,inv=True) - yield from (0,0) # Padding + for block in range(8): # There are 8 blocks in Layer 4 + yield from gen_aarch64_root_of_unity_for_block(3, block, inv=True) + yield from gen_aarch64_root_of_unity_for_block(4, 2 * block + 0, inv=True) + yield from gen_aarch64_root_of_unity_for_block(4, 2 * block + 1, inv=True) + yield from (0, 0) # Padding # Layers 0,1,2 are merged - yield from gen_aarch64_root_of_unity_for_block(0,0,inv=True) - yield from gen_aarch64_root_of_unity_for_block(1,0,inv=True) - yield from gen_aarch64_root_of_unity_for_block(1,1,inv=True) - yield from gen_aarch64_root_of_unity_for_block(2,0,inv=True) - yield from gen_aarch64_root_of_unity_for_block(2,1,inv=True) - yield from gen_aarch64_root_of_unity_for_block(2,2,inv=True) - yield from gen_aarch64_root_of_unity_for_block(2,3,inv=True) - yield from (0,0) # Padding + yield from gen_aarch64_root_of_unity_for_block(0, 0, inv=True) + yield from gen_aarch64_root_of_unity_for_block(1, 0, inv=True) + yield from gen_aarch64_root_of_unity_for_block(1, 1, inv=True) + yield from gen_aarch64_root_of_unity_for_block(2, 0, inv=True) + yield from gen_aarch64_root_of_unity_for_block(2, 1, inv=True) + yield from gen_aarch64_root_of_unity_for_block(2, 2, inv=True) + yield from gen_aarch64_root_of_unity_for_block(2, 3, inv=True) + yield from (0, 0) # Padding + def gen_aarch64_inv_ntt_zetas_layer56(): # Layers 3,4,5,6 are merged, but we emit roots for 3,4 # in separate arrays than those for 5,6 for block in range(8): + def double_ith(t, i): yield from (t[i], t[i]) + # Ordering of blocks is adjusted to suit the transposed internal # presentation of the data for i in range(2): - yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+0, inv=True), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+1, inv=True), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+2, inv=True), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(5,4*block+3, inv=True), i) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(5, 4 * block + 0, inv=True), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(5, 4 * block + 1, inv=True), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(5, 4 * block + 2, inv=True), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(5, 4 * block + 3, inv=True), i + ) for i in range(2): - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+0, inv=True), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+2, inv=True), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+4, inv=True), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+6, inv=True), i) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 0, inv=True), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 2, inv=True), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 4, inv=True), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 6, inv=True), i + ) for i in range(2): - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+1, inv=True), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+3, inv=True), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+5, inv=True), i) - yield from double_ith(gen_aarch64_root_of_unity_for_block(6,8*block+7, inv=True), i) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 1, inv=True), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 3, inv=True), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 5, inv=True), i + ) + yield from double_ith( + gen_aarch64_root_of_unity_for_block(6, 8 * block + 7, inv=True), i + ) + def gen_aarch64_mulcache_twiddles(): for idx in range(64): - root = pow(root_of_unity, bitreverse(64+idx,7), modulus) + root = pow(root_of_unity, bitreverse(64 + idx, 7), modulus) yield prepare_root_for_barrett(root)[0] yield prepare_root_for_barrett(-root)[0] + def gen_aarch64_mulcache_twiddles_twisted(): for idx in range(64): - root = pow(root_of_unity, bitreverse(64+idx,7), modulus) + root = pow(root_of_unity, bitreverse(64 + idx, 7), modulus) yield prepare_root_for_barrett(root)[1] yield prepare_root_for_barrett(-root)[1] + def gen_aarch64_fwd_ntt_zeta_file(dry_run=False): def gen(): yield from gen_header() - yield "#include \"arith_native_aarch64.h\"" + yield '#include "arith_native_aarch64.h"' yield "" yield "#ifdef MLKEM_USE_NATIVE_AARCH64" yield "" - yield "/*" + yield "/*" yield " * Table of zeta values used in the AArch64 forward NTT" yield " * See autogenerate_files.py for details." yield " */" @@ -260,7 +332,11 @@ def gen(): yield "int empty_cu_aarch64_zetas;" yield "#endif /* MLKEM_USE_NATIVE_AARCH64 */" yield "" - update_file("mlkem/native/aarch64/aarch64_zetas.c", '\n'.join(gen()), dry_run=dry_run) + + update_file( + "mlkem/native/aarch64/aarch64_zetas.c", "\n".join(gen()), dry_run=dry_run + ) + def signed_reduce_u16(x): x = x % 2**16 @@ -268,6 +344,7 @@ def signed_reduce_u16(x): x -= 2**16 return x + def prepare_root_for_montmul(root): """Takes a constant that the code needs to Montgomery-multiply with, and returns the pair of (a) the signed canonical representative of its @@ -279,44 +356,49 @@ def prepare_root_for_montmul(root): root_twisted = signed_reduce_u16(root * pow(modulus, -1, 2**16)) return root, root_twisted + def gen_avx2_root_of_unity_for_block(layer, block, inv=False): # We are computing a negacyclic NTT; the twiddles needed here is # the second half of the twiddles for a cyclic NTT of twice the size. - log = bitreverse(pow(2,layer) + block, 7) + log = bitreverse(pow(2, layer) + block, 7) if inv is True: log = -log root, root_twisted = prepare_root_for_montmul(pow(root_of_unity, log, modulus)) return root, root_twisted + def gen_avx2_fwd_ntt_zetas(): def gen_twiddles(layer, block, repeat): """Generates twisted twiddle, then twiddle, for given layer and block. Repeat both the given number of times.""" root, root_twisted = gen_avx2_root_of_unity_for_block(layer, block) - return [root]*repeat, [root_twisted]*repeat + return [root] * repeat, [root_twisted] * repeat def gen_twiddles_many(layer, block_base, block_offsets, repeat): """Generates twisted twiddles, then twiddles, of each (layer, block_base + i) pair for i in block_offsets. Each twiddle is repeated `repeat` times.""" - root_pairs = list(map(lambda x: gen_twiddles(layer, block_base + x, repeat), block_offsets)) + root_pairs = list( + map(lambda x: gen_twiddles(layer, block_base + x, repeat), block_offsets) + ) yield from (r for l in root_pairs for r in l[1]) yield from (r for l in root_pairs for r in l[0]) # Layers 0 twiddle yield from gen_twiddles_many(0, 0, range(1), 4) # Padding so that the subsequent twiddles are 16-byte aligned - yield from [0]*8 + yield from [0] * 8 # Layer 1-6 twiddles, separated by whether they belong to the upper or lower half for i in range(2): - yield from gen_twiddles_many(1, i*(2**0), range(1), 16) - yield from gen_twiddles_many(2, i*(2**1), range(2), 8) - yield from gen_twiddles_many(3, i*(2**2), range(4), 4) - yield from gen_twiddles_many(4, i*(2**3), range(8), 2) - yield from gen_twiddles_many(5, i*(2**4), range(16), 1) - yield from gen_twiddles_many(6, i*(2**5), range(0,32,2), 1) - yield from gen_twiddles_many(6, i*(2**5), range(1,32,2), 1) + yield from gen_twiddles_many(1, i * (2**0), range(1), 16) + yield from gen_twiddles_many(2, i * (2**1), range(2), 8) + yield from gen_twiddles_many(3, i * (2**2), range(4), 4) + yield from gen_twiddles_many(4, i * (2**3), range(8), 2) + yield from gen_twiddles_many(5, i * (2**4), range(16), 1) + yield from gen_twiddles_many(6, i * (2**5), range(0, 32, 2), 1) + yield from gen_twiddles_many(6, i * (2**5), range(1, 32, 2), 1) + def gen_avx2_fwd_ntt_zeta_file(dry_run=False): def gen(): @@ -328,17 +410,21 @@ def gen(): yield "" yield from map(lambda t: str(t) + ",", gen_avx2_fwd_ntt_zetas()) yield "" - update_file("mlkem/native/x86_64/x86_64_zetas.i", '\n'.join(gen()), dry_run=dry_run) + + update_file("mlkem/native/x86_64/x86_64_zetas.i", "\n".join(gen()), dry_run=dry_run) + def _main(): parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument("--dry-run", default=False, action='store_true') + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--dry-run", default=False, action="store_true") args = parser.parse_args() gen_c_zeta_file(args.dry_run) gen_aarch64_fwd_ntt_zeta_file(args.dry_run) gen_avx2_fwd_ntt_zeta_file(args.dry_run) + if __name__ == "__main__": _main() diff --git a/scripts/lib/mlkem_test.py b/scripts/lib/mlkem_test.py index 9e82521e3..69c2032a3 100644 --- a/scripts/lib/mlkem_test.py +++ b/scripts/lib/mlkem_test.py @@ -709,19 +709,27 @@ def all(opt: bool): def cbmc(self, k): config_logger(self.verbose) + def run_cbmc(mlkem_k): envvars = {"MLKEM_K": mlkem_k} cpucount = os.cpu_count() p = subprocess.Popen( - ["python3", "run-cbmc-proofs.py", "--summarize", "--no-coverage", f"-j{cpucount}"], + [ + "python3", + "run-cbmc-proofs.py", + "--summarize", + "--no-coverage", + f"-j{cpucount}", + ], cwd="cbmc/proofs", env=os.environ.copy() | envvars, ) p.communicate() assert p.returncode == 0 + if k == "ALL": run_cbmc("2") run_cbmc("3") run_cbmc("4") - else: + else: run_cbmc(k) diff --git a/test/test_bounds.py b/test/test_bounds.py index 7da75716d..afb37d6c7 100644 --- a/test/test_bounds.py +++ b/test/test_bounds.py @@ -11,16 +11,18 @@ from functools import lru_cache # Global constants -R = 2**16 -Q = 3329 +R = 2**16 +Q = 3329 Qinv = pow(-Q, -1, R) # # Barrett multiplication via doubling # + def round_even(x): - return 2*round(x/2) + return 2 * round(x / 2) + @lru_cache(maxsize=None) def barrett_twiddle(b): @@ -28,69 +30,81 @@ def barrett_twiddle(b): via doubling-high-multiply.""" return round_even(b * R / Q) // 2 -def sqrdmulh_i16(a,b): + +def sqrdmulh_i16(a, b): """Doubling multiply high with rouding""" # We cannot use round() here because of its behaviour # on multiples of 0.5: round(-.5) = round(0.5) = 0 - return ((2 * a * b + 2**15) // 2**16) + return (2 * a * b + 2**15) // 2**16 + def barmul(a, b): """Compute doubling Barrett multiplication of a and b""" b_twiddle = barrett_twiddle(b) return a * b - Q * sqrdmulh_i16(a, b_twiddle) + # # Montgomery multiplication # + def lift_signed_i16(x): """Returns signed canonical representative modulo R=2^16.""" x = x % R - if x >= R//2: + if x >= R // 2: x -= R return x + @lru_cache(maxsize=None) def montmul_neg_twiddle(b): return (b * Qinv) % R + def montmul_neg(a, b): b_twiddle = montmul_neg_twiddle(b) return (a * b + Q * lift_signed_i16(a * b_twiddle)) // R + # # Generic test functions # + def test_all_i16(f): - for a in range(-R//2, R//2): + for a in range(-R // 2, R // 2): if a % 1000 == 0: print(f"{a} ...") - for b in range(-Q//2,Q//2): - f(a,b) - f(-a,b) + for b in range(-Q // 2, Q // 2): + f(a, b) + f(-a, b) + -def test_random(f,num_tests=10000000, bound=2*R): +def test_random(f, num_tests=10000000, bound=2 * R): print(f"Randomly checking Barrett<->Montgomery relation ({num_tests} tests)...") for i in range(num_tests): if i % 100000 == 0: print(f"... run {i} tests ({((i * 1000) // num_tests)/10}%)") a = random.randrange(-bound, bound) b = random.randrange(-bound, bound) - f(a,b) + f(a, b) + # # Test relation between Barrett and Montgomery multiplication # (Proposition 1 in https://eprint.iacr.org/2021/986.pdf) # + @lru_cache(maxsize=None) def modq_even(a): - return a - Q * round_even(a/Q) + return a - Q * round_even(a / Q) + -def barmul_test(a,b): +def barmul_test(a, b): bp = modq_even(b * R) - r0 = barmul(a,b) + r0 = barmul(a, b) r1 = montmul_neg(a, bp) if r0 != r1: print(f"barmul test failure for {a,b}!") @@ -98,12 +112,15 @@ def barmul_test(a,b): print(f"Montgomery multiplication: {r1} (factor {bp})") assert False + def bar_mont_test_all_i16(): test_all_i16(barmul_test) + def bar_test_random(): test_random(barmul_test) + # # Test bound on Barrett multiplication # @@ -112,36 +129,41 @@ def bar_test_random(): # where 0.0508 appears as a close upper boun for Q/2**16. # -def bar_bound_test(a,b, max_scale=[]): + +def bar_bound_test(a, b, max_scale=[]): if a == 0: return - C = abs(a)/Q - ab = barmul(a,b) + C = abs(a) / Q + ab = barmul(a, b) Cp = abs(ab) / Q - scale_bound = 0.0508 # Upper bound to Q/2**16 - scale = (Cp - 1/2) / C + scale_bound = 0.0508 # Upper bound to Q/2**16 + scale = (Cp - 1 / 2) / C if len(max_scale) == 0 or scale > max_scale[-1]: max_scale.append(scale) print(f"New scale bound for {(a,b)}: {scale}") - if Cp >= scale_bound * C + 1/2: + if Cp >= scale_bound * C + 1 / 2: print(f"bar bound test failure for (a,b)={(a,b)}") print(f"barmul(a,b): {ab}") print(f"C (=a/q): {C}") print(f"Cp (=barmul(a,b)/q): {Cp}") assert False + def bar_bound_test_all_i16(): test_all_i16(bar_bound_test) + # # NTT bounds progression # -def funciter(f,n,x): + +def funciter(f, n, x): """Compute f^n(x)""" if n == 0: return x - return funciter(f, n-1, f(x)) + return funciter(f, n - 1, f(x)) + def ntt_layer_bound_growth(factor): """If the inputs to a CT-based layer of the NTT are bound by C*Q, @@ -151,12 +173,14 @@ def ntt_layer_bound_growth(factor): # Each coefficient is replaced by a +- t*b where a,b are input coefficients, # t is a suitable twiddle, and * is Barrett multiplication. a is thus bound # by C*q, while t*b is bound by 0.0508*C + 1/2 (see above). - return (lambda C: C + factor * C + 1/2) + return lambda C: C + factor * C + 1 / 2 + def ntt_layer_bound(factor, layers=7, initial=Q): """Returns a bound on the absolute value of coefficients after a fixed number of layers, assuming an initial absolute bound `initial`.""" - return funciter(ntt_layer_bound_growth(factor),layers, initial/Q) + return funciter(ntt_layer_bound_growth(factor), layers, initial / Q) + barmul_ntt_layer_bound = ntt_layer_bound(0.0508) montmul_ntt_layer_bound = ntt_layer_bound(0.0204)