-
Notifications
You must be signed in to change notification settings - Fork 0
/
standalone_sha3.mpc
75 lines (52 loc) · 2.94 KB
/
standalone_sha3.mpc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from Compiler.script_utils import output_utils
from Compiler.script_utils.data import data
from Compiler import ml
from Compiler import library
from Compiler.script_utils.audit import rand_smoothing
from Compiler.script_utils import config, timers, input_consistency
from Compiler.script_utils.consistency_cerebro import compute_commitment
class Sha3StandaloneConfig(config.BaseAuditModel):
compute_input: bool = True # whether to compute input or output
program.options_from_args()
cfg = config.from_program_args(program.args, Sha3StandaloneConfig)
program.use_trunc_pr = cfg.trunc_pr
sfix.round_nearest = cfg.round_nearest
ml.set_n_threads(cfg.n_threads)
train_dataset, _, _ = data._load_dataset_args(cfg.dataset)
n_players = len(train_dataset)
if cfg.compute_input:
sha_fun = input_consistency.compute_sha3_inner(cfg.sha3_approx_factor)
library.start_timer(timer_id=timers.TIMER_INPUT_CONSISTENCY_CHECK)
for player_id in range(0, n_players):
objects = input_consistency.read_input_format_from_file(player_id)
print("Player", player_id, "has", len(objects), "objects")
for object in objects:
len_items = sum([item["length"] for item in object["items"]])
print("Computing commitment of length ", len_items)
print_ln("Computing commitment for player %s with size %s", player_id, len_items)
input_comm = Array(len_items, sint)
# Should load input from secret shares for correctness,
# we dont to make it faster (since we would have already loaded the data).
library.stop_timer(timer_id=timers.TIMER_INPUT_CONSISTENCY_CHECK)
input_comm.assign_all(2)
library.start_timer(timer_id=timers.TIMER_INPUT_CONSISTENCY_CHECK)
sha_fun(input_comm, None, cfg.n_threads)
library.stop_timer(timer_id=timers.TIMER_INPUT_CONSISTENCY_CHECK)
else:
sha_fun = input_consistency.compute_sha3_inner(cfg.sha3_approx_factor,
timer_bit_decompose=timers.TIMER_OUTPUT_CONSISTENCY_SHA_BIT_DECOMPOSE,
timer_hash_variable=timers.TIMER_OUTPUT_CONSISTENCY_SHA_HASH_VARIABLE)
library.start_timer(timer_id=timers.TIMER_OUTPUT_COMMIT)
objects = input_consistency.read_output_format_from_file()
for object in objects:
print_ln("Object type %s of length %s", object["object_type"], object["length"])
len_items = object["length"]
input_comm = Array(len_items, sint)
# Should load input from secret shares for correctness,
# we dont to make it faster (since we would have already loaded the data).
library.stop_timer(timer_id=timers.TIMER_OUTPUT_COMMIT)
input_comm.assign_all(2)
library.start_timer(timer_id=timers.TIMER_OUTPUT_COMMIT)
sha_fun(input_comm, None, cfg.n_threads)
library.stop_timer(timer_id=timers.TIMER_OUTPUT_COMMIT)
print_ln("Done computing commitments!")