diff --git a/.gitignore b/.gitignore index dcafc72..d0ce31f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ scripts/cluster magebench-venv cluster.json __pycache__ +ckks_keys +logs diff --git a/experiment.py b/experiment.py new file mode 100644 index 0000000..2aeab31 --- /dev/null +++ b/experiment.py @@ -0,0 +1,93 @@ +import time +import remote + +def party_from_global_id(cluster, global_id): + if global_id < len(cluster.machines) // 2: + return 0 # evaluator + else: + return 1 # garbler + +def clear_memory_caches(cluster, worker_ids): + cluster.for_each_concurrently(lambda machine, id: remote.exec_sync(machine.public_ip_address, "sudo swapoff -a; sudo sync; echo 3 | sudo tee /proc/sys/vm/drop_caches"), worker_ids) + +def run_lan_experiment(cluster, problem_name, problem_size, protocol, scenario, worker_ids, log_name = "/dev/null", generate_fresh_input = True, generate_fresh_memprog = True): + if protocol == "halfgates": + assert len(worker_ids) % 2 == 0 + workers_per_party = len(worker_ids) // 2 + elif protocol == "ckks": + workers_per_party = len(worker_ids) + else: + raise RuntimeError("Unknown protocol {0}".format(protocol)) + + program_name = "{0}_{1}".format(problem_name, problem_size) + config_file = "~/config/{0}/config_{1}_{2}.yaml".format("1gb" if scenario == "mage" else "unbounded", protocol, workers_per_party) + + if isinstance(log_name, int): + log_name = program_name + "_t{0}".format(log_name) + elif log_name is None: + log_name = program_name + elif not isinstance(log_name, str): + raise RuntimeError("log_name must be a string (got {0})".format(repr(log_name))) + + def generate_input(machine, global_id): + remote.exec_script(machine.public_ip_address, "./scripts/generate_input.sh", "{0} {1} {2} {3} {4}".format(problem_name, problem_size, protocol, global_id % workers_per_party, workers_per_party)) + + if generate_fresh_input: + cluster.for_each_concurrently(generate_input, worker_ids) + + def generate_memprog(machine, global_id): + party = party_from_global_id(cluster, global_id) + local_id = global_id % workers_per_party + remote.exec_script(machine.public_ip_address, "./scripts/generate_memprog.sh", "{0} {1} {2} {3} {4} {5} {6}".format(problem_name, problem_size, protocol, config_file, party, local_id, log_name)) + + if generate_fresh_memprog: + cluster.for_each_concurrently(generate_memprog, worker_ids) + + def run_mage(machine, global_id): + party = party_from_global_id(cluster, global_id) + local_id = global_id % workers_per_party + if party == 1: + time.sleep(30) # Wait for the evaluator to start first + remote.exec_script(machine.public_ip_address, "./scripts/run_mage.sh", "{0} {1} {2} {3} {4} {5} {6} {7}".format(scenario, protocol, config_file, party, local_id, program_name, log_name, "true")) + + if protocol != "ckks": + time.sleep(70) # Wait for TIME-WAIT state to expire + clear_memory_caches(cluster, worker_ids) + cluster.for_each_concurrently(run_mage, worker_ids) + +def run_halfgates_baseline_experiment(cluster, problem_size, scenario, worker_ids, log_name = "/dev/null"): + assert len(worker_ids) == 2 + if not isinstance(log_name, str): + raise RuntimeError("log_name must be a string (got {0})".format(repr(log_name))) + + def run_halfgates_baseline(machine, global_id): + party = party_from_global_id(cluster, global_id) + if party == 1: + time.sleep(30) + if global_id == worker_ids[0]: + other_worker_id = worker_ids[1] + else: + assert global_id == worker_ids[1] + other_worker_id = worker_ids[0] + remote.exec_script(machine.public_ip_address, "./scripts/run_halfgates_baseline.sh", "{0} {1} {2} {3} {4}".format(scenario, party, problem_size, cluster.machines[other_worker_id].private_ip_address, log_name)) + + time.sleep(70) + clear_memory_caches(cluster, worker_ids) + cluster.for_each_concurrently(run_halfgates_baseline, worker_ids) + +def run_ckks_baseline_experiment(cluster, problem_size, scenario, worker_ids, log_name = "/dev/null", generate_fresh_input = True): + if not isinstance(log_name, str): + raise RuntimeError("log_name must be a string (got {0})".format(repr(log_name))) + + def generate_input(machine, global_id): + remote.exec_script(machine.public_ip_address, "./scripts/generate_input.sh", "{0} {1} {2} {3} {4}".format("real_statistics", problem_size, "ckks", 0, 1)) + + if generate_fresh_input: + cluster.for_each_concurrently(generate_input, worker_ids) + + def run_ckks_baseline(machine, global_id): + remote.exec_script(machine.public_ip_address, "./scripts/run_ckks_baseline.sh", "{0} {1} {2}".format(scenario, problem_size, log_name)) + + # time.sleep(70) + clear_memory_caches(cluster, worker_ids) + cluster.for_each_concurrently(run_ckks_baseline, worker_ids) diff --git a/magebench.py b/magebench.py index 1c54b7e..7fa5aad 100755 --- a/magebench.py +++ b/magebench.py @@ -1,20 +1,14 @@ #!/usr/bin/env python3 +import argparse import os +import shutil import sys import time import cluster +import experiment import remote -CLUSTER_NAME = "mage-cluster" -CLUSTER_SIZE = 2 - -def party_from_global_id(cluster, global_id): - if global_id < len(cluster.machines) // 2: - return 0 # evaluator - else: - return 1 # garbler - def validate_protocol(protocol): protocol = protocol.lower() if protocol not in ("halfgates", "ckks"): @@ -28,100 +22,215 @@ def validate_scenario(scenario): print("Scenario must be unbounded, os, or mage (got {0})".format(scenario)) return scenario -def provision(machine, id): +def provision_machine(machine, id): remote.exec_script(machine.public_ip_address, "./scripts/provision.sh") remote.copy_to(machine.public_ip_address, False, "./cluster.json", "~") remote.exec_script(machine.public_ip_address, "./scripts/generate_configs.py", "~/cluster.json 0 ~/config") -def run_lan_experiment(cluster, problem_name, problem_size, protocol, scenario, worker_ids, log_name = "/dev/null", generate_fresh_input = True, generate_fresh_memprog = True): - if protocol == "halfgates": - assert len(worker_ids) % 2 == 0 - workers_per_party = len(worker_ids) // 2 - elif protocol == "ckks": - workers_per_party = len(worker_ids) - else: - raise RuntimeError("Unknown protocol {0}".format(protocol)) - - program_name = "{0}_{1}".format(problem_name, problem_size) - config_file = "~/config/{0}/config_{1}_{2}.yaml".format("1gb" if scenario == "mage" else "unbounded", protocol, workers_per_party) - - if isinstance(log_name, int): - log_name = program_name + "_t{0}".format(log_name) - elif log_name is None: - log_name = program_name - elif not isinstance(log_name, str): - raise RuntimeError("log_name must be a string (got {0})".format(repr(log_name))) - - def generate_input(machine, global_id): - remote.exec_script(machine.public_ip_address, "./scripts/generate_input.sh", "{0} {1} {2} {3} {4}".format(problem_name, problem_size, protocol, global_id % workers_per_party, workers_per_party)) - - if generate_fresh_input: - cluster.for_each_concurrently(generate_input, worker_ids) - - def generate_memprog(machine, global_id): - party = party_from_global_id(cluster, global_id) - local_id = global_id % workers_per_party - remote.exec_script(machine.public_ip_address, "./scripts/generate_memprog.sh", "{0} {1} {2} {3} {4} {5} {6}".format(problem_name, problem_size, protocol, config_file, party, local_id, log_name)) - - if generate_fresh_memprog: - cluster.for_each_concurrently(generate_memprog, worker_ids) - - def run_mage(machine, global_id): - party = party_from_global_id(cluster, global_id) - local_id = global_id % workers_per_party - if party == 1: - time.sleep(30) # Wait for the evaluator to start first - remote.exec_script(machine.public_ip_address, "./scripts/run_mage.sh", "{0} {1} {2} {3} {4} {5} {6} {7}".format(scenario, protocol, config_file, party, local_id, program_name, log_name, "true")) - - time.sleep(70) # Wait for TIME-WAIT state to expire - cluster.for_each_concurrently(run_mage, worker_ids) +def copy_ckks_keys(machine, id): + remote.copy_to(machine.public_ip_address, True, "./ckks_keys", "~") + remote.exec_sync(machine.public_ip_address, "cp ~/ckks_keys/* ~/work/mage/bin") + +def generate_ckks_keys(c): + shutil.rmtree("./ckks_keys", ignore_errors = True) + remote.exec_sync(c.machines[0].public_ip_address, "cd ~/work/mage/bin; ./ckks_utils keygen; mkdir -p ~/ckks_keys; cp *.ckks ~/ckks_keys") + try: + remote.copy_from(c.machines[0].public_ip_address, True, "~/ckks_keys", ".") + c.for_each_concurrently(copy_ckks_keys, range(1, len(c.machines))) + finally: + shutil.rmtree("./ckks_keys") + +def provision_cluster(c): + c.for_each_concurrently(provision_machine) + generate_ckks_keys(c) + +def logs_directory(id): + return os.path.join(".", "logs", "{0:02d}".format(id)) + +def fetch_logs_from(machine, id): + directory = logs_directory(id) + remote.copy_from(machine.public_ip_address, False, "~/logs/*", directory) + +def spawn(args): + if os.path.exists("cluster.json"): + print("Cluster already exists!") + print("To create a new cluster, first run \"{0} deallocate\"".format(sys.argv[0])) + sys.exit(1) + assert args.size > 0 + print("Spawning cluster...") + c = cluster.spawn(args.name, args.size) + c.save_to_file("cluster.json") + print("Waiting three minutes for the machines to start up...") + time.sleep(180) + print("Provisioning the machines...") + c.for_each_concurrently(provision_machine) + print("Done.") + +def provision(args): + c = cluster.Cluster.load_from_file("cluster.json") + print("Provisioning the machines...") + provision_cluster(c) + print("Done.") + +def parse_program(program): + try: + index = program.rfind("_") + if index == -1: + raise ValueError + problem_name = program[:index] + problem_size = program[index + 1:] + return (problem_name, problem_size) + except ValueError: + print("Invalid program name (must be of the form _): {0}".format(program)) + sys.exit(2) +def parse_program_list(program_list): + result = [] + for program in program_list: + result.append(parse_program(program)) + return result -if __name__ == "__main__": - if len(sys.argv) < 2: - print("Usage: {0} spawn/provision/run/deallocate".format(sys.argv[0])) +def run(args): + c = cluster.Cluster.load_from_file("cluster.json") + if len(sys.argv) not in (6, 7): + print("Usage: {0} problem_name problem_size protocol scenario [log_tag]".format(sys.argv[0])) sys.exit(2) - - if sys.argv[1] == "spawn": - if os.path.exists("cluster.json"): - print("Cluster already exists!") - print("To create a new cluster, first run \"{0} deallocate\"".format(sys.argv[0])) - sys.exit(1) - print("Spawning cluster...") - c = cluster.spawn(CLUSTER_NAME, CLUSTER_SIZE) - c.save_to_file("cluster.json") - print("Waiting three minutes for the machines to start up...") - time.sleep(180) - print("Provisioning the machines...") - c.for_each_concurrently(provision) - print("Done.") - elif sys.argv[1] == "provision": - c = cluster.Cluster.load_from_file("cluster.json") - print("Provisioning the machines...") - c.for_each_concurrently(provision) - print("Done.") - elif sys.argv[1] == "run": - c = cluster.Cluster.load_from_file("cluster.json") - if len(sys.argv) not in (6, 7): - print("Usage: {0} problem_name problem_size protocol scenario [log_tag]".format(sys.argv[0])) - sys.exit(2) - problem_name = sys.argv[2] - problem_size = int(sys.argv[3]) - protocol = validate_protocol(sys.argv[4]) - scenario = validate_scenario(sys.argv[5]) - worker_ids = (0, 1) - log_name = "{0}_{1}_{2}".format(problem_name, problem_size, scenario) - if len(sys.argv) >= 7: - log_name += "_{0}".format(sys.argv[6]) - run_lan_experiment(c, problem_name, problem_size, protocol, scenario, worker_ids, log_name) - elif sys.argv[1] == "deallocate": - print("Deallocating cluster...") - cluster.deallocate(CLUSTER_NAME, CLUSTER_SIZE) - try: - os.remove("cluster.json") - except os.FileNotFoundError: - pass - print("Done.") + problem_name, problem_size = parse_program(args.program) + if problem_name.startswith("real"): + protocol = "ckks" + worker_ids = (0,) else: - print("Unknown command \"{0}\"".format(sys.argv[1])) - sys.exit(2) + protocol = "halfgates" + worker_ids = (0, 1) + scenario = args.scenario + log_name = "{0}_{1}_{2}".format(problem_name, problem_size, scenario) + if args.tag is not None: + log_name += "_{0}".format(args.tag) + experiment.run_lan_experiment(c, problem_name, problem_size, protocol, scenario, worker_ids, log_name) + +def run_single_core_experiments(args): + if args.programs is None: + args.programs = ("merge_sorted_1048576", "full_sort_1048576", "loop_join_2048", "matrix_vector_multiply_8192", "binary_fc_layer_16384", "real_sum_65536", "real_statistics_16384", "real_matrix_vector_multiply_256", "real_naive_matrix_multiply_128", "real_tiled_matrix_multiply_128") + if args.scenarios is None: + args.scenarios = ("mage", "unbounded", "os") + + c = cluster.Cluster.load_from_file("cluster.json") + parsed_programs = parse_program_list(args.programs) + for problem_name, problem_size in parsed_programs: + if problem_name.startswith("real"): + protocol = "ckks" + worker_ids = (0,) + else: + protocol = "halfgates" + worker_ids = (0, 1) + for trial in range(1, args.trials + 1): + for scenario in args.scenarios: + log_name = "single_machine_{0}_{1}_{2}_t{3}".format(problem_name, problem_size, scenario, trial) + experiment.run_lan_experiment(c, problem_name, problem_size, protocol, scenario, worker_ids, log_name) + + +def run_halfgates_baseline(args): + if args.sizes is None: + args.sizes = tuple(2 ** i for i in range(10, 21)) + if args.scenarios is None: + args.scenarios = ("mage", "unbounded", "os", "emp") + + c = cluster.Cluster.load_from_file("cluster.json") + assert len(c.machines) == 2 + + problem_name = "merge_sorted" + protocol = "halfgates" + worker_ids = (0, 1) + + for problem_size in args.sizes: + for trial in range(1, args.trials + 1): + for scenario in args.scenarios: + log_name = "halfgates_baseline_{0}_{1}_{2}_t{3}".format(problem_name, problem_size, scenario, trial) + if scenario == "emp": + experiment.run_halfgates_baseline_experiment(c, problem_size, "os", worker_ids, log_name) + else: + experiment.run_lan_experiment(c, problem_name, problem_size, protocol, scenario, worker_ids, log_name) + +def run_ckks_baseline(args): + if args.sizes is None: + args.sizes = tuple(2 ** i for i in range(6, 15)) + if args.scenarios is None: + args.scenarios = ("mage", "unbounded", "os", "seal") + + c = cluster.Cluster.load_from_file("cluster.json") + assert len(c.machines) == 1 or len(c.machines) == 2 + + problem_name = "real_statistics" + protocol = "ckks" + worker_ids = (0,) + + for problem_size in args.sizes: + for trial in range(1, args.trials + 1): + for scenario in args.scenarios: + log_name = "ckks_baseline_{0}_{1}_{2}_t{3}".format(problem_name, problem_size, scenario, trial) + if scenario == "seal": + experiment.run_ckks_baseline_experiment(c, problem_size, "os", worker_ids, log_name) + else: + experiment.run_lan_experiment(c, problem_name, problem_size, protocol, scenario, worker_ids, log_name) + +def deallocate(args): + print("Deallocating cluster...") + cluster.deallocate(CLUSTER_NAME, CLUSTER_SIZE) + try: + os.remove("cluster.json") + except os.FileNotFoundError: + pass + print("Done.") + +def fetch_logs(args): + print("Fetching logs...") + c = cluster.Cluster.load_from_file("cluster.json") + for id in range(len(c.machines)): + os.makedirs(logs_directory(id), exist_ok = True) + c.for_each_concurrently(fetch_logs_from) + print("Done.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description = "Run benchmark experiments on MAGE.") + subparsers = parser.add_subparsers() + + parser_spawn = subparsers.add_parser("spawn") + parser_spawn.add_argument("-n", "--name", default = "mage-cluster") + parser_spawn.add_argument("-z", "--size", type = int, default = 2) + parser_spawn.set_defaults(func = spawn) + + parser_provision = subparsers.add_parser("provision") + parser_provision.set_defaults(func = provision) + + parser_run = subparsers.add_parser("run") + parser_run.add_argument("program") + parser_run.add_argument("scenario") + parser_run.add_argument("--tag") + parser_run.set_defaults(func = run) + + parser_run_sce = subparsers.add_parser("run-single-core-experiments") + parser_run_sce.add_argument("-p", "--programs", action = "extend", nargs = "+") + parser_run_sce.add_argument("-s", "--scenarios", action = "extend", nargs = "+", choices = ("unbounded", "mage", "os")) + parser_run_sce.add_argument("-t", "--trials", type = int, default = 1) + parser_run_sce.set_defaults(func = run_single_core_experiments) + + parser_run_hgb = subparsers.add_parser("run-halfgates-baseline") + parser_run_hgb.add_argument("-z", "--sizes", action = "extend", nargs = "+", type = int) + parser_run_hgb.add_argument("-s", "--scenarios", action = "extend", nargs = "+", choices = ("unbounded", "mage", "os", "emp")) + parser_run_hgb.add_argument("-t", "--trials", type = int, default = 1) + parser_run_hgb.set_defaults(func = run_halfgates_baseline) + + parser_run_ckb = subparsers.add_parser("run-ckks-baseline") + parser_run_ckb.add_argument("-z", "--sizes", action = "extend", nargs = "+", type = int) + parser_run_ckb.add_argument("-s", "--scenarios", action = "extend", nargs = "+", choices = ("unbounded", "mage", "os", "seal")) + parser_run_ckb.add_argument("-t", "--trials", type = int, default = 1) + parser_run_ckb.set_defaults(func = run_ckks_baseline) + + parser_deallocate = subparsers.add_parser("deallocate") + parser_deallocate.set_defaults(func = deallocate) + + parser_fetch_logs = subparsers.add_parser("fetch-logs") + parser_fetch_logs.set_defaults(func = fetch_logs) + + args = parser.parse_args() + args.func(args) diff --git a/remote.py b/remote.py index 8274557..5c7798f 100644 --- a/remote.py +++ b/remote.py @@ -25,15 +25,15 @@ def copy_to(ip_address, directory, local_location, remote_location = "~"): assert remote_location.strip() != "" command = ("scp", "-o", "StrictHostKeyChecking=no", "-i", "mage", local_location, "mage@{0}:{1}".format(ip_address, remote_location)) if directory: - command = ("scp", "-o", "StrictHostKeyChecking=no", "-r") + command[1:] + command = ("scp", "-o", "StrictHostKeyChecking=no", "-r") + command[3:] subprocess.run(command, check = True) -def copy_from(ip_address, directory, remote_location, location_location = "."): +def copy_from(ip_address, directory, remote_location, local_location = "."): assert local_location.strip() != "" assert remote_location.strip() != "" command = ("scp", "-o", "StrictHostKeyChecking=no", "-i", "mage", "mage@{0}:{1}".format(ip_address, remote_location), local_location) if directory: - command = ("scp", "-o", "StrictHostKeyChecking=no", "-r") + command[1:] + command = ("scp", "-o", "StrictHostKeyChecking=no", "-r") + command[3:] subprocess.run(command, check = True) def exec_script(ip_address, local_location, args = "", sync = True): diff --git a/scripts/run_ckks_baseline.sh b/scripts/run_ckks_baseline.sh new file mode 100755 index 0000000..bab78dd --- /dev/null +++ b/scripts/run_ckks_baseline.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash + +set -x + +SCENARIO=$1 +PROBLEM_SIZE=$2 +LOG_NAME=$3 + +if [[ -z $LOG_NAME ]] +then + echo "Usage:" $0 "scenario problem_size log_file_name" + exit +fi + +PROBLEM_NAME=real_statistics +PROGRAM=${PROBLEM_NAME}_${PROBLEM_SIZE} +WORKER=0 + +pushd ~/work/mage/bin + +PREFIX= +if [[ $SCENARIO = "unbounded" ]] +then + sudo swapoff -a + PREFIX="sudo" +elif [[ $SCENARIO = "os" ]] +then + sudo swapoff -a + sudo swapon /dev/disk/cloud/azure_resource-part2 + PREFIX="sudo cgexec -g memory:memprog1gb" +else + echo "Unknown/unsupported scenario" $SCENARIO + exit 2 +fi + +sudo free +sudo sync +echo 3 | sudo tee /proc/sys/vm/drop_caches +sudo free + +$PREFIX ./ckks_utils $PROBLEM_NAME $PROBLEM_SIZE ${PROGRAM}_${WORKER}_garbler.input ${PROGRAM}_${WORKER}.output > ~/logs/${LOG_NAME}.log + +if [[ $CHECK_RESULT = true ]] +then + ./ckks_utils decrypt_file 1 ${PROGRAM}_${WORKER}.output + ./ckks_utils float_file_decode ${PROGRAM}_${WORKER}.output > decoded.output + ./ckks_utils float_file_decode ${PROGRAM}_${WORKER}.expected > expected.output + diff decoded.output expected.output > ~/logs/${LOG_NAME}.result +fi diff --git a/scripts/run_halfgates_baseline.sh b/scripts/run_halfgates_baseline.sh new file mode 100755 index 0000000..0ef7fe4 --- /dev/null +++ b/scripts/run_halfgates_baseline.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +set -x + +SCENARIO=$1 +PARTY=$2 +PROBLEM_SIZE=$3 +OTHER_IP=$4 +LOG_NAME=$5 + +if [[ -z $LOG_NAME ]] +then + echo "Usage:" $0 "scenario party_id log_file_name" + exit 2 +fi + +pushd ~/work/emp-sh2pc/bin + +PREFIX= +if [[ $SCENARIO = "unbounded" ]] +then + sudo swapoff -a + PREFIX="sudo" +elif [[ $SCENARIO = "os" ]] +then + sudo swapoff -a + sudo swapon /dev/disk/cloud/azure_resource-part2 + PREFIX="sudo cgexec -g memory:memprog1gb" +else + echo "Unknown/unsupported scenario" $SCENARIO + exit 2 +fi + +sudo free +sudo sync +echo 3 | sudo tee /proc/sys/vm/drop_caches +sudo free + +$PREFIX ./merge_sorted $PARTY 50000 $PROBLEM_SIZE $OTHER_IP > ~/logs/${LOG_NAME}.log