diff --git a/.gitignore b/.gitignore index bf7673e..16eec57 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,12 @@ /server /server/* -/opt \ No newline at end of file +/opt +/local +/cmake-3.24.1 +/bin +/deps +/doc +/logs +/venv +/Programs diff --git a/BMR/AndJob.cpp b/BMR/AndJob.cpp old mode 100755 new mode 100644 diff --git a/BMR/AndJob.h b/BMR/AndJob.h old mode 100755 new mode 100644 diff --git a/BMR/BooleanCircuit.h b/BMR/BooleanCircuit.h old mode 100755 new mode 100644 diff --git a/BMR/CommonParty.cpp b/BMR/CommonParty.cpp old mode 100755 new mode 100644 diff --git a/BMR/CommonParty.h b/BMR/CommonParty.h old mode 100755 new mode 100644 diff --git a/BMR/CommonParty.hpp b/BMR/CommonParty.hpp old mode 100755 new mode 100644 diff --git a/BMR/GarbledGate.cpp b/BMR/GarbledGate.cpp old mode 100755 new mode 100644 diff --git a/BMR/GarbledGate.h b/BMR/GarbledGate.h old mode 100755 new mode 100644 diff --git a/BMR/Gate.h b/BMR/Gate.h old mode 100755 new mode 100644 diff --git a/BMR/Key.cpp b/BMR/Key.cpp old mode 100755 new mode 100644 diff --git a/BMR/Key.h b/BMR/Key.h old mode 100755 new mode 100644 diff --git a/BMR/Party.cpp b/BMR/Party.cpp old mode 100755 new mode 100644 index beddd64..0fe11a0 --- a/BMR/Party.cpp +++ b/BMR/Party.cpp @@ -249,6 +249,7 @@ FakeProgramParty::FakeProgramParty(int argc, const char** argv) : } cout << "Compiler: " << prev << endl; P = new PlainPlayer(N, 0); + Share::MAC_Check::setup(*P); if (argc > 4) threshold = atoi(argv[4]); cout << "Threshold for multi-threaded evaluation: " << threshold << endl; @@ -280,6 +281,7 @@ FakeProgramParty::~FakeProgramParty() cerr << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity_in_bytes() << " GB" << endl; #endif + Share::MAC_Check::teardown(); } void FakeProgramParty::_compute_prfs_outputs(Key* keys) diff --git a/BMR/Party.h b/BMR/Party.h old mode 100755 new mode 100644 diff --git a/BMR/ProgramParty.hpp b/BMR/ProgramParty.hpp old mode 100755 new mode 100644 diff --git a/BMR/RealGarbleWire.h b/BMR/RealGarbleWire.h old mode 100755 new mode 100644 index 9fa2dc5..115d0bc --- a/BMR/RealGarbleWire.h +++ b/BMR/RealGarbleWire.h @@ -48,8 +48,6 @@ class RealGarbleWire : public PRFRegister static void inputbvec(GC::Processor>& processor, ProcessorBase& input_processor, const vector& args); - RealGarbleWire(const Register& reg) : PRFRegister(reg) {} - void garble(PRFOutputs& prf_output, const RealGarbleWire& left, const RealGarbleWire& right); diff --git a/BMR/RealGarbleWire.hpp b/BMR/RealGarbleWire.hpp old mode 100755 new mode 100644 index 760a20b..c9e31fc --- a/BMR/RealGarbleWire.hpp +++ b/BMR/RealGarbleWire.hpp @@ -110,7 +110,7 @@ void RealGarbleWire::inputbvec( { GarbleInputter inputter; processor.inputbvec(inputter, input_processor, args, - inputter.party.P->my_num()); + *inputter.party.P); } template diff --git a/BMR/RealProgramParty.h b/BMR/RealProgramParty.h old mode 100755 new mode 100644 diff --git a/BMR/RealProgramParty.hpp b/BMR/RealProgramParty.hpp old mode 100755 new mode 100644 index ae69cb7..4213941 --- a/BMR/RealProgramParty.hpp +++ b/BMR/RealProgramParty.hpp @@ -64,7 +64,6 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : online_opts = {opt, argc, argv, 1000}; else online_opts = {opt, argc, argv}; - assert(not online_opts.interactive); online_opts.finalize(opt, argc, argv); this->load(online_opts.progname); @@ -97,8 +96,6 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : if (online_opts.live_prep) { mac_key.randomize(prng); - if (T::needs_ot) - BaseMachine::s().ot_setups.push_back({*P, true}); prep = new typename T::LivePrep(0, usage); } else @@ -107,10 +104,12 @@ RealProgramParty::RealProgramParty(int argc, const char** argv) : prep = new Sub_Data_Files(N, prep_dir, usage); } + T::MAC_Check::setup(*P); MC = new typename T::MAC_Check(mac_key); garble_processor.reset(program); - this->processor.open_input_file(N.my_num(), 0); + this->processor.open_input_file(N.my_num(), 0, online_opts.cmd_private_input_file); + this->processor.setup_redirection(P->my_num(), 0, online_opts, this->processor.out); shared_proc = new SubProcessor(dummy_proc, *MC, *prep, *P); @@ -218,6 +217,7 @@ RealProgramParty::~RealProgramParty() delete garble_inputter; delete garble_protocol; cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl; + T::MAC_Check::teardown(); } template diff --git a/BMR/Register.cpp b/BMR/Register.cpp old mode 100755 new mode 100644 diff --git a/BMR/Register.h b/BMR/Register.h old mode 100755 new mode 100644 index f348f7b..4def659 --- a/BMR/Register.h +++ b/BMR/Register.h @@ -152,7 +152,7 @@ class Register { * for pipelining matters. */ - Register(int n_parties); + Register(); void init(int n_parties); void init(int rfd, int n_parties); @@ -235,6 +235,9 @@ class Phase template static void ands(T& processor, const vector& args) { processor.ands(args); } template + static void andrsvec(T& processor, const vector& args) + { processor.andrsvec(args); } + template static void xors(T& processor, const vector& args) { processor.xors(args); } template static void inputb(T& processor, const vector& args) { processor.input(args); } @@ -278,10 +281,6 @@ class ProgramRegister : public Phase, public Register static int threshold(int) { throw not_implemented(); } - static Register new_reg(); - static Register tmp_reg() { return new_reg(); } - static Register and_reg() { return new_reg(); } - template static void store(NoMemory& dest, const vector >& accesses) { (void)dest; (void)accesses; } @@ -306,8 +305,6 @@ class ProgramRegister : public Phase, public Register void other_input(Input&, int) {} char get_output() { return 0; } - - ProgramRegister(const Register& reg) : Register(reg) {} }; class PRFRegister : public ProgramRegister @@ -319,8 +316,6 @@ class PRFRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - PRFRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const PRFRegister& left, const PRFRegister& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char input = -1); @@ -396,8 +391,6 @@ class EvalRegister : public ProgramRegister static void convcbit(Integer& dest, const GC::Clear& source, GC::Processor>& proc); - EvalRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const ProgramRegister& left, const ProgramRegister& right, Function func); void XOR(const Register& left, const Register& right); @@ -427,8 +420,6 @@ class GarbleRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - GarbleRegister(const Register& reg) : ProgramRegister(reg) {} - void op(const Register& left, const Register& right, Function func); void XOR(const Register& left, const Register& right); void input(party_id_t from, char value = -1); @@ -452,8 +443,6 @@ class RandomRegister : public ProgramRegister static void load(vector >& accesses, const NoMemory& source); - RandomRegister(const Register& reg) : ProgramRegister(reg) {} - void randomize(); void op(const Register& left, const Register& right, Function func); @@ -469,12 +458,6 @@ class RandomRegister : public ProgramRegister }; -inline Register::Register(int n_parties) : - garbled_entry(n_parties), external(NO_SIGNAL), - mask(NO_SIGNAL), keys(n_parties) -{ -} - inline void KeyVector::operator=(const KeyVector& other) { resize(other.size()); diff --git a/BMR/Register.hpp b/BMR/Register.hpp old mode 100755 new mode 100644 index bd214a8..6179069 --- a/BMR/Register.hpp +++ b/BMR/Register.hpp @@ -14,15 +14,7 @@ void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor, const vector& args) { NoOpInputter inputter; - int my_num = -1; - try - { - my_num = ProgramParty::s().P->my_num(); - } - catch (exception&) - { - } - processor.inputbvec(inputter, input_processor, args, my_num); + processor.inputbvec(inputter, input_processor, args, *ProgramParty::s().P); } template @@ -31,7 +23,7 @@ void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor, { EvalInputter inputter; processor.inputbvec(inputter, input_processor, args, - ProgramParty::s().P->my_num()); + *ProgramParty::s().P); } template diff --git a/BMR/Register_inline.h b/BMR/Register_inline.h old mode 100755 new mode 100644 index 6a275da..7694c46 --- a/BMR/Register_inline.h +++ b/BMR/Register_inline.h @@ -9,10 +9,10 @@ #include "CommonParty.h" #include "Party.h" - -inline Register ProgramRegister::new_reg() +inline Register::Register() : + garbled_entry(CommonParty::s().get_n_parties()), external(NO_SIGNAL), + mask(NO_SIGNAL), keys(CommonParty::s().get_n_parties()) { - return Register(CommonParty::s().get_n_parties()); } #endif /* BMR_REGISTER_INLINE_H_ */ diff --git a/BMR/SpdzWire.h b/BMR/SpdzWire.h old mode 100755 new mode 100644 diff --git a/BMR/TrustedParty.cpp b/BMR/TrustedParty.cpp old mode 100755 new mode 100644 diff --git a/BMR/TrustedParty.h b/BMR/TrustedParty.h old mode 100755 new mode 100644 diff --git a/BMR/Wire.h b/BMR/Wire.h old mode 100755 new mode 100644 diff --git a/BMR/common.h b/BMR/common.h old mode 100755 new mode 100644 diff --git a/BMR/config.h b/BMR/config.h old mode 100755 new mode 100644 diff --git a/BMR/msg_types.cpp b/BMR/msg_types.cpp old mode 100755 new mode 100644 diff --git a/BMR/msg_types.h b/BMR/msg_types.h old mode 100755 new mode 100644 diff --git a/BMR/network/Client.cpp b/BMR/network/Client.cpp old mode 100755 new mode 100644 diff --git a/BMR/network/Client.h b/BMR/network/Client.h old mode 100755 new mode 100644 diff --git a/BMR/network/Node.cpp b/BMR/network/Node.cpp old mode 100755 new mode 100644 diff --git a/BMR/network/Node.h b/BMR/network/Node.h old mode 100755 new mode 100644 diff --git a/BMR/network/Server.cpp b/BMR/network/Server.cpp old mode 100755 new mode 100644 diff --git a/BMR/network/Server.h b/BMR/network/Server.h old mode 100755 new mode 100644 diff --git a/BMR/network/common.h b/BMR/network/common.h old mode 100755 new mode 100644 diff --git a/BMR/network/utils.cpp b/BMR/network/utils.cpp old mode 100755 new mode 100644 diff --git a/BMR/network/utils.h b/BMR/network/utils.h old mode 100755 new mode 100644 diff --git a/BMR/prf.h b/BMR/prf.h old mode 100755 new mode 100644 diff --git a/BMR/proto_utils.cpp b/BMR/proto_utils.cpp old mode 100755 new mode 100644 diff --git a/BMR/proto_utils.h b/BMR/proto_utils.h old mode 100755 new mode 100644 diff --git a/CHANGELOG.md b/CHANGELOG.md old mode 100755 new mode 100644 index 18cc92a..f201d46 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,13 +1,39 @@ The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here. -## 0.3.2 (Mai 27, 2022) +## 0.3.4 (Nov 9, 2022) + +- Decision tree learning +- Optimized oblivious shuffle in Rep3 +- Optimized daBit generation in Rep3 and semi-honest HE-based 2PC +- Optimized element-vector AND in SemiBin +- Optimized input protocol in Shamir-based protocols +- Square-root ORAM (@Quitlox) +- Improved ORAM in binary circuits +- UTF-8 outputs + +## 0.3.3 (Aug 25, 2022) + +- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate +- Fix security bug in MAC check when using multithreading +- Fix security bug to prevent selective failure attack by checking earlier +- Fix security bug in Mama: insufficient sacrifice. +- Inverse permutation (@Quitlox) +- Easier direct compilation (@eriktaubeneck) +- Generally allow element-vector operations +- Increase maximum register size to 2^54 +- Client example in Python +- Uniform base OTs across platforms +- Multithreaded base OT computation +- Faster random bit generation in two-player Semi(2k) + +## 0.3.2 (May 27, 2022) - Secure shuffling - O(n log n) radix sorting - Documented BGV encryption interface - Optimized matrix multiplication in dealer protocol - Fixed security bug in homomorphic encryption parameter generation -- Fixed Security bug in Temi matrix multiplication +- Fixed security bug in Temi matrix multiplication ## 0.3.1 (Apr 19, 2022) diff --git a/CONFIG b/CONFIG old mode 100755 new mode 100644 index cef15e0..b0de2dc --- a/CONFIG +++ b/CONFIG @@ -12,7 +12,7 @@ PREP_DIR = '-DPREP_DIR="Player-Data/"' SSL_DIR = '-DSSL_DIR="Player-Data/"' # set for SHE preprocessing (SPDZ and Overdrive) -USE_NTL = 0 +USE_NTL = 1 # set for using GF(2^128) # unset for GF(2^40) @@ -31,24 +31,21 @@ ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx ARCH = -march=native MACHINE := $(shell uname -m) +ARM := $(shell uname -m | grep x86; echo $$?) OS := $(shell uname -s) ifeq ($(MACHINE), x86_64) -# set this to 0 to avoid using AVX for OT ifeq ($(OS), Linux) -CHECK_AVX := $(shell grep -q avx /proc/cpuinfo; echo $$?) -ifeq ($(CHECK_AVX), 0) AVX_OT = 1 else AVX_OT = 0 endif else -AVX_OT = 1 -endif -else ARCH = AVX_OT = 0 endif +USE_KOS = 0 + # allow to set compiler in CONFIG.mine CXX = g++ @@ -70,8 +67,11 @@ endif # MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5 LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS) +LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib LDLIBS += -lboost_system -lssl -lcrypto +CFLAGS += -I./local/include + ifeq ($(USE_NTL),1) CFLAGS += -DUSE_NTL LDLIBS := -lntl $(LDLIBS) @@ -87,7 +87,7 @@ else BOOST = -lboost_thread $(MY_BOOST) endif -CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror +CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror CPPFLAGS = $(CFLAGS) LD = $(CXX) @@ -98,3 +98,9 @@ ifeq ($(USE_NTL),1) CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy endif endif + +ifeq ($(USE_KOS),1) +CFLAGS += -DUSE_KOS +else +CFLAGS += -std=c++17 +endif diff --git a/Compiler/GC/__init__.py b/Compiler/GC/__init__.py old mode 100755 new mode 100644 diff --git a/Compiler/GC/instructions.py b/Compiler/GC/instructions.py old mode 100755 new mode 100644 index e53b718..73a8af2 --- a/Compiler/GC/instructions.py +++ b/Compiler/GC/instructions.py @@ -13,6 +13,7 @@ import Compiler.tools as tools import collections import itertools +import math class SecretBitsAF(base.RegisterArgFormat): reg_type = 'sb' @@ -50,6 +51,7 @@ class ClearBitsAF(base.RegisterArgFormat): INPUTBVEC = 0x247, SPLIT = 0x248, CONVCBIT2S = 0x249, + ANDRSVEC = 0x24a, XORCBI = 0x210, BITDECC = 0x211, NOTCB = 0x212, @@ -155,6 +157,52 @@ class andrs(BinaryVectorInstruction): def add_usage(self, req_node): req_node.increment(('bit', 'triple'), sum(self.args[::4])) + req_node.increment(('bit', 'mixed'), + sum(int(math.ceil(x / 64)) for x in self.args[::4])) + +class andrsvec(base.VarArgsInstruction, base.Mergeable, + base.DynFormatInstruction): + """ Constant-vector AND of secret bit registers (vectorized version). + + :param: total number of arguments to follow (int) + :param: number of arguments to follow for one operation / + operation vector size plus three (int) + :param: vector size (int) + :param: result vector (sbit) + :param: (repeat)... + :param: constant operand (sbits) + :param: vector operand + :param: (repeat)... + :param: (repeat from number of arguments to follow for one operation)... + + """ + code = opcodes['ANDRSVEC'] + + def __init__(self, *args, **kwargs): + super(andrsvec, self).__init__(*args, **kwargs) + for i, n in self.bases(iter(self.args)): + size = self.args[i + 1] + for x in self.args[i + 2:i + n]: + assert x.n == size + + @classmethod + def dynamic_arg_format(cls, args): + yield 'int' + for i, n in cls.bases(args): + yield 'int' + n_args = (n - 3) // 2 + assert n_args > 0 + for j in range(n_args): + yield 'sbw' + for j in range(n_args + 1): + yield 'sb' + yield 'int' + + def add_usage(self, req_node): + for i, n in self.bases(iter(self.args)): + size = self.args[i + 1] + req_node.increment(('bit', 'triple'), size * (n - 3) // 2) + req_node.increment(('bit', 'mixed'), size) class ands(BinaryVectorInstruction): """ Bitwise AND of secret bit register vector. @@ -342,7 +390,8 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction): code = opcodes['STMCB'] arg_format = ['cb','long'] -class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): +class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy secret bit memory cell with run-time address to secret bit register. @@ -351,8 +400,10 @@ class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction): """ code = opcodes['LDMSBI'] arg_format = ['sbw','ci'] + direct = staticmethod(ldmsb) -class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): +class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy secret bit register to secret bit memory cell with run-time address. @@ -361,8 +412,10 @@ class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction): """ code = opcodes['STMSBI'] arg_format = ['sb','ci'] + direct = staticmethod(stmsb) -class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction): +class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy clear bit memory cell with run-time address to clear bit register. @@ -371,8 +424,10 @@ class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction): """ code = opcodes['LDMCBI'] arg_format = ['cbw','ci'] + direct = staticmethod(ldmcb) -class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction): +class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction, + base.IndirectMemoryInstruction): """ Copy clear bit register to clear bit memory cell with run-time address. @@ -381,6 +436,7 @@ class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction): """ code = opcodes['STMCBI'] arg_format = ['cb','ci'] + direct = staticmethod(stmcb) class ldmsdi(base.ReadMemoryInstruction): code = opcodes['LDMSDI'] @@ -597,6 +653,7 @@ def dynamic_arg_format(cls, args): for i, n in cls.bases(args): yield 'int' yield 'p' + assert n > 3 for j in range(n - 3): yield 'sbw' yield 'int' diff --git a/Compiler/GC/program.py b/Compiler/GC/program.py old mode 100755 new mode 100644 diff --git a/Compiler/GC/types.py b/Compiler/GC/types.py old mode 100755 new mode 100644 index fdd9872..f70ee64 --- a/Compiler/GC/types.py +++ b/Compiler/GC/types.py @@ -17,6 +17,7 @@ import Compiler.GC.instructions as inst import operator import math +import itertools from functools import reduce class bits(Tape.Register, _structure, _bit): @@ -58,12 +59,12 @@ def compose(cls, items, bit_length=1): @classmethod def bit_compose(cls, bits): bits = list(bits) - if len(bits) == 1: + if len(bits) == 1 and isinstance(bits[0], cls): return bits[0] bits = list(bits) for i in range(len(bits)): if util.is_constant(bits[i]): - bits[i] = sbit(bits[i]) + bits[i] = cls.bit_type(bits[i]) res = cls.new(n=len(bits)) if len(bits) <= cls.unit: cls.bitcom(res, *(sbit.conv(bit) for bit in bits)) @@ -172,7 +173,7 @@ def load_other(self, other): else: try: bits = other.bit_decompose() - bits = bits[:self.n] + [sbit(0)] * (self.n - len(bits)) + bits = bits[:self.n] + [self.bit_type(0)] * (self.n - len(bits)) other = self.bit_compose(bits) assert(isinstance(other, type(self))) assert(other.n == self.n) @@ -197,6 +198,8 @@ def __and__(self, other): return 0 elif self.is_long_one(other): return self + elif isinstance(other, _vec): + return other & other.from_vec([self]) else: return self._and(other) @read_mem_value @@ -235,6 +238,18 @@ def if_else(self, x, y): This will output 1. """ return result_conv(x, y)(self & (x ^ y) ^ y) + def zero_if_not(self, condition): + if util.is_constant(condition): + return self * condition + else: + return self * cbit.conv(condition) + def expand(self, length): + if self.n in (length, None): + return self + elif self.n == 1: + return self.get_type(length).bit_compose([self] * length) + else: + raise CompilerError('cannot expand from %s to %s' % (self.n, length)) class cbits(bits): """ Clear bits register. Helper type with limited functionality. """ @@ -246,6 +261,7 @@ class cbits(bits): bitdec = inst.bitdecc conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y)) conv_cint_vec = inst.convcintvec + mov = staticmethod(lambda x, y: inst.addcbi(x, y, 0)) @classmethod def bit_compose(cls, bits): return sum(bit << i for i, bit in enumerate(bits)) @@ -258,7 +274,13 @@ def conv_regint_by_bit(cls, n, res, other): def conv(cls, other): if isinstance(other, cbits) and cls.n != None and \ cls.n // cls.unit == other.n // cls.unit: - return other + if isinstance(other, cls): + return other + else: + res = cls() + for i in range(math.ceil(cls.n / cls.unit)): + cls.mov(res[i], other[i]) + return res else: return super(cbits, cls).conv(other) types = {} @@ -289,8 +311,15 @@ def clear_op(self, other, c_inst, ci_inst, op): return op(self, cbits(other)) __add__ = lambda self, other: \ self.clear_op(other, inst.addcb, inst.addcbi, operator.add) - __sub__ = lambda self, other: \ - self.clear_op(-other, inst.addcb, inst.addcbi, operator.add) + def __sub__(self, other): + try: + return self + -other + except: + return type(self)(regint(self) - regint(other)) + def __rsub__(self, other): + return type(self)(other - regint(self)) + def __neg__(self): + return type(self)(-regint(self)) def _xor(self, other): if isinstance(other, (sbits, sbitvec)): return NotImplemented @@ -490,6 +519,8 @@ def __mul__(self, other): if isinstance(other, int): return self.mul_int(other) try: + if (self.n, other.n) == (1, 1): + return self & other if min(self.n, other.n) != 1: raise NotImplementedError('high order multiplication') n = max(self.n, other.n) @@ -581,7 +612,15 @@ def trans(cls, rows): rows = list(rows) if len(rows) == 1 and rows[0].n <= rows[0].unit: return rows[0].bit_decompose() - n_columns = rows[0].n + for row in rows: + try: + n_columns = row.n + break + except: + pass + for i in range(len(rows)): + if util.is_zero(rows[i]): + rows[i] = cls.get_type(n_columns)(0) for row in rows: assert(row.n == n_columns) if n_columns == 1 and len(rows) <= cls.unit: @@ -605,7 +644,7 @@ def bit_adder(*args, **kwargs): def ripple_carry_adder(*args, **kwargs): return sbitint.ripple_carry_adder(*args, **kwargs) -class sbitvec(_vec): +class sbitvec(_vec, _bit): """ Vector of registers of secret bits, effectively a matrix of secret bits. This facilitates parallel arithmetic operations in binary circuits. Container types are not supported, use :py:obj:`sbitvec.get_type` for that. @@ -613,7 +652,7 @@ class sbitvec(_vec): You can access the rows by member :py:obj:`v` and the columns by calling :py:obj:`elements`. - There are three ways to create an instance: + There are four ways to create an instance: 1. By transposition:: @@ -646,8 +685,14 @@ class sbitvec(_vec): This should output:: [1, 0, 1] + + 4. Private input:: + + x = sbitvec.get_type(32).get_input_from(player) + """ bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v))) + is_clear = False @classmethod def get_type(cls, n): """ Create type for fixed-length vector of registers of secret bits. @@ -661,6 +706,9 @@ def malloc(size, creator_tape=None): return sbit.malloc(size * n, creator_tape=creator_tape) @staticmethod def n_elements(): + return 1 + @staticmethod + def mem_size(): return n @classmethod def get_input_from(cls, player): @@ -680,10 +728,11 @@ def from_vec(cls, vector): res.v = _complement_two_extend(list(vector), n)[:n] return res def __init__(self, other=None, size=None): - assert size in (None, 1) if other is not None: if util.is_constant(other): - self.v = [sbit((other >> i) & 1) for i in range(n)] + t = sbits.get_type(size or 1) + self.v = [t(((other >> i) & 1) * ((1 << t.n) - 1)) + for i in range(n)] elif isinstance(other, _vec): self.v = self.bit_extend(other.v, n) elif isinstance(other, (list, tuple)): @@ -691,36 +740,41 @@ def __init__(self, other=None, size=None): else: self.v = sbits.get_type(n)(other).bit_decompose() assert len(self.v) == n + assert size is None or size == self.v[0].n @classmethod - def load_mem(cls, address): + def load_mem(cls, address, size=None): + if size not in (None, 1): + assert isinstance(address, int) or len(address) == 1 + sb = sbits.get_type(size) + return cls.from_vec(sb.bit_compose( + sbit.load_mem(address + i + j * n) for j in range(size)) + for i in range(n)) if not isinstance(address, int) and len(address) == n: return cls.from_vec(sbit.load_mem(x) for x in address) else: return cls.from_vec(sbit.load_mem(address + i) for i in range(n)) def store_in_mem(self, address): + size = 1 for x in self.v: - assert util.is_constant(x) or x.n == 1 - v = [sbit.conv(x) for x in self.v] + if not util.is_constant(x): + size = max(size, x.n) + v = [sbits.get_type(size).conv(x) for x in self.v] if not isinstance(address, int) and len(address) == n: + assert max_n == 1 for x, y in zip(v, address): x.store_in_mem(y) else: + assert isinstance(address, int) or len(address) == 1 for i in range(n): - v[i].store_in_mem(address + i) + for j, x in enumerate(v[i].bit_decompose()): + x.store_in_mem(address + i + j * n) def reveal(self): - if len(self) > cbits.unit: - return self.elements()[0].reveal() - revealed = [cbit() for i in range(len(self))] - for i in range(len(self)): - try: - inst.reveal(1, revealed[i], self.v[i]) - except: - revealed[i] = cbit.conv(self.v[i]) - return cbits.get_type(len(self)).bit_compose(revealed) + return util.untuplify([x.reveal() for x in self.elements()]) @classmethod - def two_power(cls, nn): - return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1)) + def two_power(cls, nn, size=1): + return cls.from_vec( + [0] * nn + [sbits.get_type(size)().long_one()] + [0] * (n - nn - 1)) def coerce(self, other): if util.is_constant(other): return self.from_vec(util.bit_decompose(other, n)) @@ -733,8 +787,12 @@ def bit_compose(cls, bits): bits += [0] * (n - len(bits)) assert len(bits) == n return cls.from_vec(bits) + def zero_if_not(self, condition): + return self.from_vec(x.zero_if_not(condition) for x in self.v) def __str__(self): return 'sbitvec(%d)' % n + sbitvecn.basic_type = sbitvecn + sbitvecn.reg_type = 'sb' return sbitvecn @classmethod def from_vec(cls, vector): @@ -802,16 +860,15 @@ def coerce(self, other): return other def __xor__(self, other): other = self.coerce(other) - return self.from_vec(x ^ y for x, y in zip(self.v, other)) + return self.from_vec(x ^ y for x, y in zip(*self.expand(other))) def __and__(self, other): - return self.from_vec(x & y for x, y in zip(self.v, other.v)) + return self.from_vec(x & y for x, y in zip(*self.expand(other))) + __rxor__ = __xor__ + __rand__ = __and__ + def __invert__(self): + return self.from_vec(~x for x in self.v) def if_else(self, x, y): - assert(len(self.v) == 1) - try: - return self.from_vec(util.if_else(self.v[0], a, b) \ - for a, b in zip(x, y)) - except: - return util.if_else(self.v[0], x, y) + return util.if_else(self.v[0], x, y) def __iter__(self): return iter(self.v) def __len__(self): @@ -824,6 +881,7 @@ def conv(cls, other): return cls.from_vec(other.v) else: return cls(other) + hard_conv = conv @property def size(self): if not self.v or util.is_constant(self.v[0]): @@ -853,6 +911,34 @@ def half_adder(self, other): def __mul__(self, other): if isinstance(other, int): return self.from_vec(x * other for x in self.v) + if isinstance(other, sbitvec): + if len(other.v) == 1: + other = other.v[0] + elif len(self.v) == 1: + self, other = other, self.v[0] + else: + raise CompilerError('no operand of lenght 1: %d/%d', + (len(self.v), len(other.v))) + if not isinstance(other, sbits): + return NotImplemented + ops = [] + for x in self.v: + if not util.is_zero(x): + assert x.n == other.n + ops.append(x) + if ops: + prods = [sbits.get_type(other.n)() for i in ops] + inst.andrsvec(3 + 2 * len(ops), other.n, *prods, other, *ops) + res = [] + i = 0 + for x in self.v: + if util.is_zero(x): + res.append(0) + else: + res.append(prods[i]) + i += 1 + return sbitvec.from_vec(res) + __rmul__ = __mul__ def __add__(self, other): return self.from_vec(x + y for x, y in zip(self.v, other)) def bit_and(self, other): @@ -861,6 +947,46 @@ def bit_xor(self, other): return self ^ other def right_shift(self, m, k, security=None, signed=True): return self.from_vec(self.v[m:]) + def tree_reduce(self, function): + elements = self.elements() + while len(elements) > 1: + size = len(elements) + half = size // 2 + left = elements[:half] + right = elements[half:2*half] + odd = elements[2*half:] + sides = [self.from_vec(sbitvec(x).v) for x in (left, right)] + red = function(*sides) + elements = red.elements() + elements += odd + return self.from_vec(sbitvec(elements).v) + @classmethod + def comp_result(cls, x): + return cls.get_type(1).from_vec([x]) + def expand(self, other, expand=True): + m = 1 + for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []): + try: + m = max(m, x.n) + except: + pass + res = [] + if not util.is_constant(other): + other = self.coerce(other) + for y in self, other: + if isinstance(y, int): + res.append([x * sbits.get_type(m)().long_one() + for x in util.bit_decompose(y, len(self.v))]) + else: + res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v]) + return res + def demux(self): + if len(self) == 1: + return sbitvec.from_vec([self.v[0].bit_not(), self.v[0]]) + a = sbitvec.from_vec(self.v[:len(self) // 2]).demux() + b = sbitvec.from_vec(self.v[len(self) // 2:]).demux() + prod = [a * bb for bb in b.v] + return sbitvec.from_vec(reduce(operator.add, (x.v for x in prod))) class bit(object): n = 1 @@ -914,8 +1040,8 @@ class cbit(bit, cbits): sbits.default_type = sbits class bitsBlock(oram.Block): - value_type = sbits def __init__(self, value, start, lengths, entries_per_block): + self.value_type = type(value) oram.Block.__init__(self, value, lengths) length = sum(self.lengths) used_bits = entries_per_block * length @@ -960,7 +1086,10 @@ def _store(self, value, address): cbits.dynamic_array = Array def _complement_two_extend(bits, k): - return bits[:k] + [bits[-1]] * (k - len(bits)) + if len(bits) == 1: + return bits + [0] * (k - len(bits)) + else: + return bits[:k] + [bits[-1]] * (k - len(bits)) class _sbitintbase: def extend(self, n): @@ -1110,7 +1239,7 @@ def pow2(self, k): :param k: bit length of input """ return _sbitintbase.pow2(self, k) -class sbitintvec(sbitvec, _number, _bitint, _sbitintbase): +class sbitintvec(sbitvec, _bitint, _number, _sbitintbase): """ Vector of signed integers for parallel binary computation:: @@ -1145,19 +1274,34 @@ def elements(self): def __add__(self, other): if util.is_zero(other): return self - other = self.coerce(other) - assert(len(self.v) == len(other.v)) - v = sbitint.bit_adder(self.v, other.v) - return self.from_vec(v) + a, b = self.expand(other) + v = sbitint.bit_adder(a, b) + return self.get_type(len(v)).from_vec(v) __radd__ = __add__ + __sub__ = _bitint.__sub__ + def __rsub__(self, other): + a, b = self.expand(other) + return self.from_vec(b) - self.from_vec(a) def __mul__(self, other): if isinstance(other, sbits): return self.from_vec(other * x for x in self.v) + elif len(self.v) == 1: + return other * self.v[0] elif isinstance(other, sbitfixvec): return NotImplemented + my_bits, other_bits = self.expand(other, False) matrix = [] - for i, b in enumerate(util.bit_decompose(other)): - matrix.append([x & b for x in self.v[:len(self.v)-i]]) + m = float('inf') + for x in itertools.chain(my_bits, other_bits): + try: + m = min(m, x.n) + except: + pass + for i, b in enumerate(other_bits): + if m == 1: + matrix.append([x * b for x in my_bits[:len(self.v)-i]]) + else: + matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v) v = sbitint.wallace_tree_from_matrix(matrix) return self.from_vec(v[:len(self.v)]) __rmul__ = __mul__ @@ -1188,6 +1332,8 @@ class cbitfix(object): store_in_mem = lambda self, *args: self.v.store_in_mem(*args) @classmethod def _new(cls, value): + if isinstance(value, list): + return [cls._new(x) for x in value] res = cls() if cls.k < value.unit: bits = value.bit_decompose(cls.k) @@ -1265,7 +1411,7 @@ class cls(_fix): cls.set_precision(f, k) return cls._new(cls.int_type(other), k, f) -class sbitfixvec(_fix): +class sbitfixvec(_fix, _vec): """ Vector of fixed-point numbers for parallel binary computation. Use :py:obj:`set_precision()` to change the precision. @@ -1302,14 +1448,19 @@ def set_precision(cls, f, k=None): super(sbitfixvec, cls).set_precision(f=f, k=k) cls.int_type = sbitintvec.get_type(cls.k) @classmethod - def get_input_from(cls, player): + def get_input_from(cls, player, size=1): """ Secret input from :py:obj:`player`. :param: player (int) """ - v = [sbit() for i in range(sbitfix.k)] + v = [0] * sbitfix.k sbits._check_input_player(player) - inst.inputbvec(len(v) + 3, sbitfix.f, player, *v) + for i in range(size): + vv = [sbit() for i in range(sbitfix.k)] + inst.inputbvec(len(v) + 3, sbitfix.f, player, *vv) + for j in range(sbitfix.k): + tmp = vv[j] << i + v[j] = tmp ^ v[j] return cls._new(cls.int_type.from_vec(v)) def __init__(self, value=None, *args, **kwargs): if isinstance(value, (list, tuple)): diff --git a/Compiler/__init__.py b/Compiler/__init__.py old mode 100755 new mode 100644 index 9a22da4..6a0d6b1 --- a/Compiler/__init__.py +++ b/Compiler/__init__.py @@ -2,30 +2,3 @@ from .GC import types as GC_types import inspect from .config import * -from .compilerLib import run - - -# add all instructions to the program VARS dictionary -compilerLib.VARS = {} -instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)] - -for mod in (types, GC_types): - instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\ - if t[1].__module__ == mod.__name__] - -instr_classes += [t[1] for t in inspect.getmembers(library, inspect.isfunction)\ - if t[1].__module__ == library.__name__] - -for op in instr_classes: - compilerLib.VARS[op.__name__] = op - -# add open and input separately due to name conflict -compilerLib.VARS['open'] = instructions.asm_open -compilerLib.VARS['vopen'] = instructions.vasm_open -compilerLib.VARS['gopen'] = instructions.gasm_open -compilerLib.VARS['vgopen'] = instructions.vgasm_open -compilerLib.VARS['input'] = instructions.asm_input -compilerLib.VARS['ginput'] = instructions.gasm_input - -compilerLib.VARS['comparison'] = comparison -compilerLib.VARS['floatingpoint'] = floatingpoint diff --git a/Compiler/allocator.py b/Compiler/allocator.py old mode 100755 new mode 100644 index bf431ca..e5c99a7 --- a/Compiler/allocator.py +++ b/Compiler/allocator.py @@ -261,6 +261,7 @@ def longest_paths_merge(self): instructions = self.instructions merge_nodes = self.open_nodes depths = self.depths + self.req_num = defaultdict(lambda: 0) if not merge_nodes: return 0 @@ -281,6 +282,7 @@ def longest_paths_merge(self): print('Merging %d %s in round %d/%d' % \ (len(merge), t.__name__, i, len(merges))) self.do_merge(merge) + self.req_num[t.__name__, 'round'] += 1 preorder = None @@ -530,7 +532,9 @@ def eliminate_dead_code(self): can_eliminate_defs = True for reg in inst.get_def(): for dup in reg.duplicates: - if not dup.can_eliminate: + if not (dup.can_eliminate and reduce( + operator.and_, + (x.can_eliminate for x in dup.vector), True)): can_eliminate_defs = False break # remove if instruction has result that isn't used diff --git a/Compiler/circuit.py b/Compiler/circuit.py old mode 100755 new mode 100644 index 9c4187f..92bd84e --- a/Compiler/circuit.py +++ b/Compiler/circuit.py @@ -137,8 +137,6 @@ def sha3_256(x): 0x4a43f8804b0ad882fa493be44dff80f562d661a05647c15166d71ebff8c6ffa7 0xf0d7aa0ab2d92d580bb080e17cbb52627932ba37f085d3931270d31c39357067 - Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only - implemented for computation modulo a power of two. """ global Keccak_f @@ -236,10 +234,10 @@ def circuit(cls, name): return cls._circuits[name] def __init__(self, value): - if isinstance(value, sbitvec): + if isinstance(value, (sbitint, sbitintvec)): + self.value = self.circuit('i2f')(sbitvec.conv(value)) + elif isinstance(value, sbitvec): self.value = value - elif isinstance(value, (sbitint, sbitintvec)): - self.value = self.circuit('i2f')(sbitvec(value)) elif util.is_constant_float(value): self.value = sbitvec(sbits.get_type(64)( struct.unpack('Q', struct.pack('d', value))[0])) diff --git a/Compiler/circuit_oram.py b/Compiler/circuit_oram.py old mode 100755 new mode 100644 index f5ddebf..a2cada5 --- a/Compiler/circuit_oram.py +++ b/Compiler/circuit_oram.py @@ -1,5 +1,6 @@ -from Compiler.path_oram import * +from Compiler.oram import * +from Compiler.path_oram import PathORAM, XOR from Compiler.util import bit_compose def first_diff(a_bits, b_bits): diff --git a/Compiler/comparison.py b/Compiler/comparison.py old mode 100755 new mode 100644 index 23bee21..1a139ef --- a/Compiler/comparison.py +++ b/Compiler/comparison.py @@ -87,15 +87,14 @@ def LtzRing(a, k): carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands))) msb = carry ^ summands[0][-1] ^ summands[1][-1] return sint.conv(msb) - return - elif program.options.ring: + else: from . import floatingpoint require_ring_size(k, 'comparison') m = k - 1 shift = int(program.options.ring) - k r_prime, r_bin = MaskingBitsInRing(k) tmp = a - r_prime - c_prime = (tmp << shift).reveal() >> shift + c_prime = (tmp << shift).reveal(False) >> shift a = r_bin[0].bit_decompose_clear(c_prime, m) b = r_bin[:m] u = CarryOutRaw(a[::-1], b[::-1]) @@ -190,7 +189,7 @@ def TruncLeakyInRing(a, k, m, signed): r = sint.bit_compose(r_bits) if signed: a += (1 << (k - 1)) - shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal() + shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False) masked = shifted >> n_shift u = sint() BitLTL(u, masked, r_bits[:n_bits], 0) @@ -231,7 +230,7 @@ def Mod2mRing(a_prime, a, k, m, signed): shift = int(program.options.ring) - m r_prime, r_bin = MaskingBitsInRing(m, True) tmp = a + r_prime - c_prime = (tmp << shift).reveal() >> shift + c_prime = (tmp << shift).reveal(False) >> shift u = sint() BitLTL(u, c_prime, r_bin[:m], 0) res = (u << m) + c_prime - r_prime @@ -261,7 +260,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed): t[1] = a adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) - asm_open(c, t[3]) + asm_open(True, c, t[3]) modc(c_prime, c, c2m) if const_rounds: BitLTC1(u, c_prime, r, kappa) @@ -510,7 +509,7 @@ def PreMulC_with_inverses_and_vectors(p, a): movs(w[0], r[0]) movs(a_vec[0], a[0]) vmuls(k, t[0], w, a_vec) - vasm_open(k, m, t[0]) + vasm_open(k, True, m, t[0]) PreMulC_end(p, a, c, m, z) def PreMulC_with_inverses(p, a): @@ -538,7 +537,7 @@ def PreMulC_with_inverses(p, a): w[1][0] = r[0][0] for i in range(k): muls(t[0][i], w[1][i], a[i]) - asm_open(m[i], t[0][i]) + asm_open(True, m[i], t[0][i]) PreMulC_end(p, a, c, m, z) def PreMulC_without_inverses(p, a): @@ -563,7 +562,7 @@ def PreMulC_without_inverses(p, a): #adds(tt[0][i], t[0][i], a[i]) #subs(tt[1][i], tt[0][i], a[i]) #startopen(tt[1][i]) - asm_open(u[i], t[0][i]) + asm_open(True, u[i], t[0][i]) for i in range(k-1): muls(v[i], r[i+1], s[i]) w[0] = r[0] @@ -579,7 +578,7 @@ def PreMulC_without_inverses(p, a): mulm(z[i], s[i], u_inv[i]) for i in range(k): muls(t[1][i], w[i], a[i]) - asm_open(m[i], t[1][i]) + asm_open(True, m[i], t[1][i]) PreMulC_end(p, a, c, m, z) def PreMulC_end(p, a, c, m, z): @@ -637,6 +636,7 @@ def Mod2(a_0, a, k, kappa, signed): t = [program.curr_block.new_reg('s') for i in range(6)] c2k1 = program.curr_block.new_reg('c') PRandM(r_dprime, r_prime, [r_0], k, 1, kappa) + r_0 = r_prime mulsi(t[0], r_dprime, 2) if signed: ld2i(c2k1, k - 1) @@ -645,7 +645,7 @@ def Mod2(a_0, a, k, kappa, signed): t[1] = a adds(t[2], t[0], t[1]) adds(t[3], t[2], r_prime) - asm_open(c, t[3]) + asm_open(True, c, t[3]) from . import floatingpoint c_0 = floatingpoint.bits(c, 1)[0] mulci(tc, c_0, 2) diff --git a/Compiler/compilerLib.py b/Compiler/compilerLib.py old mode 100755 new mode 100644 index b2898e2..462a5d1 --- a/Compiler/compilerLib.py +++ b/Compiler/compilerLib.py @@ -1,94 +1,405 @@ -from Compiler.program import Program +import inspect +import os +import re +import sys +import tempfile +from optparse import OptionParser + +from Compiler.exceptions import CompilerError + from .GC import types as GC_types +from .program import Program, defaults -import sys -import re, tempfile, os - - -def run(args, options): - """ Compile a file and output a Program object. - - If options.merge_opens is set to True, will attempt to merge any - parallelisable open instructions. """ - - prog = Program(args, options) - VARS['program'] = prog - if options.binary: - VARS['sint'] = GC_types.sbitintvec.get_type(int(options.binary)) - VARS['sfix'] = GC_types.sbitfixvec - for i in 'cint', 'cfix', 'cgf2n', 'sintbit', 'sgf2n', 'sgf2nint', \ - 'sgf2nuint', 'sgf2nuint32', 'sgf2nfloat', 'sfloat', 'cfloat', \ - 'squant': - del VARS[i] - - print('Compiling file', prog.infile) - f = open(prog.infile, 'rb') - - changed = False - if options.flow_optimization: - output = [] - if_stack = [] - for line in open(prog.infile): - if if_stack and not re.match(if_stack[-1][0], line): - if_stack.pop() - m = re.match( - '(\s*)for +([a-zA-Z_]+) +in +range\(([0-9a-zA-Z_]+)\):', - line) - if m: - output.append('%s@for_range_opt(%s)\n' % (m.group(1), - m.group(3))) - output.append('%sdef _(%s):\n' % (m.group(1), m.group(2))) - changed = True - continue - m = re.match('(\s*)if(\W.*):', line) - if m: - if_stack.append((m.group(1), len(output))) - output.append('%s@if_(%s)\n' % (m.group(1), m.group(2))) - output.append('%sdef _():\n' % (m.group(1))) - changed = True - continue - m = re.match('(\s*)elif\s+', line) - if m: - raise CompilerError('elif not supported') - if if_stack: - m = re.match('%selse:' % if_stack[-1][0], line) - if m: - start = if_stack[-1][1] - ws = if_stack[-1][0] - output[start] = re.sub(r'^%s@if_\(' % ws, r'%s@if_e(' % ws, - output[start]) - output.append('%s@else_\n' % ws) - output.append('%sdef _():\n' % ws) - continue - output.append(line) - if changed: - infile = tempfile.NamedTemporaryFile('w+', delete=False) - for line in output: - infile.write(line) - infile.seek(0) + +class Compiler: + def __init__(self, custom_args=None, usage=None): + if usage: + self.usage = usage else: - infile = open(prog.infile) - else: - infile = open(prog.infile) + self.usage = "usage: %prog [options] filename [args]" + self.custom_args = custom_args + self.build_option_parser() + self.VARS = {} + + def build_option_parser(self): + parser = OptionParser(usage=self.usage) + parser.add_option( + "-n", + "--nomerge", + action="store_false", + dest="merge_opens", + default=defaults.merge_opens, + help="don't attempt to merge open instructions", + ) + parser.add_option("-o", "--output", dest="outfile", help="specify output file") + parser.add_option( + "-a", + "--asm-output", + dest="asmoutfile", + help="asm output file for debugging", + ) + parser.add_option( + "-g", + "--galoissize", + dest="galois", + default=defaults.galois, + help="bit length of Galois field", + ) + parser.add_option( + "-d", + "--debug", + action="store_true", + dest="debug", + help="keep track of trace for debugging", + ) + parser.add_option( + "-c", + "--comparison", + dest="comparison", + default="log", + help="comparison variant: log|plain|inv|sinv", + ) + parser.add_option( + "-M", + "--preserve-mem-order", + action="store_true", + dest="preserve_mem_order", + default=defaults.preserve_mem_order, + help="preserve order of memory instructions; possible efficiency loss", + ) + parser.add_option( + "-O", + "--optimize-hard", + action="store_true", + dest="optimize_hard", + help="currently not in use", + ) + parser.add_option( + "-u", + "--noreallocate", + action="store_true", + dest="noreallocate", + default=defaults.noreallocate, + help="don't reallocate", + ) + parser.add_option( + "-m", + "--max-parallel-open", + dest="max_parallel_open", + default=defaults.max_parallel_open, + help="restrict number of parallel opens", + ) + parser.add_option( + "-D", + "--dead-code-elimination", + action="store_true", + dest="dead_code_elimination", + default=defaults.dead_code_elimination, + help="eliminate instructions with unused result", + ) + parser.add_option( + "-p", + "--profile", + action="store_true", + dest="profile", + help="profile compilation", + ) + parser.add_option( + "-s", + "--stop", + action="store_true", + dest="stop", + help="stop on register errors", + ) + parser.add_option( + "-R", + "--ring", + dest="ring", + default=defaults.ring, + help="bit length of ring (default: 0 for field)", + ) + parser.add_option( + "-B", + "--binary", + dest="binary", + default=defaults.binary, + help="bit length of sint in binary circuit (default: 0 for arithmetic)", + ) + parser.add_option( + "-G", + "--garbled-circuit", + dest="garbled", + action="store_true", + help="compile for binary circuits only (default: false)", + ) + parser.add_option( + "-F", + "--field", + dest="field", + default=defaults.field, + help="bit length of sint modulo prime (default: 64)", + ) + parser.add_option( + "-P", + "--prime", + dest="prime", + default=defaults.prime, + help="prime modulus (default: not specified)", + ) + parser.add_option( + "-I", + "--insecure", + action="store_true", + dest="insecure", + help="activate insecure functionality for benchmarking", + ) + parser.add_option( + "-b", + "--budget", + dest="budget", + default=defaults.budget, + help="set budget for optimized loop unrolling " "(default: 100000)", + ) + parser.add_option( + "-X", + "--mixed", + action="store_true", + dest="mixed", + help="mixing arithmetic and binary computation", + ) + parser.add_option( + "-Y", + "--edabit", + action="store_true", + dest="edabit", + help="mixing arithmetic and binary computation using edaBits", + ) + parser.add_option( + "-Z", + "--split", + default=defaults.split, + dest="split", + help="mixing arithmetic and binary computation " + "using direct conversion if supported " + "(number of parties as argument)", + ) + parser.add_option( + "--invperm", + action="store_true", + dest="invperm", + help="speedup inverse permutation (only use in two-party, " + "semi-honest environment)" + ) + parser.add_option( + "-C", + "--CISC", + action="store_true", + dest="cisc", + help="faster CISC compilation mode", + ) + parser.add_option( + "-K", + "--keep-cisc", + dest="keep_cisc", + help="don't translate CISC instructions", + ) + parser.add_option( + "-l", + "--flow-optimization", + action="store_true", + dest="flow_optimization", + help="optimize control flow", + ) + parser.add_option( + "-v", + "--verbose", + action="store_true", + dest="verbose", + help="more verbose output", + ) + self.parser = parser + + def parse_args(self): + self.options, self.args = self.parser.parse_args(self.custom_args) + if self.options.optimize_hard: + print("Note that -O/--optimize-hard currently has no effect") + + def build_program(self, name=None): + self.prog = Program(self.args, self.options, name=name) + + def build_vars(self): + from . import comparison, floatingpoint, instructions, library, types + + # add all instructions to the program VARS dictionary + instr_classes = [ + t[1] for t in inspect.getmembers(instructions, inspect.isclass) + ] + + for mod in (types, GC_types): + instr_classes += [ + t[1] + for t in inspect.getmembers(mod, inspect.isclass) + if t[1].__module__ == mod.__name__ + ] + + instr_classes += [ + t[1] + for t in inspect.getmembers(library, inspect.isfunction) + if t[1].__module__ == library.__name__ + ] + + for op in instr_classes: + self.VARS[op.__name__] = op + + # backward compatibility for deprecated classes + self.VARS["sbitint"] = GC_types.sbitintvec + self.VARS["sbitfix"] = GC_types.sbitfixvec + + # add open and input separately due to name conflict + self.VARS["vopen"] = instructions.vasm_open + self.VARS["gopen"] = instructions.gasm_open + self.VARS["vgopen"] = instructions.vgasm_open + self.VARS["ginput"] = instructions.gasm_input + + self.VARS["comparison"] = comparison + self.VARS["floatingpoint"] = floatingpoint + + self.VARS["program"] = self.prog + if self.options.binary: + self.VARS["sint"] = GC_types.sbitintvec.get_type(int(self.options.binary)) + self.VARS["sfix"] = GC_types.sbitfixvec + for i in [ + "cint", + "cfix", + "cgf2n", + "sintbit", + "sgf2n", + "sgf2nint", + "sgf2nuint", + "sgf2nuint32", + "sgf2nfloat", + "cfloat", + "squant", + ]: + del self.VARS[i] + + def prep_compile(self, name=None): + self.parse_args() + if len(self.args) < 1 and name is None: + self.parser.print_help() + exit(1) + self.build_program(name=name) + self.build_vars() + + def compile_file(self): + """Compile a file and output a Program object. + + If options.merge_opens is set to True, will attempt to merge any + parallelisable open instructions.""" + print("Compiling file", self.prog.infile) + + with open(self.prog.infile, "r") as f: + changed = False + if self.options.flow_optimization: + output = [] + if_stack = [] + for line in f: + if if_stack and not re.match(if_stack[-1][0], line): + if_stack.pop() + m = re.match( + r"(\s*)for +([a-zA-Z_]+) +in " r"+range\(([0-9a-zA-Z_]+)\):", + line, + ) + if m: + output.append( + "%s@for_range_opt(%s)\n" % (m.group(1), m.group(3)) + ) + output.append("%sdef _(%s):\n" % (m.group(1), m.group(2))) + changed = True + continue + m = re.match(r"(\s*)if(\W.*):", line) + if m: + if_stack.append((m.group(1), len(output))) + output.append("%s@if_(%s)\n" % (m.group(1), m.group(2))) + output.append("%sdef _():\n" % (m.group(1))) + changed = True + continue + m = re.match(r"(\s*)elif\s+", line) + if m: + raise CompilerError("elif not supported") + if if_stack: + m = re.match("%selse:" % if_stack[-1][0], line) + if m: + start = if_stack[-1][1] + ws = if_stack[-1][0] + output[start] = re.sub( + r"^%s@if_\(" % ws, r"%s@if_e(" % ws, output[start] + ) + output.append("%s@else_\n" % ws) + output.append("%sdef _():\n" % ws) + continue + output.append(line) + if changed: + infile = tempfile.NamedTemporaryFile("w+", delete=False) + for line in output: + infile.write(line) + infile.seek(0) + else: + infile = open(self.prog.infile) + else: + infile = open(self.prog.infile) + + # make compiler modules directly accessible + sys.path.insert(0, "Compiler") + # create the tapes + exec(compile(infile.read(), infile.name, "exec"), self.VARS) + + if changed and not self.options.debug: + os.unlink(infile.name) + + return self.finalize_compile() + + def register_function(self, name=None): + """ + To register a function to be compiled, use this as a decorator. + Example: + + @compiler.register_function('test-mpc') + def test_mpc(compiler): + ... + """ + + def inner(func): + self.compile_name = name or func.__name__ + self.compile_function = func + return func - # make compiler modules directly accessible - sys.path.insert(0, 'Compiler') - # create the tapes - exec(compile(infile.read(), infile.name, 'exec'), VARS) + return inner - if changed and not options.debug: - os.unlink(infile.name) + def compile_func(self): + if not (hasattr(self, "compile_name") and hasattr(self, "compile_func")): + raise CompilerError( + "No function to compile. " + "Did you decorate a function with @register_fuction(name)?" + ) + self.prep_compile(self.compile_name) + print( + "Compiling: {} from {}".format(self.compile_name, self.compile_func.__name__) + ) + self.compile_function() + self.finalize_compile() - prog.finalize() + def finalize_compile(self): + self.prog.finalize() - if prog.req_num: - print('Program requires at most:') - for x in prog.req_num.pretty(): - print(x) + if self.prog.req_num: + print("Program requires at most:") + for x in self.prog.req_num.pretty(): + print(x) - if prog.verbose: - print('Program requires:', repr(prog.req_num)) - print('Cost:', 0 if prog.req_num is None else prog.req_num.cost()) - print('Memory size:', dict(prog.allocated_mem)) + if self.prog.verbose: + print("Program requires:", repr(self.prog.req_num)) + print("Cost:", 0 if self.prog.req_num is None else self.prog.req_num.cost()) + print("Memory size:", dict(self.prog.allocated_mem)) - return prog + return self.prog diff --git a/Compiler/config.py b/Compiler/config.py old mode 100755 new mode 100644 diff --git a/Compiler/dijkstra.py b/Compiler/dijkstra.py old mode 100755 new mode 100644 index 45d25e6..fd57e1b --- a/Compiler/dijkstra.py +++ b/Compiler/dijkstra.py @@ -99,7 +99,7 @@ def bubble_up(self, start): bits.reverse() bits = [0] + floatingpoint.PreOR(bits, self.levels) bits = [bits[i+1] - bits[i] for i in range(self.levels)] - shift = sum([bit << i for i,bit in enumerate(bits)]) + shift = self.int_type.bit_compose(bits) childpos = MemValue(start * shift) @for_range(self.levels - 1) def f(i): @@ -215,12 +215,13 @@ def dump(self, msg=''): print_ln() print_ln() -def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint): - basic_type = int_type.basic_type +def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None): vert_loops = n_loops * e_index.size // edges.size \ if n_loops else -1 dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \ - init_rounds=vert_loops, value_type=basic_type) + init_rounds=vert_loops, value_type=int_type) + int_type = dist.value_type + basic_type = int_type.basic_type #visited = ORAM(e_index.size) #previous = oram_type(e_index.size) Q = HeapQ(e_index.size, oram_type, init_rounds=vert_loops, \ @@ -240,7 +241,7 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint): u = MemValue(basic_type(0)) @for_range(n_loops or edges.size) def f(i): - cint(i).print_reg('loop') + print_ln('loop %s', i) time() u.write(if_else(last_edge, Q.pop(last_edge), u)) #visited.access(u, True, last_edge) diff --git a/Compiler/exceptions.py b/Compiler/exceptions.py old mode 100755 new mode 100644 diff --git a/Compiler/floatingpoint.py b/Compiler/floatingpoint.py old mode 100755 new mode 100644 index c596240..7786f73 --- a/Compiler/floatingpoint.py +++ b/Compiler/floatingpoint.py @@ -28,13 +28,15 @@ def shift_two(n, pos): def maskRing(a, k): shift = int(program.Program.prog.options.ring) - k - if program.Program.prog.use_dabit: + if program.Program.prog.use_edabit(): + r_prime, r = types.sint.get_edabit(k) + elif program.Program.prog.use_dabit: rr, r = zip(*(types.sint.get_dabit() for i in range(k))) r_prime = types.sint.bit_compose(rr) else: r = [types.sint.get_random_bit() for i in range(k)] r_prime = types.sint.bit_compose(r) - c = ((a + r_prime) << shift).reveal() >> shift + c = ((a + r_prime) << shift).reveal(False) >> shift return c, r def maskField(a, k, kappa): @@ -45,7 +47,7 @@ def maskField(a, k, kappa): comparison.PRandM(r_dprime, r_prime, r, k, k, kappa) # always signed due to usage in equality testing a += two_power(k) - asm_open(c, a + two_power(k) * r_dprime + r_prime) + asm_open(True, c, a + two_power(k) * r_dprime + r_prime) return c, r @instructions_base.ret_cisc @@ -231,7 +233,7 @@ def Inv(a): ldi(one, 1) inverse(t[0], t[1]) s = t[0]*a - asm_open(c[0], s) + asm_open(True, c[0], s) # avoid division by zero for benchmarking divc(c[1], one, c[0]) #divc(c[1], c[0], one) @@ -279,7 +281,7 @@ def BitDecRingRaw(a, k, m): else: r_bits = [types.sint.get_random_bit() for i in range(m)] r = types.sint.bit_compose(r_bits) - shifted = ((a - r) << n_shift).reveal() + shifted = ((a - r) << n_shift).reveal(False) masked = shifted >> n_shift bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m)) return bits @@ -297,7 +299,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None): r = [types.sint() for i in range(m)] comparison.PRandM(r_dprime, r_prime, r, k, m, kappa) pow2 = two_power(k + kappa) - asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) + asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime) res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m))) instructions_base.reset_global_vector_size() return res @@ -309,6 +311,7 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None): @instructions_base.ret_cisc def Pow2(a, l, kappa): + comparison.program.curr_tape.require_bit_length(l - 1) m = int(ceil(log(l, 2))) t = BitDec(a, m, m, kappa) return Pow2_from_bits(t) @@ -339,10 +342,10 @@ def B2U_from_Pow2(pow2a, l, kappa): if program.Program.prog.options.ring: n_shift = int(program.Program.prog.options.ring) - l assert n_shift > 0 - c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift + c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal(False) >> n_shift else: comparison.PRandInt(t, kappa) - asm_open(c, pow2a + two_power(l) * t + + asm_open(True, c, pow2a + two_power(l) * t + sum(two_power(i) * r[i] for i in range(l))) comparison.program.curr_tape.require_bit_length(l + kappa) c = list(r_bits[0].bit_decompose_clear(c, l)) @@ -384,11 +387,11 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False): r_dprime += t1 - t2 if program.Program.prog.options.ring: n_shift = int(program.Program.prog.options.ring) - l - c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift + c = ((a + r_dprime + r_prime) << n_shift).reveal(False) >> n_shift else: comparison.PRandInt(rk, kappa) r_dprime += two_power(l) * rk - asm_open(c, a + r_dprime + r_prime) + asm_open(True, c, a + r_dprime + r_prime) for i in range(1,l): ci[i] = c % two_power(i) c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l)) @@ -414,7 +417,7 @@ def TruncInRing(to_shift, l, pow2m): rev *= pow2m r_bits = [types.sint.get_random_bit() for i in range(l)] r = types.sint.bit_compose(r_bits) - shifted = (rev - (r << n_shift)).reveal() + shifted = (rev - (r << n_shift)).reveal(False) masked = shifted >> n_shift bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l)) return types.sint.bit_compose(reversed(bits)) @@ -455,7 +458,7 @@ def Int2FL(a, gamma, l, kappa=None): v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False) else: v = 2**(l-gamma+1) * t - p = (p + gamma - 1 - l) * (1 -z) + p = (p + gamma - 1 - l) * z.bit_not() return v, p, z, s def FLRound(x, mode): @@ -528,7 +531,7 @@ def TruncPrRing(a, k, m, signed=True): msb = r_bits[-1] n_shift = n_ring - (k + 1) tmp = a + r - masked = (tmp << n_shift).reveal() + masked = (tmp << n_shift).reveal(False) shifted = (masked << 1 >> (n_shift + m + 1)) overflow = msb.bit_xor(masked >> (n_ring - 1)) res = shifted - upper + \ @@ -549,7 +552,7 @@ def TruncPrField(a, k, m, kappa=None): k, m, kappa, use_dabit=False) two_to_m = two_power(m) r = two_to_m * r_dprime + r_prime - c = (b + r).reveal() + c = (b + r).reveal(False) c_prime = c % two_to_m a_prime = c_prime - r_prime d = (a - a_prime) / two_to_m @@ -665,14 +668,14 @@ def get_bits_loop(): def _(): for i in range(bit_length): tbits[j][i].link(sint.get_random_bit()) - c = regint(BITLT(tbits[j], pbits, bit_length).reveal()) + c = regint(BITLT(tbits[j], pbits, bit_length).reveal(False)) done[j].link(c) return (sum(done) != a.size) for j in range(a.size): for i in range(bit_length): movs(bbits[i][j], tbits[j][i]) b = sint.bit_compose(bbits) - c = (a-b).reveal() + c = (a-b).reveal(False) cmodp = c t = bbits[0].bit_decompose_clear(p - c, bit_length) c = longint(c, bit_length) diff --git a/Compiler/graph.py b/Compiler/graph.py old mode 100755 new mode 100644 diff --git a/Compiler/gs.py b/Compiler/gs.py old mode 100755 new mode 100644 diff --git a/Compiler/instructions.py b/Compiler/instructions.py old mode 100755 new mode 100644 index 5f5b82d..c513183 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -387,6 +387,14 @@ class use(base.Instruction): code = base.opcodes['USE'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + from .program import field_types, data_types + from .util import find_in_dict + return {(find_in_dict(field_types, args[0].i), + find_in_dict(data_types, args[1].i)): + args[2].i} + class use_inp(base.Instruction): """ Input usage. Necessary to avoid reusage while using preprocessing from files. @@ -398,6 +406,13 @@ class use_inp(base.Instruction): code = base.opcodes['USE_INP'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + from .program import field_types, data_types + from .util import find_in_dict + return {(find_in_dict(field_types, args[0].i), 'input', args[1].i): + args[2].i} + class use_edabit(base.Instruction): """ edaBit usage. Necessary to avoid reusage while using preprocessing from files. Also used to multithreading for expensive @@ -410,6 +425,10 @@ class use_edabit(base.Instruction): code = base.opcodes['USE_EDABIT'] arg_format = ['int','int','int'] + @classmethod + def get_usage(cls, args): + return {('sedabit' if args[0].i else 'edabit', args[1].i): args[2].i} + class use_matmul(base.Instruction): """ Matrix multiplication usage. Used for multithreading of preprocessing. @@ -471,6 +490,11 @@ class use_prep(base.Instruction): code = base.opcodes['USE_PREP'] arg_format = ['str','int'] + @classmethod + def get_usage(cls, args): + return {('gf2n' if cls.__name__ == 'guse_prep' else 'modp', + args[0].str): args[1].i} + class nplayers(base.Instruction): """ Store number of players in clear integer register. @@ -590,6 +614,18 @@ class submr(base.SubBase): code = base.opcodes['SUBMR'] arg_format = ['sw','c','s'] +@base.vectorize +class prefixsums(base.Instruction): + """ Prefix sum. + + :param: result (sint) + :param: input (sint) + + """ + __slots__ = [] + code = base.opcodes['PREFIXSUMS'] + arg_format = ['sw','s'] + @base.gf2n @base.vectorize class mulc(base.MulBase): @@ -783,30 +819,6 @@ def has_var_args(self): return True -### -### Special GF(2) arithmetic instructions -### - -@base.vectorize -class gmulbitc(base.MulBase): - r""" Clear GF(2^n) by clear GF(2) multiplication """ - __slots__ = [] - code = base.opcodes['GMULBITC'] - arg_format = ['cgw','cg','cg'] - - def is_gf2n(self): - return True - -@base.vectorize -class gmulbitm(base.MulBase): - r""" Secret GF(2^n) by clear GF(2) multiplication """ - __slots__ = [] - code = base.opcodes['GMULBITM'] - arg_format = ['sgw','sg','cg'] - - def is_gf2n(self): - return True - ### ### Arithmetic with immediate values ### @@ -1051,6 +1063,7 @@ class shrci(base.ClearShiftInstruction): code = base.opcodes['SHRCI'] op = '__rshift__' +@base.gf2n @base.vectorize class shrsi(base.ClearShiftInstruction): """ Bitwise right shift of secret register (vector) by (constant) @@ -1404,7 +1417,6 @@ def get_players(self): for i, t in self.bases(iter(self.args)): yield self.args[i + sum(self.types[t]) + 1] -@base.vectorize class inputmixedreg(inputmixed_base): """ Store private input in secret registers (vectors). The input is read as integer or floating-point number and the latter is then @@ -1424,6 +1436,21 @@ class inputmixedreg(inputmixed_base): """ code = base.opcodes['INPUTMIXEDREG'] player_arg_type = 'ci' + is_vec = lambda self: True + + def __init__(self, *args): + inputmixed_base.__init__(self, *args) + for i, t in self.bases(iter(self.args)): + n = self.types[t][0] + for j in range(i + 1, i + 1 + n): + assert args[j].size == self.get_size() + + def get_size(self): + return self.args[1].size + + def get_code(self): + return inputmixed_base.get_code( + self, self.get_size() if self.get_size() > 1 else 0) def add_usage(self, req_node): # player 0 as proxy @@ -1602,7 +1629,7 @@ class print_char(base.IOInstruction): arg_format = ['int'] def __init__(self, ch): - super(print_char, self).__init__(ord(ch)) + super(print_char, self).__init__(ch) class print_char4(base.IOInstruction): """ Output four bytes. @@ -1706,6 +1733,7 @@ class writesockets(base.IOInstruction): from registers into a socket for a specified client id. If the protocol uses MACs, the client should be different for every party. + :param: number of arguments to follow :param: client id (regint) :param: message type (must be 0) :param: vector size (int) @@ -2161,14 +2189,19 @@ class gconvgf2n(base.Instruction): class asm_open(base.VarArgsInstruction): """ Reveal secret registers (vectors) to clear registers (vectors). - :param: number of argument to follow (multiple of two) + :param: number of argument to follow (odd number) + :param: check after opening (0/1) :param: destination (cint) :param: source (sint) :param: (repeat the last two)... """ __slots__ = [] code = base.opcodes['OPEN'] - arg_format = tools.cycle(['cw','s']) + arg_format = tools.chain(['int'], tools.cycle(['cw','s'])) + + def merge(self, other): + self.args[0] |= other.args[0] + self.args += other.args[1:] @base.gf2n @base.vectorize @@ -2280,6 +2313,7 @@ def dynamic_arg_format(self, args): yield 'int' for i, n in self.bases(args): yield 's' + field + 'w' + assert n > 2 for j in range(n - 2): yield 's' + field yield 'int' @@ -2407,8 +2441,41 @@ class trunc_pr(base.VarArgsInstruction): code = base.opcodes['TRUNC_PR'] arg_format = tools.cycle(['sw','s','int','int']) +class shuffle_base(base.DataInstruction): + n_relevant_parties = 2 + + @staticmethod + def logn(n): + return int(math.ceil(math.log(n, 2))) + + @classmethod + def n_swaps(cls, n): + logn = cls.logn(n) + return logn * 2 ** logn - 2 ** logn + 1 + + def add_gen_usage(self, req_node, n): + # hack for unknown usage + req_node.increment(('bit', 'inverse'), float('inf')) + # minimal usage with two relevant parties + logn = self.logn(n) + n_switches = self.n_swaps(n) + for i in range(self.n_relevant_parties): + req_node.increment((self.field_type, 'input', i), n_switches) + # multiplications for bit check + req_node.increment((self.field_type, 'triple'), + n_switches * self.n_relevant_parties) + + def add_apply_usage(self, req_node, n, record_size): + req_node.increment(('bit', 'inverse'), float('inf')) + logn = self.logn(n) + n_switches = self.n_swaps(n) * self.n_relevant_parties + if n != 2 ** logn: + record_size += 1 + req_node.increment((self.field_type, 'triple'), + n_switches * record_size) + @base.gf2n -class secshuffle(base.VectorInstruction, base.DataInstruction): +class secshuffle(base.VectorInstruction, shuffle_base): """ Secure shuffling. :param: destination (sint) @@ -2424,9 +2491,10 @@ def __init__(self, *args, **kwargs): assert len(args[0]) > args[2] def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', 0), float('inf')) + self.add_gen_usage(req_node, len(self.args[0])) + self.add_apply_usage(req_node, len(self.args[0]), self.args[2]) -class gensecshuffle(base.DataInstruction): +class gensecshuffle(shuffle_base): """ Generate secure shuffle to bit used several times. :param: destination (regint) @@ -2438,9 +2506,9 @@ class gensecshuffle(base.DataInstruction): arg_format = ['ciw','int'] def add_usage(self, req_node): - req_node.increment((self.field_type, 'input', 0), float('inf')) + self.add_gen_usage(req_node, self.args[1]) -class applyshuffle(base.VectorInstruction, base.DataInstruction): +class applyshuffle(base.VectorInstruction, shuffle_base): """ Generate secure shuffle to bit used several times. :param: destination (sint) @@ -2460,7 +2528,7 @@ def __init__(self, *args, **kwargs): assert len(args[0]) > args[2] def add_usage(self, req_node): - req_node.increment((self.field_type, 'triple', 0), float('inf')) + self.add_apply_usage(req_node, len(self.args[0]), self.args[2]) class delshuffle(base.Instruction): """ Delete secure shuffle. @@ -2471,6 +2539,26 @@ class delshuffle(base.Instruction): code = base.opcodes['DELSHUFFLE'] arg_format = ['ci'] +class inverse_permutation(base.VectorInstruction, shuffle_base): + """ Calculate the inverse permutation of a secret permutation. + + :param: destination (sint) + :param: source (sint) + + """ + __slots__ = [] + code = base.opcodes['INVPERM'] + arg_format = ['sw', 's'] + + def __init__(self, *args, **kwargs): + super(inverse_permutation, self).__init__(*args, **kwargs) + assert len(args[0]) == len(args[1]) + + def add_usage(self, req_node): + self.add_gen_usage(req_node, len(self.args[0])) + self.add_apply_usage(req_node, len(self.args[0]), 1) + + class check(base.Instruction): """ Force MAC check in current thread and all idle thread if current @@ -2498,7 +2586,7 @@ def expand(self): c = [program.curr_block.new_reg('c') for i in range(2)] square(s[0], s[1]) subs(s[2], self.args[1], s[0]) - asm_open(c[0], s[2]) + asm_open(False, c[0], s[2]) mulc(c[1], c[0], c[0]) mulm(s[3], self.args[1], c[0]) adds(s[4], s[3], s[3]) diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py old mode 100755 new mode 100644 index d598d8a..f811e47 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -80,6 +80,7 @@ SUBSI = 0x2A, SUBCFI = 0x2B, SUBSFI = 0x2C, + PREFIXSUMS = 0x2D, # Multiplication/division MULC = 0x30, MULM = 0x31, @@ -111,6 +112,7 @@ GENSECSHUFFLE = 0xFB, APPLYSHUFFLE = 0xFC, DELSHUFFLE = 0xFD, + INVPERM = 0xFE, # Data access TRIPLE = 0x50, BIT = 0x51, @@ -207,8 +209,8 @@ CONDPRINTPLAIN = 0xE1, INTOUTPUT = 0xE6, FLOATOUTPUT = 0xE7, - GBITDEC = 0x184, - GBITCOM = 0x185, + GBITDEC = 0x18A, + GBITCOM = 0x18B, # Secure socket INITSECURESOCKET = 0x1BA, RESPSECURESOCKET = 0x1BB @@ -541,7 +543,7 @@ def add_usage(self, *args): def get_bytes(self): assert len(self.kwargs) < 2 - res = int_to_bytes(opcodes['CISC']) + res = LongArgFormat.encode(opcodes['CISC']) res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1) name = self.function.__name__ String.check(name) @@ -701,10 +703,16 @@ class ClearIntAF(RegisterArgFormat): reg_type = RegType.ClearInt class IntArgFormat(ArgFormat): + n_bits = 32 + @classmethod def check(cls, arg): - if not isinstance(arg, int) and not arg is None: - raise ArgumentError(arg, 'Expected an integer-valued argument') + if not arg is None: + if not isinstance(arg, int): + raise ArgumentError(arg, 'Expected an integer-valued argument') + if arg >= 2 ** cls.n_bits or arg < -2 ** cls.n_bits: + raise ArgumentError( + arg, 'Immediate value outside of %d-bit range' % cls.n_bits) @classmethod def encode(cls, arg): @@ -717,9 +725,11 @@ def __str__(self): return str(self.i) class LongArgFormat(IntArgFormat): + n_bits = 64 + @classmethod def encode(cls, arg): - return struct.pack('>Q', arg) + return list(struct.pack('>Q', arg)) def __init__(self, f): self.i = struct.unpack('>Q', f.read(8))[0] @@ -728,8 +738,6 @@ class ImmediateModpAF(IntArgFormat): @classmethod def check(cls, arg): super(ImmediateModpAF, cls).check(arg) - if arg >= 2**32 or arg < -2**32: - raise ArgumentError(arg, 'Immediate value outside of 32-bit range') class ImmediateGF2NAF(IntArgFormat): @classmethod @@ -740,6 +748,8 @@ def check(cls, arg): class PlayerNoAF(IntArgFormat): @classmethod def check(cls, arg): + if not util.is_constant(arg): + raise CompilerError('Player number must be known at compile time') super(PlayerNoAF, cls).check(arg) if arg > 256: raise ArgumentError(arg, 'Player number > 256') @@ -822,7 +832,7 @@ def get_code(self, prefix=0): return (prefix << self.code_length) + self.code def get_encoding(self): - enc = int_to_bytes(self.get_code()) + enc = LongArgFormat.encode(self.get_code()) # add the number of registers if instruction flagged as has var args if self.has_var_args(): enc += int_to_bytes(len(self.args)) @@ -957,7 +967,7 @@ def __init__(self, f): except AttributeError: pass read = lambda: struct.unpack('>I', f.read(4))[0] - full_code = read() + full_code = struct.unpack('>Q', f.read(8))[0] code = full_code % (1 << Instruction.code_length) self.size = full_code >> Instruction.code_length self.type = cls.reverse_opcodes[code] @@ -1106,12 +1116,16 @@ class IOInstruction(DoNotEliminateInstruction): @classmethod def str_to_int(cls, s): """ Convert a 4 character string to an integer. """ + try: + s = bytearray(s, 'utf8') + except: + pass if len(s) > 4: raise CompilerError('String longer than 4 characters') n = 0 for c in reversed(s.ljust(4)): n <<= 8 - n += ord(c) + n += c return n class AsymmetricCommunicationInstruction(DoNotEliminateInstruction): diff --git a/Compiler/library.py b/Compiler/library.py old mode 100755 new mode 100644 index ef2fe1a..1f8fdc9 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -64,6 +64,7 @@ def print_str(s, *args): variables/registers with ``%s``. """ def print_plain_str(ss): """ Print a plain string (no custom formatting options) """ + ss = bytearray(ss, 'utf8') i = 1 while 4*i <= len(ss): print_char4(ss[4*(i-1):4*i]) @@ -138,7 +139,7 @@ def print_str_if(cond, ss, *args): """ Print string conditionally. See :py:func:`print_ln_if` for details. """ if util.is_constant(cond): if cond: - print_ln(ss, *args) + print_str(ss, *args) else: subs = ss.split('%s') assert len(subs) == len(args) + 1 @@ -154,7 +155,8 @@ def print_str_if(cond, ss, *args): print_str_if(cond, *_expand_to_print(val)) else: print_str_if(cond, str(val)) - s += '\0' * ((-len(s)) % 4) + s = bytearray(s, 'utf8') + s += b'\0' * ((-len(s)) % 4) while s: cond.print_if(s[:4]) s = s[4:] @@ -243,6 +245,10 @@ def store_in_mem(value, address): try: value.store_in_mem(address) except AttributeError: + if isinstance(value, (list, tuple)): + for i, x in enumerate(value): + store_in_mem(x, address + i) + return # legacy if value.is_clear: if isinstance(address, cint): @@ -261,11 +267,13 @@ def reveal(secret): try: return secret.reveal() except AttributeError: + if secret.is_clear: + return secret if secret.is_gf2n: res = cgf2n() else: res = cint() - instructions.asm_open(res, secret) + instructions.asm_open(True, res, secret) return res @vectorize @@ -282,13 +290,13 @@ def get_arg(): ldarg(res) return res -def make_array(l): +def make_array(l, t=None): if isinstance(l, program.Tape.Register): - res = Array(len(l), type(l)) + res = Array(len(l), t or type(l)) res[:] = l else: l = list(l) - res = Array(len(l), type(l[0]) if l else cint) + res = Array(len(l), t or type(l[0]) if l else cint) res.assign(l) return res @@ -459,6 +467,24 @@ def wrapper(self, *args): return block(*args) return wrapper +# def cond_swap(x,y): +# from .types import SubMultiArray +# if isinstance(x, (Array, SubMultiArray)): +# b = x[0] > y[0] +# return list(zip(*[b.cond_swap(xx, yy) for xx, yy in zip(x, y)])) +# b = x < y +# if isinstance(x, sfloat): +# res = ([], []) +# for i,j in enumerate(('v','p','z','s')): +# xx = x.__getattribute__(j) +# yy = y.__getattribute__(j) +# bx = b * xx +# by = b * yy +# res[0].append(bx + yy - by) +# res[1].append(xx - bx + by) +# return sfloat(*res[0]), sfloat(*res[1]) +# return b.cond_swap(y, x) + def cond_swap(x,y): b = x < y if isinstance(x, sfloat): @@ -501,12 +527,15 @@ def odd_even_merge_sort(a): if len(a) == 1: return elif len(a) % 2 == 0: + aa = a + a = list(a) lower = a[:len(a)//2] upper = a[len(a)//2:] odd_even_merge_sort(lower) odd_even_merge_sort(upper) a[:] = lower + upper odd_even_merge(a) + aa[:] = a else: raise CompilerError('Length of list must be power of two') @@ -874,15 +903,15 @@ def loop_fn(i): # known loop count if condition(start): get_tape().req_node.children[-1].aggregator = \ - lambda x: ((stop - start) // step) * x[0] + lambda x: int(ceil(((stop - start) / step))) * x[0] def for_range(start, stop=None, step=None): """ Decorator to execute loop bodies consecutively. Arguments work as - in Python :py:func:`range`, but they can by any public + in Python :py:func:`range`, but they can be any public integer. Information has to be passed out via container types such - as :py:class:`~Compiler.types.Array` or declaring registers as - :py:obj:`global`. Note that changing Python data structures such + as :py:class:`~Compiler.types.Array` or using :py:func:`update`. + Note that changing Python data structures such as lists within the loop is not possible, but the compiler cannot warn about this. @@ -897,13 +926,11 @@ def for_range(start, stop=None, step=None): @for_range(n) def _(i): a[i] = i - global x - x += 1 + x.update(x + 1) Note that you cannot overwrite data structures such as - :py:class:`~Compiler.types.Array` in a loop even when using - :py:obj:`global`. Use :py:func:`~Compiler.types.Array.assign` - instead. + :py:class:`~Compiler.types.Array` in a loop. Use + :py:func:`~Compiler.types.Array.assign` instead. """ def decorator(loop_body): range_loop(loop_body, start, stop, step) @@ -1013,9 +1040,11 @@ def write_state_to_memory(r): def f(i): state = tuplify(initializer()) start_block = get_block() + j = i * n_parallel + one = regint(1) for k in range(n_parallel): - j = i * n_parallel + k state = reducer(tuplify(loop_body(j)), state) + j += one if n_parallel > 1 and start_block != get_block(): print('WARNING: parallelization broken ' 'by control flow instruction') @@ -1514,11 +1543,17 @@ class State: pass state = State() if callable(condition): condition = condition() + try: + if not condition.is_clear: + raise CompilerError('cannot branch on secret values') + except AttributeError: + pass state.condition = regint.conv(condition) state.start_block = instructions.program.curr_block state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \ name='if-block') state.has_else = False + state.caller = [frame[1:] for frame in inspect.stack()[1:]] instructions.program.curr_tape.if_states.append(state) def else_then(): @@ -1884,7 +1919,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False): theta = int(ceil(log(k/3.5) / log(2))) base.set_global_vector_size(b.size) - alpha = b.get_type(2 * k).two_power(2*f) + alpha = b.get_type(2 * k).two_power(2*f, size=b.size) w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k) x = alpha - b.extend(2 * k) * w base.reset_global_vector_size() diff --git a/Compiler/ml.py b/Compiler/ml.py old mode 100755 new mode 100644 index 02f0f04..c667e1d --- a/Compiler/ml.py +++ b/Compiler/ml.py @@ -73,8 +73,13 @@ def log_e(x): return mpc_math.log_fx(x, math.e) +use_mux = False + def exp(x): - return mpc_math.pow_fx(math.e, x) + if use_mux: + return mpc_math.mux_exp(math.e, x) + else: + return mpc_math.pow_fx(math.e, x) def get_limit(x): exp_limit = 2 ** (x.k - x.f - 1) @@ -148,7 +153,7 @@ def argmax(x): """ Compute index of maximum element. :param x: iterable - :returns: sint + :returns: sint or 0 if :py:obj:`x` has length 1 """ def op(a, b): comp = (a[1] > b[1]) @@ -164,13 +169,16 @@ def softmax(x): return softmax_from_exp(exp_for_softmax(x)[0]) def exp_for_softmax(x): - m = util.max(x) + m = util.max(x) - get_limit(x[0]) + math.log(len(x)) mv = m.expand_to_vector(len(x)) try: x = x.get_vector() except AttributeError: x = sfix(x) - return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m + if use_mux: + return exp(x - mv), m + else: + return (x - mv > -get_limit(x)).if_else(exp(x - mv), 0), m def softmax_from_exp(x): return x / sum(x) @@ -1072,6 +1080,7 @@ def __init__(self, shape, strides=(1, 2, 2, 1), ksize=(1, 2, 2, 1), self.nabla_Y = Tensor(output_shape, sfix) self.N = shape[0] self.comparisons = MultiArray([self.N, self.X.sizes[3], + output_shape[1], output_shape[2], ksize[1] * ksize[2]], sint) def __repr__(self): @@ -1091,26 +1100,28 @@ def m(a, b): red = util.tree_reduce(m, [(x[0], [1] if training else []) for x in pool]) self.Y[bi][i][j][k] = red[0] - for i, x in enumerate(red[1]): - self.comparisons[bi][k][i] = x + for ii, x in enumerate(red[1]): + self.comparisons[bi][k][i][j][ii] = x self.traverse(batch, process) def backward(self, compute_nabla_X=True, batch=None): if compute_nabla_X: self.nabla_X.alloc() + self.nabla_X.assign_all(0) def process(pool, bi, k, i, j): - for (x, h_in, w_in, h, w), c in zip(pool, - self.comparisons[bi][k]): + for (x, h_in, w_in, h, w), c \ + in zip(pool, self.comparisons[bi][k][i][j]): hh = h * h_in ww = w * w_in - self.nabla_X[bi][hh][ww][k] = \ - util.if_else(h_in * w_in, c * self.nabla_Y[bi][i][j][k], - self.nabla_X[bi][hh][ww][k]) + res = h_in * w_in * c * self.nabla_Y[bi][i][j][k] + self.nabla_X[bi][hh][ww][k] += res self.traverse(batch, process) def traverse(self, batch, process): need_padding = [self.strides[i] * (self.Y.sizes[i] - 1) + self.ksize[i] > self.X.sizes[i] for i in range(4)] + overlap = reduce(operator.or_, + (x < y for x, y in zip(self.strides, self.ksize))) @for_range_opt_multithread(self.n_threads, [len(batch), self.X.sizes[3]]) def _(l, k): @@ -1120,6 +1131,8 @@ def _(i): h_base = self.strides[1] * i @for_range_opt(self.Y.sizes[2]) def _(j): + if overlap: + break_point() w_base = self.strides[2] * j pool = [] for ii in range(self.ksize[1]): @@ -2002,6 +2015,9 @@ def from_args(program, layers): return res def __init__(self, report_loss=None): + if get_program().options.binary: + raise CompilerError( + 'machine learning code not compatible with binary circuits') self.tol = 0.000 self.report_loss = report_loss self.X_by_label = None @@ -2384,6 +2400,11 @@ def output_weights(self): for layer in self.layers: layer.output_weights() + def summary(self): + sizes = [var.total_size() for var in self.thetas] + print(sizes) + print('Trainable params:', sum(sizes)) + class Adam(Optimizer): """ Adam/AMSgrad optimizer. @@ -2653,9 +2674,7 @@ def trainable_variables(self): return list(self.opt.thetas) def summary(self): - sizes = [var.total_size() for var in self.trainable_variables] - print(sizes) - print('Trainable params:', sum(sizes)) + self.opt.summary() def build(self, input_shape, batch_size=128): data_input_shape = input_shape diff --git a/Compiler/mpc_math.py b/Compiler/mpc_math.py old mode 100755 new mode 100644 index 47253dc..8f09bd7 --- a/Compiler/mpc_math.py +++ b/Compiler/mpc_math.py @@ -8,6 +8,8 @@ import math +import operator +from functools import reduce from Compiler import floatingpoint from Compiler import types from Compiler import comparison @@ -295,7 +297,6 @@ class my_fix(type(a)): intbitint = types.intbitint n_shift = int(types.program.options.ring) - a.k if types.program.use_split(): - assert not zero_output from Compiler.GC.types import sbitvec if types.program.use_split() == 3: x = a.v.split_to_two_summands(a.k) @@ -327,6 +328,7 @@ class my_fix(type(a)): s = sint.conv(bits[-1]) lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f]) higher_bits = bits[a.f:n_bits] + bits_to_check = bits[n_bits:-1] else: if types.program.use_edabit(): l = sint.get_edabit(a.f, True) @@ -338,7 +340,7 @@ class my_fix(type(a)): r_bits = [sint.get_random_bit() for i in range(a.k)] r = sint.bit_compose(r_bits) lower_r = sint.bit_compose(r_bits[:a.f]) - shifted = ((a.v - r) << n_shift).reveal() + shifted = ((a.v - r) << n_shift).reveal(False) masked_bits = (shifted >> n_shift).bit_decompose(a.k) lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1], r_bits[a.f-1::-1]) @@ -398,6 +400,36 @@ class my_fix(type(a)): return s.if_else(1 / g, g) +def mux_exp(x, y, block_size=8): + assert util.is_constant_float(x) + from Compiler.GC.types import sbitvec, sbits + bits = sbitvec.from_vec(y.v.bit_decompose(y.k, maybe_mixed=True)).v + sign = bits[-1] + m = math.log(2 ** (y.k - y.f - 1), x) + del bits[int(math.ceil(math.log(m, 2))) + y.f:] + parts = [] + for i in range(0, len(bits), block_size): + one_hot = sbitvec.from_vec(bits[i:i + block_size]).demux().v + exp = [] + try: + for j in range(len(one_hot)): + exp.append(types.cfix.int_rep(x ** (j * 2 ** (i - y.f)), y.f)) + except OverflowError: + pass + exp = list(filter(lambda x: x < 2 ** (y.k - 1), exp)) + bin_part = [0] * max(x.bit_length() for x in exp) + for j in range(len(bin_part)): + for k, (a, b) in enumerate(zip(one_hot, exp)): + bin_part[j] ^= a if util.bit_decompose(b, len(bin_part))[j] \ + else 0 + if util.is_zero(bin_part[j]): + bin_part[j] = sbits.get_type(y.size)(0) + if i == 0: + bin_part[j] = sign.if_else(0, bin_part[j]) + parts.append(y._new(y.int_type(sbitvec.from_vec(bin_part)))) + return util.tree_reduce(operator.mul, parts) + + @types.vectorize @instructions_base.sfix_cisc def log2_fx(x, use_division=True): @@ -420,6 +452,8 @@ def log2_fx(x, use_division=True): p -= x.f vlen = x.f v = x._new(v, k=x.k, f=x.f) + elif isinstance(x, types._register): + return log2_fx(types.sfix(x), use_division) else: d = types.sfloat(x) v, p, vlen = d.v, d.p, d.vlen @@ -501,7 +535,7 @@ def abs_fx(x): # # @return floored sint value of x def floor_fx(x): - return load_sint(floatingpoint.Trunc(x.v, x.k - x.f, x.f, x.kappa), type(x)) + return load_sint(floatingpoint.Trunc(x.v, x.k, x.f, x.kappa), type(x)) ### sqrt methods diff --git a/Compiler/non_linear.py b/Compiler/non_linear.py old mode 100755 new mode 100644 index 01cb4db..66e8290 --- a/Compiler/non_linear.py +++ b/Compiler/non_linear.py @@ -32,6 +32,8 @@ def trunc_pr(self, a, k, m, signed=True): return shift_two(a, m) prog = program.Program.prog if prog.use_trunc_pr: + if not prog.options.ring: + prog.curr_tape.require_bit_length(k + prog.security) if signed and prog.use_trunc_pr != -1: a += (1 << (k - 1)) res = sint() diff --git a/Compiler/oram.py b/Compiler/oram.py old mode 100755 new mode 100644 index d4b4343..bbaa393 --- a/Compiler/oram.py +++ b/Compiler/oram.py @@ -805,11 +805,11 @@ def batch_init(self, values): class TrivialORAM(RefTrivialORAM, AbstractORAM): """ Trivial ORAM (obviously). """ ref_type = RefTrivialORAM - def __init__(self, size, value_type=sint, value_length=1, index_size=None, \ + def __init__(self, size, value_type=None, value_length=1, index_size=None, \ entry_size=None, contiguous=True, init_rounds=-1): self.index_size = index_size or log2(size) - self.value_type = value_type - self.index_type = value_type.get_type(self.index_size) + self.value_type = value_type or sint + self.index_type = self.value_type.get_type(self.index_size) if entry_size is None: self.value_length = value_length self.entry_size = [None] * value_length @@ -862,15 +862,16 @@ def _read(self, index): empty_entry = self.empty_entry(False) demux_array(bit_decompose(index, self.index_size), \ self.index_vector) + t = self.value_type.get_type(None if None in self.entry_size else max(self.entry_size)) @map_sum(get_n_threads(self.size), n_parallel, self.size, \ - self.value_length + 1, [self.value_type.bit_type] + \ - [self.value_type.get_type(l) for l in self.entry_size]) + self.value_length + 1, t) def f(i): entry = self.ram[i] access_here = self.index_vector[i] return access_here * ValueTuple((entry.empty(),) + entry.x) - not_found = f()[0] - read_value = ValueTuple(f()[1:]) + not_found * empty_entry.x + not_found = self.value_type.bit_type(f()[0]) + read_value = ValueTuple(self.value_type.get_type(l)(x) for l, x in zip(self.entry_size, f()[1:])) + \ + not_found * empty_entry.x maybe_stop_timer(6) return read_value, not_found @method_block @@ -879,7 +880,9 @@ def _write(self, index, *new_value): empty_entry = self.empty_entry(False) demux_array(bit_decompose(index, self.index_size), \ self.index_vector) - new_value = make_array(new_value) + new_value = make_array( + new_value, self.value_type.get_type( + max(x or 0 for x in self.entry_size))) @for_range_multithread(get_n_threads(self.size), n_parallel, self.size) def f(i): entry = self.ram[i] @@ -895,7 +898,9 @@ def _access(self, index, write, new_empty, *new_value): empty_entry = self.empty_entry(False) index_vector = \ demux_array(bit_decompose(index, self.index_size)) - new_value = make_array(new_value) + new_value = make_array( + new_value, self.value_type.get_type( + max(x or 0 for x in self.entry_size))) new_empty = MemValue(new_empty) write = MemValue(write) @map_sum(get_n_threads(self.size), n_parallel, self.size, \ @@ -1029,8 +1034,9 @@ def get_n_threads_for_tree(size): class TreeORAM(AbstractORAM): """ Tree ORAM. """ - def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \ + def __init__(self, size, value_type=None, value_length=1, entry_size=None, \ bucket_oram=TrivialORAM, init_rounds=-1): + value_type = value_type or sint print('create oram of size', size) self.bucket_oram = bucket_oram # heuristic bucket size @@ -1227,7 +1233,8 @@ def batch_init(self, values): """ Batch initalization. Obliviously shuffles and adds N entries to random leaf buckets. """ m = len(values) - assert((m & (m-1)) == 0) + if not (m & (m-1)) == 0: + raise CompilerError('Batch size must a power of 2.') if m != self.size: raise CompilerError('Batch initialization must have N values.') if self.value_type != sint: @@ -1675,6 +1682,39 @@ class OneLevelORAM(TreeORAM): pattern after one recursion. """ index_structure = BaseORAMIndexStructure +class BinaryORAM: + def __init__(self, size, value_type=None, **kwargs): + import circuit_oram + from Compiler.GC import types + n_bits = int(get_program().options.binary) + self.value_type = value_type or types.sbitintvec.get_type(n_bits) + self.index_type = self.value_type + oram_value_type = types.sbits.get_type(64) + if 'entry_size' not in kwargs: + kwargs['entry_size'] = n_bits + self.oram = circuit_oram.OptimalCircuitORAM( + size, value_type=oram_value_type, **kwargs) + self.size = size + def get_index(self, index): + return self.oram.value_type(self.index_type.conv(index).elements()[0]) + def __setitem__(self, index, value): + value = list(self.oram.value_type( + self.value_type.conv(v).elements()[0]) for v in tuplify(value)) + self.oram[self.get_index(index)] = value + def __getitem__(self, index): + value = self.oram[self.get_index(index)] + return untuplify(tuple(self.value_type(v) for v in tuplify(value))) + def read(self, index): + return self.oram.read(index) + def read_and_maybe_remove(self, index): + return self.oram.read_and_maybe_remove(index) + def access(self, *args): + return self.oram.access(*args) + def add(self, *args, **kwargs): + return self.oram.add(*args, **kwargs) + def delete(self, *args, **kwargs): + return self.oram.delete(*args, **kwargs) + def OptimalORAM(size,*args,**kwargs): """ Create an ORAM instance suitable for the size based on experiments. @@ -1683,6 +1723,10 @@ def OptimalORAM(size,*args,**kwargs): :param value_type: :py:class:`sint` (default) / :py:class:`sg2fn` / :py:class:`sfix` """ + if not util.is_constant(size): + raise CompilerError('ORAM size has be a compile-time constant') + if get_program().options.binary: + return BinaryORAM(size, *args, **kwargs) if optimal_threshold is None: if n_threads == 1: threshold = 2**11 @@ -1731,6 +1775,12 @@ class OptimalPackedORAMWithEmpty(PackedORAMWithEmpty): def test_oram(oram_type, N, value_type=sint, iterations=100): stop_grind() oram = oram_type(N, value_type=value_type, entry_size=32, init_rounds=0) + test_oram_initialized(oram, iterations) + return oram + +def test_oram_initialized(oram, iterations=100): + N = oram.size + value_type = oram.value_type value_type = value_type.get_type(32) index_type = value_type.get_type(log2(N)) start_grind() diff --git a/Compiler/path_oram.py b/Compiler/path_oram.py old mode 100755 new mode 100644 diff --git a/Compiler/permutation.py b/Compiler/permutation.py old mode 100755 new mode 100644 diff --git a/Compiler/program.py b/Compiler/program.py old mode 100755 new mode 100644 index 78b802e..dfe08f8 --- a/Compiler/program.py +++ b/Compiler/program.py @@ -4,39 +4,41 @@ object that holds various properties of the computation. """ -from Compiler.config import * -from Compiler.exceptions import * -from Compiler.instructions_base import RegType -import Compiler.instructions -import Compiler.instructions_base -import Compiler.instructions_base as inst_base -from . import allocator as al -from . import util -import random -import time -import sys, os, errno import inspect -from collections import defaultdict, deque import itertools import math -from functools import reduce +import os import re +import sys +from collections import defaultdict, deque +from functools import reduce + +import Compiler.instructions +import Compiler.instructions_base +import Compiler.instructions_base as inst_base +from Compiler.config import REG_MAX, USER_MEM, COST +from Compiler.exceptions import CompilerError +from Compiler.instructions_base import RegType +from . import allocator as al +from . import util data_types = dict( - triple = 0, - square = 1, - bit = 2, - inverse = 3, - dabit = 4, + triple=0, + square=1, + bit=2, + inverse=3, + dabit=4, + mixed=5, ) field_types = dict( - modp = 0, - gf2n = 1, - bit = 2, + modp=0, + gf2n=1, + bit=2, ) + class defaults: debug = False verbose = False @@ -44,11 +46,13 @@ class defaults: ring = 0 field = 0 binary = 0 + garbled = False prime = None galois = 40 budget = 100000 mixed = False edabit = False + invperm = False split = None cisc = False comparison = None @@ -62,8 +66,9 @@ class defaults: insecure = False keep_cisc = False + class Program(object): - """ A program consists of a list of tapes representing the whole + """A program consists of a list of tapes representing the whole computation. When compiling an :file:`.mpc` file, the single instances is @@ -71,20 +76,22 @@ class Program(object): from Python code, an instance has to be created before running any instructions. """ - def __init__(self, args, options=defaults): - from .non_linear import Ring, Prime, KnownPrime + + def __init__(self, args, options=defaults, name=None): + from .non_linear import KnownPrime, Prime + self.options = options self.verbose = options.verbose self.args = args + self.name = name self.init_names(args) self._security = 40 self.prime = None self.tapes = [] - if sum(x != 0 for x in(options.ring, options.field, - options.binary)) > 1: - raise CompilerError('can only use one out of -B, -R, -F') + if sum(x != 0 for x in (options.ring, options.field, options.binary)) > 1: + raise CompilerError("can only use one out of -B, -R, -F") if options.prime and (options.ring or options.binary): - raise CompilerError('can only use one out of -B, -R, -p') + raise CompilerError("can only use one out of -B, -R, -p") if options.ring: self.set_ring_size(int(options.ring)) else: @@ -93,19 +100,20 @@ def __init__(self, args, options=defaults): self.prime = int(options.prime) max_bit_length = int(options.prime).bit_length() - 2 if self.bit_length > max_bit_length: - raise CompilerError('integer bit length can be maximal %s' % - max_bit_length) + raise CompilerError( + "integer bit length can be maximal %s" % max_bit_length + ) self.bit_length = self.bit_length or max_bit_length self.non_linear = KnownPrime(self.prime) else: self.non_linear = Prime(self.security) if not self.bit_length: self.bit_length = 64 - print('Default bit length:', self.bit_length) - print('Default security parameter:', self.security) + print("Default bit length:", self.bit_length) + print("Default security parameter:", self.security) self.galois_length = int(options.galois) if self.verbose: - print('Galois length:', self.galois_length) + print("Galois length:", self.galois_length) self.tape_counter = 0 self._curr_tape = None self.DEBUG = options.debug @@ -119,29 +127,44 @@ def __init__(self, args, options=defaults): self.public_input_file = None self.types = {} self.budget = int(self.options.budget) - self.to_merge = [Compiler.instructions.asm_open_class, \ - Compiler.instructions.gasm_open_class, \ - Compiler.instructions.muls_class, \ - Compiler.instructions.gmuls_class, \ - Compiler.instructions.mulrs_class, \ - Compiler.instructions.gmulrs, \ - Compiler.instructions.dotprods_class, \ - Compiler.instructions.gdotprods_class, \ - Compiler.instructions.asm_input_class, \ - Compiler.instructions.gasm_input_class, - Compiler.instructions.inputfix_class, - Compiler.instructions.inputfloat_class, - Compiler.instructions.inputmixed_class, - Compiler.instructions.trunc_pr_class, - Compiler.instructions_base.Mergeable] + self.to_merge = [ + Compiler.instructions.asm_open_class, + Compiler.instructions.gasm_open_class, + Compiler.instructions.muls_class, + Compiler.instructions.gmuls_class, + Compiler.instructions.mulrs_class, + Compiler.instructions.gmulrs, + Compiler.instructions.dotprods_class, + Compiler.instructions.gdotprods_class, + Compiler.instructions.asm_input_class, + Compiler.instructions.gasm_input_class, + Compiler.instructions.inputfix_class, + Compiler.instructions.inputfloat_class, + Compiler.instructions.inputmixed_class, + Compiler.instructions.trunc_pr_class, + Compiler.instructions_base.Mergeable, + ] import Compiler.GC.instructions as gc - self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \ - gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb] + + self.to_merge += [ + gc.ldmsdi, + gc.stmsdi, + gc.ldmsd, + gc.stmsd, + gc.stmsdci, + gc.andrs, + gc.ands, + gc.inputb, + gc.inputbvec, + gc.reveal, + ] self.use_trunc_pr = False """ Setting whether to use special probabilistic truncation. """ self.use_dabit = options.mixed """ Setting whether to use daBits for non-linear functionality. """ self._edabit = options.edabit + """ Whether to use the low-level INVPERM instruction (only implemented with the assumption of a semi-honest two-party environment)""" + self._invperm = options.invperm self._split = False if options.split: self.use_split(int(options.split)) @@ -153,7 +176,8 @@ def __init__(self, args, options=defaults): self.n_running_threads = None self.input_files = {} Program.prog = self - from . import instructions_base, instructions, types, comparison + from . import comparison, instructions, instructions_base, types + instructions.program = self instructions_base.program = self types.program = self @@ -164,53 +188,53 @@ def get_args(self): return self.args def max_par_tapes(self): - """ Upper bound on number of tapes that will be run in parallel. - (Excludes empty tapes) """ + """Upper bound on number of tapes that will be run in parallel. + (Excludes empty tapes)""" return self.n_threads - + def init_names(self, args): # ignore path to file - source must be in Programs/Source - if 'Programs' in os.listdir(os.getcwd()): + if "Programs" in os.listdir(os.getcwd()): # compile prog in ./Programs/Source directory - self.programs_dir = os.getcwd() + '/Programs' + self.programs_dir = os.getcwd() + "/Programs" else: # assume source is in main SPDZ directory - self.programs_dir = sys.path[0] + '/Programs' + self.programs_dir = sys.path[0] + "/Programs" if self.verbose: - print('Compiling program in', self.programs_dir) - + print("Compiling program in", self.programs_dir) + # create extra directories if needed - for dirname in ['Public-Input', 'Bytecode', 'Schedules']: - if not os.path.exists(self.programs_dir + '/' + dirname): - os.mkdir(self.programs_dir + '/' + dirname) - - progname = args[0].split('/')[-1] - if progname.endswith('.mpc'): - progname = progname[:-4] - - if os.path.exists(args[0]): - self.infile = args[0] - else: - self.infile = self.programs_dir + '/Source/' + progname + '.mpc' + for dirname in ["Public-Input", "Bytecode", "Schedules"]: + if not os.path.exists(self.programs_dir + "/" + dirname): + os.mkdir(self.programs_dir + "/" + dirname) + + if self.name is None: + self.name = args[0].split("/")[-1] + if self.name.endswith(".mpc"): + self.name = self.name[:-4] + + if os.path.exists(args[0]): + self.infile = args[0] + else: + self.infile = self.programs_dir + "/Source/" + self.name + ".mpc" """ self.name is input file name (minus extension) + any optional arguments. Used to generate output filenames """ if self.options.outfile: - self.name = self.options.outfile + '-' + progname + self.name = self.options.outfile + "-" + self.name else: - self.name = progname + self.name = self.name if len(args) > 1: - self.name += '-' + '-'.join(re.sub('/', '_', arg) - for arg in args[1:]) - self.progname = progname + self.name += "-" + "-".join(re.sub("/", "_", arg) for arg in args[1:]) def set_ring_size(self, ring_size): from .non_linear import Ring + for tape in self.tapes: - prev = tape.req_bit_length['p'] + prev = tape.req_bit_length["p"] if prev and prev != ring_size: - raise CompilerError('cannot have different ring sizes') + raise CompilerError("cannot have different ring sizes") self.bit_length = ring_size - 1 self.non_linear = Ring(ring_size) self.options.ring = str(ring_size) @@ -234,7 +258,8 @@ def g(): :param function: Python function defining the thread :param args: arguments to the function :param name: name used for files - :param single_thread: Boolean indicating whether tape will never be run in parallel to itself + :param single_thread: Boolean indicating whether tape will + never be run in parallel to itself :returns: tape handle """ @@ -258,20 +283,22 @@ def run_tape(self, tape_index, arg): return self.run_tapes([[tape_index, arg]])[0] def run_tapes(self, args): - """ Run tapes in parallel. See :py:func:`new_tape` for an example. + """Run tapes in parallel. See :py:func:`new_tape` for an example. - :param args: list of tape handles or tuples of tape handle and extra argument (for :py:func:`~Compiler.library.get_arg`) + :param args: list of tape handles or tuples of tape handle and extra + argument (for :py:func:`~Compiler.library.get_arg`) :returns: list of thread numbers """ if not self.curr_tape.singular: - raise CompilerError('Compiler does not support ' \ - 'recursive spawning of threads') + raise CompilerError( + "Compiler does not support " "recursive spawning of threads" + ) args = [list(util.tuplify(arg)) for arg in args] singular_tapes = set() for arg in args: if self.tapes[arg[0]].singular: if arg[0] in singular_tapes: - raise CompilerError('cannot run singular tape in parallel') + raise CompilerError("cannot run singular tape in parallel") singular_tapes.add(arg[0]) assert len(arg) assert len(arg) <= 2 @@ -286,59 +313,60 @@ def run_tapes(self, args): else: thread_numbers.append(self.n_threads) self.n_threads += 1 - self.curr_tape.start_new_basicblock(name='pre-run_tape') - Compiler.instructions.run_tape(*sum(([x] + list(y) for x, y in - zip(thread_numbers, args)), [])) - self.curr_tape.start_new_basicblock(name='post-run_tape') + self.curr_tape.start_new_basicblock(name="pre-run_tape") + Compiler.instructions.run_tape( + *sum(([x] + list(y) for x, y in zip(thread_numbers, args)), []) + ) + self.curr_tape.start_new_basicblock(name="post-run_tape") for arg in args: - self.curr_tape.req_node.children.append( - self.tapes[arg[0]].req_tree) + self.curr_tape.req_node.children.append(self.tapes[arg[0]].req_tree) return thread_numbers def join_tape(self, thread_number): self.join_tapes([thread_number]) def join_tapes(self, thread_numbers): - """ Wait for completion of tapes. See :py:func:`new_tape` for an example. + """Wait for completion of tapes. See :py:func:`new_tape` for an example. :param thread_numbers: list of thread numbers """ - self.curr_tape.start_new_basicblock(name='pre-join_tape') + self.curr_tape.start_new_basicblock(name="pre-join_tape") for thread_number in thread_numbers: Compiler.instructions.join_tape(thread_number) self.curr_tape.free_threads.add(thread_number) - self.curr_tape.start_new_basicblock(name='post-join_tape') + self.curr_tape.start_new_basicblock(name="post-join_tape") def update_req(self, tape): if self.req_num is None: self.req_num = tape.req_num else: self.req_num += tape.req_num - + def write_bytes(self): - """ Write all non-empty threads and schedule to files. """ + """Write all non-empty threads and schedule to files.""" nonempty_tapes = [t for t in self.tapes] - sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name - sch_file = open(sch_filename, 'w') - print('Writing to', sch_filename) - sch_file.write(str(self.max_par_tapes()) + '\n') - sch_file.write(str(len(nonempty_tapes)) + '\n') - sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n') - sch_file.write('1 0\n') - sch_file.write('0\n') - sch_file.write(' '.join(sys.argv) + '\n') - req = max(x.req_bit_length['p'] for x in self.tapes) + sch_filename = self.programs_dir + "/Schedules/%s.sch" % self.name + sch_file = open(sch_filename, "w") + print("Writing to", sch_filename) + sch_file.write(str(self.max_par_tapes()) + "\n") + sch_file.write(str(len(nonempty_tapes)) + "\n") + sch_file.write(" ".join("%s:%d" % (tape.name, len(tape)) + for tape in nonempty_tapes) + "\n") + sch_file.write("1 0\n") + sch_file.write("0\n") + sch_file.write(" ".join(sys.argv) + "\n") + req = max(x.req_bit_length["p"] for x in self.tapes) if self.options.ring: - sch_file.write('R:%s' % self.options.ring) + sch_file.write("R:%s" % self.options.ring) elif self.options.prime: - sch_file.write('p:%s' % self.options.prime) + sch_file.write("p:%s" % self.options.prime) else: - sch_file.write('lgp:%s' % req) - sch_file.write('\n') - sch_file.write('opts: %s\n' % ' '.join(self.relevant_opts)) + sch_file.write("lgp:%s" % req) + sch_file.write("\n") + sch_file.write("opts: %s\n" % " ".join(self.relevant_opts)) for tape in self.tapes: tape.write_bytes() @@ -347,12 +375,12 @@ def finalize_tape(self, tape): tape.optimize(self.options) tape.write_bytes() if self.options.asmoutfile: - tape.write_str(self.options.asmoutfile + '-' + tape.name) + tape.write_str(self.options.asmoutfile + "-" + tape.name) tape.purge() - + @property def curr_tape(self): - """ The tape that is currently running.""" + """The tape that is currently running.""" if self._curr_tape is None: assert not self.tapes self._curr_tape = Tape(self.name, self) @@ -365,13 +393,13 @@ def curr_tape(self, value): @property def curr_block(self): - """ The basic block that is currently being created. """ + """The basic block that is currently being created.""" return self.curr_tape.active_basicblock - + def malloc(self, size, mem_type, reg_type=None, creator_tape=None): - """ Allocate memory from the top """ + """Allocate memory from the top""" if not isinstance(size, int): - raise CompilerError('size must be known at compile time') + raise CompilerError("size must be known at compile time") if size == 0: return if isinstance(mem_type, type): @@ -389,8 +417,7 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): single_size = size size *= self.n_running_threads else: - raise CompilerError('cannot allocate memory ' - 'outside main thread') + raise CompilerError("cannot allocate memory " "outside main thread") blocks = self.free_mem_blocks[mem_type] addr = blocks.pop(size) if addr is not None: @@ -400,24 +427,23 @@ def malloc(self, size, mem_type, reg_type=None, creator_tape=None): self.allocated_mem[mem_type] += size if len(str(addr)) != len(str(addr + size)) and self.verbose: print("Memory of type '%s' now of size %d" % (mem_type, addr + size)) - if addr + size >= 2 ** 32: - raise CompilerError("allocation exceeded for type '%s'" % - mem_type) - self.allocated_mem_blocks[addr,mem_type] = size + if addr + size >= 2**64: + raise CompilerError("allocation exceeded for type '%s'" % mem_type) + self.allocated_mem_blocks[addr, mem_type] = size if single_size: from .library import get_thread_number, runtime_error_if + tn = get_thread_number() - runtime_error_if(tn > self.n_running_threads, 'malloc') + runtime_error_if(tn > self.n_running_threads, "malloc") return addr + single_size * (tn - 1) else: return addr def free(self, addr, mem_type): - """ Free memory """ - if self.curr_block.alloc_pool \ - is not self.curr_tape.basicblocks[0].alloc_pool: - raise CompilerError('Cannot free memory within function block') - size = self.allocated_mem_blocks.pop((addr,mem_type)) + """Free memory""" + if self.curr_block.alloc_pool is not self.curr_tape.basicblocks[0].alloc_pool: + raise CompilerError("Cannot free memory within function block") + size = self.allocated_mem_blocks.pop((addr, mem_type)) self.free_mem_blocks[mem_type].push(addr, size) def finalize(self): @@ -435,47 +461,49 @@ def finalize(self): if self.options.asmoutfile: for tape in self.tapes: - tape.write_str(self.options.asmoutfile + '-' + tape.name) + tape.write_str(self.options.asmoutfile + "-" + tape.name) def finalize_memory(self): - from . import library - self.curr_tape.start_new_basicblock(None, 'memory-usage') + self.curr_tape.start_new_basicblock(None, "memory-usage") # reset register counter to 0 if not self.options.noreallocate: self.curr_tape.init_registers() - for mem_type,size in sorted(self.allocated_mem.items()): - if size: - #print "Memory of type '%s' of size %d" % (mem_type, size) + for mem_type, size in sorted(self.allocated_mem.items()): + if size and (not self.options.garbled or \ + mem_type not in ('s', 'sg', 'c', 'cg')): + # print "Memory of type '%s' of size %d" % (mem_type, size) if mem_type in self.types: self.types[mem_type].load_mem(size - 1, mem_type) else: from Compiler.types import _get_type + _get_type(mem_type).load_mem(size - 1, mem_type) if self.verbose: if self.saved: - print('Saved %s memory units through reallocation' % self.saved) + print("Saved %s memory units through reallocation" % self.saved) def public_input(self, x): - """ Append a value to the public input file. """ + """Append a value to the public input file.""" if self.public_input_file is None: - self.public_input_file = open(self.programs_dir + - '/Public-Input/%s' % self.name, 'w') - self.public_input_file.write('%s\n' % str(x)) + self.public_input_file = open( + self.programs_dir + "/Public-Input/%s" % self.name, "w" + ) + self.public_input_file.write("%s\n" % str(x)) def set_bit_length(self, bit_length): - """ Change the integer bit length for non-linear functions. """ + """Change the integer bit length for non-linear functions.""" self.bit_length = bit_length - print('Changed bit length for comparisons etc. to', bit_length) + print("Changed bit length for comparisons etc. to", bit_length) def set_security(self, security): self._security = security self.non_linear.set_security(security) - print('Changed statistical security for comparison etc. to', security) + print("Changed statistical security for comparison etc. to", security) @property def security(self): - """ The statistical security parameter for non-linear - functions. """ + """The statistical security parameter for non-linear + functions.""" return self._security @security.setter @@ -483,7 +511,8 @@ def security(self, security): self.set_security(security) def optimize_for_gc(self): - pass + import Compiler.GC.instructions as gc + self.to_merge += [gc.xors] def get_tape_counter(self): res = self.tape_counter @@ -493,7 +522,7 @@ def get_tape_counter(self): @property def use_trunc_pr(self): if not self._use_trunc_pr: - self.relevant_opts.add('trunc_pr') + self.relevant_opts.add("trunc_pr") return self._use_trunc_pr @use_trunc_pr.setter @@ -501,7 +530,7 @@ def use_trunc_pr(self, change): self._use_trunc_pr = change def use_edabit(self, change=None): - """ Setting whether to use edaBits for non-linear + """Setting whether to use edaBits for non-linear functionality (default: false). :param change: change setting if not :py:obj:`None` @@ -509,16 +538,30 @@ def use_edabit(self, change=None): """ if change is None: if not self._edabit: - self.relevant_opts.add('edabit') + self.relevant_opts.add("edabit") return self._edabit else: self._edabit = change + def use_invperm(self, change=None): + """ Set whether to use the low-level INVPERM instruction to inverse a permutation (see sint.inverse_permutation). The INVPERM instruction assumes a semi-honest two-party environment. If false, a general protocol implemented in the high-level language is used. + + :param change: change setting if not :py:obj:`None` + :returns: setting if :py:obj:`change` is :py:obj:`None` + """ + if change is None: + if not self._invperm: + self.relevant_opts.add("invperm") + return self._invperm + else: + self._invperm = change + + def use_edabit_for(self, *args): return True def use_split(self, change=None): - """ Setting whether to use local arithmetic-binary share + """Setting whether to use local arithmetic-binary share conversion for non-linear functionality (default: false). :param change: change setting if not :py:obj:`None` @@ -526,16 +569,16 @@ def use_split(self, change=None): """ if change is None: if not self._split: - self.relevant_opts.add('split') + self.relevant_opts.add("split") return self._split else: if change and not self.options.ring: - raise CompilerError('splitting only supported for rings') - assert change > 1 or change == False + raise CompilerError("splitting only supported for rings") + assert change > 1 or change is False self._split = change def use_square(self, change=None): - """ Setting whether to use preprocessed square tuples + """Setting whether to use preprocessed square tuples (default: false). :param change: change setting if not :py:obj:`None` @@ -559,22 +602,24 @@ def linear_rounds(self, change=None): self._linear_rounds = change def options_from_args(self): - """ Set a number of options from the command-line arguments. """ - if 'trunc_pr' in self.args: + """Set a number of options from the command-line arguments.""" + if "trunc_pr" in self.args: self.use_trunc_pr = True - if 'signed_trunc_pr' in self.args: + if "signed_trunc_pr" in self.args: self.use_trunc_pr = -1 - if 'split' in self.args or 'split3' in self.args: + if "split" in self.args or "split3" in self.args: self.use_split(3) for arg in self.args: - m = re.match('split([0-9]+)', arg) + m = re.match("split([0-9]+)", arg) if m: self.use_split(int(m.group(1))) - if 'raw' in self.args: + if "raw" in self.args: self.always_raw(True) - if 'edabit' in self.args: + if "edabit" in self.args: self.use_edabit(True) - if 'linear_rounds' in self.args: + if "invperm" in self.args: + self.use_invperm(True) + if "linear_rounds" in self.args: self.linear_rounds(True) def disable_memory_warnings(self): @@ -583,28 +628,32 @@ def disable_memory_warnings(self): @staticmethod def read_tapes(schedule): - m = re.search(r'([^/]*)\.mpc', schedule) + m = re.search(r"([^/]*)\.mpc", schedule) if m: schedule = m.group(1) if not os.path.exists(schedule): - schedule = 'Programs/Schedules/%s.sch' % schedule + schedule = "Programs/Schedules/%s.sch" % schedule try: lines = open(schedule).readlines() except FileNotFoundError: - print('%s not found, have you compiled the program?' % schedule, - file=sys.stderr) + print( + "%s not found, have you compiled the program?" % schedule, + file=sys.stderr, + ) sys.exit(1) - for tapename in lines[2].split(' '): + for tapename in lines[2].split(" "): yield tapename.strip() + class Tape: - """ A tape contains a list of basic blocks, onto which instructions are added. """ + """A tape contains a list of basic blocks, onto which instructions are added.""" + def __init__(self, name, program): - """ Set prime p and the initial instructions and registers. """ + """Set prime p and the initial instructions and registers.""" self.program = program - name += '-%d' % program.get_tape_counter() + name += "-%d" % program.get_tape_counter() self.init_names(name) self.init_registers() self.req_tree = self.ReqNode(name) @@ -643,6 +692,7 @@ def __init__(self, parent, name, scope, exit_condition=None): self.purged = False self.n_rounds = 0 self.n_to_merge = 0 + self.rounds = Tape.ReqNum() self.warn_about_mem = parent.program.warn_about_mem[-1] def __len__(self): @@ -658,9 +708,9 @@ def set_return(self, previous_block, sub_block): def adjust_return(self): offset = self.sub_block.get_offset(self) self.previous_block.return_address_store.args[1] = offset - + def set_exit(self, condition, exit_true=None): - """ Sets the block which we start from next, depending on the condition. + """Sets the block which we start from next, depending on the condition. (Default is to go to next block in the list) """ @@ -668,34 +718,33 @@ def set_exit(self, condition, exit_true=None): self.exit_block = exit_true for reg in condition.get_used(): reg.can_eliminate = False - + def add_jump(self): - """ Add the jump for this block's exit condition to list of - instructions (must be done after merging) """ + """Add the jump for this block's exit condition to list of + instructions (must be done after merging)""" self.instructions.append(self.exit_condition) - + def get_offset(self, next_block): return next_block.offset - (self.offset + len(self.instructions)) - + def adjust_jump(self): - """ Set the correct relative jump offset """ + """Set the correct relative jump offset""" offset = self.get_offset(self.exit_block) self.exit_condition.set_relative_jump(offset) - #print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset) def purge(self, retain_usage=True): def relevant(inst): - req_node = Tape.ReqNode('') + req_node = Tape.ReqNode("") req_node.num = Tape.ReqNum() inst.add_usage(req_node) return req_node.num != {} + if retain_usage: - self.usage_instructions = list(filter(relevant, - self.instructions)) + self.usage_instructions = list(filter(relevant, self.instructions)) else: self.usage_instructions = [] if len(self.usage_instructions) > 1000: - print('Retaining %d instructions' % len(self.usage_instructions)) + print("Retaining %d instructions" % len(self.usage_instructions)) del self.instructions self.purged = True @@ -706,14 +755,15 @@ def add_usage(self, req_node): instructions = self.instructions for inst in instructions: inst.add_usage(req_node) - req_node.num['all', 'round'] += self.n_rounds - req_node.num['all', 'inv'] += self.n_to_merge + req_node.num["all", "round"] += self.n_rounds + req_node.num["all", "inv"] += self.n_to_merge + req_node.num += self.rounds def expand_cisc(self): new_instructions = [] - if self.parent.program.options.keep_cisc != None: - skip = ['LTZ', 'Trunc'] - skip += self.parent.program.options.keep_cisc.split(',') + if self.parent.program.options.keep_cisc is not None: + skip = ["LTZ", "Trunc"] + skip += self.parent.program.options.keep_cisc.split(",") else: skip = [] for inst in self.instructions: @@ -726,38 +776,45 @@ def __str__(self): return self.name def is_empty(self): - """ Returns True if the list of basic blocks is empty. + """Returns True if the list of basic blocks is empty. Note: False is returned even when tape only contains basic blocks with no instructions. However, these are removed when - optimize is called. """ + optimize is called.""" if not self.purged: - self._is_empty = (len(self.basicblocks) == 0) + self._is_empty = len(self.basicblocks) == 0 return self._is_empty - def start_new_basicblock(self, scope=False, name=''): + def start_new_basicblock(self, scope=False, name=""): # use False because None means no scope if scope is False: scope = self.active_basicblock - suffix = '%s-%d' % (name, self.block_counter) + suffix = "%s-%d" % (name, self.block_counter) self.block_counter += 1 - sub = self.BasicBlock(self, self.name + '-' + suffix, scope) + sub = self.BasicBlock(self, self.name + "-" + suffix, scope) self.basicblocks.append(sub) self.active_basicblock = sub self.req_node.add_block(sub) - #print 'Compiling basic block', sub.name + # print 'Compiling basic block', sub.name def init_registers(self): self.reg_counter = RegType.create_dict(lambda: 0) - + def init_names(self, name): self.name = name - self.outfile = self.program.programs_dir + '/Bytecode/' + self.name + '.bc' + self.outfile = self.program.programs_dir + "/Bytecode/" + self.name + ".bc" + + def __len__(self): + if self.purged: + return self.size + else: + return sum(len(block) for block in self.basicblocks) def purge(self): + self.size = len(self) for block in self.basicblocks: block.purge() - self._is_empty = (len(self.basicblocks) == 0) + self._is_empty = len(self.basicblocks) == 0 del self.basicblocks del self.active_basicblock self.purged = True @@ -767,19 +824,29 @@ def wrapper(self, *args, **kwargs): if self.purged: return return function(self, *args, **kwargs) + return wrapper @unpurged def optimize(self, options): if len(self.basicblocks) == 0: - print('Tape %s is empty' % self.name) + print("Tape %s is empty" % self.name) return if self.if_states: - raise CompilerError('Unclosed if/else blocks') + print("Tracebacks for open blocks:") + for state in self.if_states: + try: + print(util.format_trace(state.caller)) + except AttributeError: + pass + print() + raise CompilerError("Unclosed if/else blocks, see tracebacks above") if self.program.verbose: - print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)) + print( + "Processing tape", self.name, "with %d blocks" % len(self.basicblocks) + ) for block in self.basicblocks: al.determine_scope(block, options) @@ -787,41 +854,58 @@ def optimize(self, options): # merge open instructions # need to do this if there are several blocks if (options.merge_opens and self.merge_opens) or options.dead_code_elimination: - for i,block in enumerate(self.basicblocks): + for i, block in enumerate(self.basicblocks): if len(block.instructions) > 0 and self.program.verbose: - print('Processing basic block %s, %d/%d, %d instructions' % \ - (block.name, i, len(self.basicblocks), \ - len(block.instructions))) + print( + "Processing basic block %s, %d/%d, %d instructions" + % ( + block.name, + i, + len(self.basicblocks), + len(block.instructions), + ) + ) # the next call is necessary for allocation later even without merging - merger = al.Merger(block, options, \ - tuple(self.program.to_merge)) + merger = al.Merger(block, options, tuple(self.program.to_merge)) if options.dead_code_elimination: if len(block.instructions) > 1000000: - print('Eliminate dead code...') + print("Eliminate dead code...") merger.eliminate_dead_code() if options.merge_opens and self.merge_opens: if len(block.instructions) == 0: block.used_from_scope = util.set_by_id() continue if len(block.instructions) > 1000000: - print('Merging instructions...') + print("Merging instructions...") numrounds = merger.longest_paths_merge() block.n_rounds = numrounds block.n_to_merge = len(merger.open_nodes) + if options.verbose: + block.rounds = merger.req_num if merger.counter and self.program.verbose: - print('Block requires', \ - ', '.join('%d %s' % (y, x.__name__) \ - for x, y in list(merger.counter.items()))) + print( + "Block requires", + ", ".join( + "%d %s" % (y, x.__name__) + for x, y in list(merger.counter.items()) + ), + ) if merger.counter and self.program.verbose: - print('Block requires %s rounds' % \ - ', '.join('%d %s' % (y, x.__name__) \ - for x, y in list(merger.rounds.items()))) + print( + "Block requires %s rounds" + % ", ".join( + "%d %s" % (y, x.__name__) + for x, y in list(merger.rounds.items()) + ) + ) # free memory merger = None if options.dead_code_elimination: - block.instructions = [x for x in block.instructions if x is not None] + block.instructions = [ + x for x in block.instructions if x is not None + ] if not (options.merge_opens and self.merge_opens): - print('Not merging instructions in tape %s' % self.name) + print("Not merging instructions in tape %s" % self.name) if options.cisc: self.expand_cisc() @@ -846,19 +930,27 @@ def optimize(self, options): reg_counts = self.count_regs() if options.noreallocate: if self.program.verbose: - print('Tape register usage:', dict(reg_counts)) + print("Tape register usage:", dict(reg_counts)) else: if self.program.verbose: - print('Tape register usage before re-allocation:', - dict(reg_counts)) - print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])) - print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])) - print('Re-allocating...') + print("Tape register usage before re-allocation:", dict(reg_counts)) + print( + "modp: %d clear, %d secret" + % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]) + ) + print( + "GF2N: %d clear, %d secret" + % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]) + ) + print("Re-allocating...") allocator = al.StraightlineAllocator(REG_MAX, self.program) + def alloc(block): - for reg in sorted(block.used_from_scope, - key=lambda x: (x.reg_type, x.i)): + for reg in sorted( + block.used_from_scope, key=lambda x: (x.reg_type, x.i) + ): allocator.alloc_reg(reg, block.alloc_pool) + def alloc_loop(block): left = deque([block]) while left: @@ -866,73 +958,84 @@ def alloc_loop(block): alloc(block) for child in block.children: left.append(child) - for i,block in enumerate(reversed(self.basicblocks)): + + for i, block in enumerate(reversed(self.basicblocks)): if len(block.instructions) > 1000000: - print('Allocating %s, %d/%d' % \ - (block.name, i, len(self.basicblocks))) + print( + "Allocating %s, %d/%d" % (block.name, i, len(self.basicblocks)) + ) if block.exit_condition is not None: jump = block.exit_condition.get_relative_jump() - if isinstance(jump, int) and jump < 0 and \ - block.exit_block.scope is not None: + if ( + isinstance(jump, int) + and jump < 0 + and block.exit_block.scope is not None + ): alloc_loop(block.exit_block.scope) allocator.process(block.instructions, block.alloc_pool) allocator.finalize(options) if self.program.verbose: - print('Tape register usage:', dict(allocator.usage)) + print("Tape register usage:", dict(allocator.usage)) # offline data requirements if self.program.verbose: - print('Compile offline data requirements...') + print("Compile offline data requirements...") self.req_num = self.req_tree.aggregate() if self.program.verbose: - print('Tape requires', self.req_num) - for req,num in sorted(self.req_num.items()): - if num == float('inf') or num >= 2 ** 32: + print("Tape requires", self.req_num) + for req, num in sorted(self.req_num.items()): + if num == float("inf") or num >= 2**32: num = -1 if req[1] in data_types: self.basicblocks[-1].instructions.append( - Compiler.instructions.use(field_types[req[0]], \ - data_types[req[1]], num, \ - add_to_prog=False)) - elif req[1] == 'input': + Compiler.instructions.use( + field_types[req[0]], data_types[req[1]], num, add_to_prog=False + ) + ) + elif req[1] == "input": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_inp(field_types[req[0]], \ - req[2], num, \ - add_to_prog=False)) - elif req[0] == 'modp': + Compiler.instructions.use_inp( + field_types[req[0]], req[2], num, add_to_prog=False + ) + ) + elif req[0] == "modp": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_prep(req[1], num, \ - add_to_prog=False)) - elif req[0] == 'gf2n': + Compiler.instructions.use_prep(req[1], num, add_to_prog=False) + ) + elif req[0] == "gf2n": self.basicblocks[-1].instructions.append( - Compiler.instructions.guse_prep(req[1], num, \ - add_to_prog=False)) - elif req[0] == 'edabit': + Compiler.instructions.guse_prep(req[1], num, add_to_prog=False) + ) + elif req[0] == "edabit": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_edabit(False, req[1], num, \ - add_to_prog=False)) - elif req[0] == 'sedabit': + Compiler.instructions.use_edabit( + False, req[1], num, add_to_prog=False + ) + ) + elif req[0] == "sedabit": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_edabit(True, req[1], num, \ - add_to_prog=False)) - elif req[0] == 'matmul': + Compiler.instructions.use_edabit( + True, req[1], num, add_to_prog=False + ) + ) + elif req[0] == "matmul": self.basicblocks[-1].instructions.append( - Compiler.instructions.use_matmul(*req[1], num, \ - add_to_prog=False)) + Compiler.instructions.use_matmul(*req[1], num, add_to_prog=False) + ) if not self.is_empty(): # bit length requirement - for x in ('p', '2'): + for x in ("p", "2"): if self.req_bit_length[x]: bl = self.req_bit_length[x] if self.program.options.ring: bl = -int(self.program.options.ring) self.basicblocks[-1].instructions.append( - Compiler.instructions.reqbl(bl, - add_to_prog=False)) + Compiler.instructions.reqbl(bl, add_to_prog=False) + ) if self.program.verbose: - print('Tape requires prime bit length', self.req_bit_length['p']) - print('Tape requires galois bit length', self.req_bit_length['2']) + print("Tape requires prime bit length", self.req_bit_length["p"]) + print("Tape requires galois bit length", self.req_bit_length["2"]) @unpurged def expand_cisc(self): @@ -941,93 +1044,100 @@ def expand_cisc(self): @unpurged def _get_instructions(self): - return itertools.chain.\ - from_iterable(b.instructions for b in self.basicblocks) + return itertools.chain.from_iterable(b.instructions for b in self.basicblocks) @unpurged def get_encoding(self): - """ Get the encoding of the program, in human-readable format. """ + """Get the encoding of the program, in human-readable format.""" return [i.get_encoding() for i in self._get_instructions() if i is not None] - + @unpurged def get_bytes(self): - """ Get the byte encoding of the program as an actual string of bytes. """ - return b"".join(i.get_bytes() for i in self._get_instructions() if i is not None) - + """Get the byte encoding of the program as an actual string of bytes.""" + return b"".join( + i.get_bytes() for i in self._get_instructions() if i is not None + ) + @unpurged def write_encoding(self, filename): - """ Write the readable encoding to a file. """ - print('Writing to', filename) - f = open(filename, 'w') + """Write the readable encoding to a file.""" + print("Writing to", filename) + f = open(filename, "w") for line in self.get_encoding(): - f.write(str(line) + '\n') + f.write(str(line) + "\n") f.close() - + @unpurged def write_str(self, filename): - """ Write the sequence of instructions to a file. """ - print('Writing to', filename) - f = open(filename, 'w') + """Write the sequence of instructions to a file.""" + print("Writing to", filename) + f = open(filename, "w") n = 0 for block in self.basicblocks: if block.instructions: - f.write('# %s\n' % block.name) + f.write("# %s\n" % block.name) for line in block.instructions: - f.write('%s # %d\n' % (line, n)) + f.write("%s # %d\n" % (line, n)) n += 1 f.close() - + @unpurged def write_bytes(self, filename=None): - """ Write the program's byte encoding to a file. """ + """Write the program's byte encoding to a file.""" if filename is None: filename = self.outfile - if not filename.endswith('.bc'): - filename += '.bc' - if not 'Bytecode' in filename: - filename = self.program.programs_dir + '/Bytecode/' + filename - print('Writing to', filename) - f = open(filename, 'wb') + if not filename.endswith(".bc"): + filename += ".bc" + if "Bytecode" not in filename: + filename = self.program.programs_dir + "/Bytecode/" + filename + print("Writing to", filename) + f = open(filename, "wb") for i in self._get_instructions(): if i is not None: f.write(i.get_bytes()) f.close() - + def new_reg(self, reg_type, size=None): return self.Register(reg_type, self, size=size) - + def count_regs(self, reg_type=None): if reg_type is None: return self.reg_counter else: return self.reg_counter[reg_type] - + def __str__(self): return self.name class ReqNum(defaultdict): def __init__(self, init={}): super(Tape.ReqNum, self).__init__(lambda: 0, init) + def __add__(self, other): res = Tape.ReqNum() - for i,count in list(self.items()): - res[i] += count - for i,count in list(other.items()): + for i, count in list(self.items()): + res[i] += count + for i, count in list(other.items()): res[i] += count return res + def __mul__(self, other): res = Tape.ReqNum() for i in self: res[i] = other * self[i] return res + __rmul__ = __mul__ + def set_all(self, value): - if value == float('inf') and self['all', 'inv'] > 0: - print('Going to unknown from %s' % self) + if Program.prog.options.verbose and \ + value == float("inf") and self["all", "inv"] > 0: + print("Going to unknown from %s" % self) res = Tape.ReqNum() for i in self: res[i] = value return res + def max(self, other): res = Tape.ReqNum() for i in self: @@ -1035,82 +1145,105 @@ def max(self, other): for i in other: res[i] = max(self[i], other[i]) return res + def cost(self): - return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \ - if req[1] != 'input' and req[0] != 'edabit') + return sum( + num * COST[req[0]][req[1]] + for req, num in list(self.items()) + if req[1] != "input" and req[0] != "edabit" + ) + def pretty(self): - t = lambda x: 'integer' if x == 'modp' else x + def t(x): + return "integer" if x == "modp" else x + res = [] for req, num in self.items(): domain = t(req[0]) - n = '%12.0f' % num - if req[1] == 'input': - res += ['%s %s inputs from player %d' \ - % (n, domain, req[2])] - elif domain.endswith('edabit'): - if domain == 'sedabit': - eda = 'strict edabits' + if num < 0: + num = float('inf') + n = "%12.0f" % num + if req[1] == "input": + res += ["%s %s inputs from player %d" % (n, domain, req[2])] + elif domain.endswith("edabit"): + if domain == "sedabit": + eda = "strict edabits" else: - eda = 'loose edabits' - res += ['%s %s of length %d' % (n, eda, req[1])] - elif domain == 'matmul': - res += ['%s matrix multiplications (%dx%d * %dx%d)' % - (n, req[1][0], req[1][1], req[1][1], req[1][2])] - elif req[0] != 'all': - res += ['%s %s %ss' % (n, domain, req[1])] - if self['all','round']: - res += ['% 12.0f virtual machine rounds' % self['all','round']] + eda = "loose edabits" + res += ["%s %s of length %d" % (n, eda, req[1])] + elif domain == "matmul": + res += [ + "%s matrix multiplications (%dx%d * %dx%d)" + % (n, req[1][0], req[1][1], req[1][1], req[1][2]) + ] + elif req[0] != "all": + res += ["%s %s %ss" % (n, domain, req[1])] + if self["all", "round"]: + res += ["% 12.0f virtual machine rounds" % self["all", "round"]] return res + def __str__(self): - return ', '.join(self.pretty()) + return ", ".join(self.pretty()) + def __repr__(self): return repr(dict(self)) class ReqNode(object): - __slots__ = ['num', 'children', 'name', 'blocks'] + __slots__ = ["num", "children", "name", "blocks"] + def __init__(self, name): self.children = [] self.name = name self.blocks = [] + def aggregate(self, *args): self.num = Tape.ReqNum() for block in self.blocks: block.add_usage(self) - res = reduce(lambda x,y: x + y.aggregate(self.name), - self.children, self.num) + res = reduce( + lambda x, y: x + y.aggregate(self.name), self.children, self.num + ) return res + def increment(self, data_type, num=1): self.num[data_type] += num + def add_block(self, block): self.blocks.append(block) class ReqChild(object): - __slots__ = ['aggregator', 'nodes', 'parent'] + __slots__ = ["aggregator", "nodes", "parent"] + def __init__(self, aggregator, parent): self.aggregator = aggregator self.nodes = [] self.parent = parent + def aggregate(self, name): res = self.aggregator([node.aggregate() for node in self.nodes]) try: n_reps = self.aggregator([1]) - n_rounds = res['all', 'round'] - n_invs = res['all', 'inv'] + n_rounds = res["all", "round"] + n_invs = res["all", "inv"] if (n_invs / n_rounds) * 1000 < n_reps: - print(self.nodes[0].blocks[0].name, 'blowing up rounds: ', \ - '(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs)) - except: + print( + self.nodes[0].blocks[0].name, + "blowing up rounds: ", + "(%d / %d) ** 3 < %d" % (n_rounds, n_reps, n_invs), + ) + except Exception: pass return res + def add_node(self, tape, name): new_node = Tape.ReqNode(name) self.nodes.append(new_node) tape.req_node = new_node - def open_scope(self, aggregator, scope=False, name=''): + def open_scope(self, aggregator, scope=False, name=""): child = self.ReqChild(aggregator, self.req_node) self.req_node.children.append(child) - child.add_node(self, '%s-%d' % (name, len(self.basicblocks))) + child.add_node(self, "%s-%d" % (name, len(self.basicblocks))) self.start_new_basicblock(name=name) return child @@ -1118,21 +1251,21 @@ def close_scope(self, outer_scope, parent_req_node, name): self.req_node = parent_req_node self.start_new_basicblock(outer_scope, name) - def require_bit_length(self, bit_length, t='p'): - if t == 'p': + def require_bit_length(self, bit_length, t="p"): + if t == "p": if self.program.prime: - if (bit_length >= self.program.prime.bit_length() - 1): + if bit_length >= self.program.prime.bit_length() - 1: raise CompilerError( - 'required bit length %d too much for %d' % \ - (bit_length, self.program.prime)) - self.req_bit_length[t] = max(bit_length + 1, \ - self.req_bit_length[t]) + "required bit length %d too much for %d" + % (bit_length, self.program.prime) + ) + self.req_bit_length[t] = max(bit_length + 1, self.req_bit_length[t]) else: self.req_bit_length[t] = max(bit_length, self.req_bit_length) @staticmethod def read_instructions(tapename): - tape = open('Programs/Bytecode/%s.bc' % tapename, 'rb') + tape = open("Programs/Bytecode/%s.bc" % tapename, "rb") while tape.peek(): yield inst_base.ParsedInstruction(tape) @@ -1140,23 +1273,35 @@ class _no_truth(object): __slots__ = [] def __bool__(self): - raise CompilerError('Cannot derive truth value from register, ' - "consider using 'compile.py -l'") + raise CompilerError( + "Cannot derive truth value from register, " + "consider using 'compile.py -l'" + ) class Register(_no_truth): """ Class for creating new registers. The register's index is automatically assigned based on the block's reg_counter dictionary. """ - __slots__ = ["reg_type", "program", "absolute_i", "relative_i", \ - "size", "vector", "vectorbase", "caller", \ - "can_eliminate", "duplicates"] - maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1 + + __slots__ = [ + "reg_type", + "program", + "absolute_i", + "relative_i", + "size", + "vector", + "vectorbase", + "caller", + "can_eliminate", + "duplicates", + ] + maximum_size = 2 ** (64 - inst_base.Instruction.code_length) - 1 def __init__(self, reg_type, program, size=None, i=None): - """ Creates a new register. - reg_type must be one of those defined in RegType. """ - if Compiler.instructions_base.get_global_instruction_type() == 'gf2n': + """Creates a new register. + reg_type must be one of those defined in RegType.""" + if Compiler.instructions_base.get_global_instruction_type() == "gf2n": if reg_type == RegType.ClearModp: reg_type = RegType.ClearGF2N elif reg_type == RegType.SecretModp: @@ -1166,7 +1311,7 @@ def __init__(self, reg_type, program, size=None, i=None): if size is None: size = Compiler.instructions_base.get_global_vector_size() if size is not None and size > self.maximum_size: - raise CompilerError('vector too large: %d' % size) + raise CompilerError("vector too large: %d" % size) self.size = size self.vectorbase = self self.relative_i = 0 @@ -1176,7 +1321,7 @@ def __init__(self, reg_type, program, size=None, i=None): self.i = program.reg_counter[reg_type] program.reg_counter[reg_type] += size else: - self.i = float('inf') + self.i = float("inf") self.vector = [] self.can_eliminate = True self.duplicates = util.set_by_id([self]) @@ -1197,13 +1342,14 @@ def set_size(self, size): if self.size == size: return else: - raise CompilerError('Mismatch of instruction and register size:' - ' %s != %s' % (self.size, size)) + raise CompilerError( + "Mismatch of instruction and register size:" + " %s != %s" % (self.size, size) + ) def set_vectorbase(self, vectorbase): if self.vectorbase is not self: - raise CompilerError('Cannot assign one register' \ - 'to several vectors') + raise CompilerError("Cannot assign one register" "to several vectors") self.relative_i = self.i - vectorbase.i self.vectorbase = vectorbase @@ -1211,7 +1357,7 @@ def _new_by_number(self, i, size=1): return Tape.Register(self.reg_type, self.program, size=size, i=i) def get_vector(self, base=0, size=None): - if size == None: + if size is None: size = self.size if base == 0 and size == self.size: return self @@ -1220,7 +1366,7 @@ def get_vector(self, base=0, size=None): res = self._new_by_number(self.i + base, size=size) res.set_vectorbase(self) self.create_vector_elements() - res.vector = self.vector[base:base+size] + res.vector = self.vector[base : base + size] return res def create_vector_elements(self): @@ -1256,16 +1402,34 @@ def link(self, other): for dup in self.duplicates: dup.duplicates = self.duplicates + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + other = self.conv(other) + if self.program != other.program: + raise CompilerError( + 'cannot update register with one from another thread') + self.link(other) + @property def is_gf2n(self): - return self.reg_type == RegType.ClearGF2N or \ - self.reg_type == RegType.SecretGF2N - + return ( + self.reg_type == RegType.ClearGF2N + or self.reg_type == RegType.SecretGF2N + ) + @property def is_clear(self): - return self.reg_type == RegType.ClearModp or \ - self.reg_type == RegType.ClearGF2N or \ - self.reg_type == RegType.ClearInt + return ( + self.reg_type == RegType.ClearModp + or self.reg_type == RegType.ClearGF2N + or self.reg_type == RegType.ClearInt + ) def __str__(self): return self.reg_type + str(self.i) diff --git a/Compiler/sorting.py b/Compiler/sorting.py old mode 100755 new mode 100644 index 248b3ea..fc619b7 --- a/Compiler/sorting.py +++ b/Compiler/sorting.py @@ -3,12 +3,7 @@ def dest_comp(B): Bt = B.transpose() - Bt_flat = Bt.get_vector() - St_flat = Bt.value_type.Array(len(Bt_flat)) - St_flat.assign(Bt_flat) - @library.for_range(len(St_flat) - 1) - def _(i): - St_flat[i + 1] = St_flat[i + 1] + St_flat[i] + St_flat = Bt.get_vector().prefix_sum() Tt_flat = Bt.get_vector() * St_flat.get_vector() Tt = types.Matrix(*Bt.sizes, B.value_type) Tt.assign_vector(Tt_flat) @@ -37,8 +32,14 @@ def radix_sort(k, D, n_bits=None, signed=True): bs = types.Matrix.create_from(k.get_vector().bit_decompose(n_bits)) if signed and len(bs) > 1: bs[-1][:] = bs[-1][:].bit_not() - B = types.sint.Matrix(len(k), 2) - h = types.Array.create_from(types.sint(types.regint.inc(len(k)))) + radix_sort_from_matrix(bs, D) + +def radix_sort_from_matrix(bs, D): + n = len(D) + for b in bs: + assert(len(b) == n) + B = types.sint.Matrix(n, 2) + h = types.Array.create_from(types.sint(types.regint.inc(n))) @library.for_range(len(bs)) def _(i): b = bs[i] diff --git a/Compiler/tools.py b/Compiler/tools.py old mode 100755 new mode 100644 diff --git a/Compiler/types.py b/Compiler/types.py old mode 100755 new mode 100644 index 098f493..b55a1ba --- a/Compiler/types.py +++ b/Compiler/types.py @@ -127,8 +127,14 @@ def vectorized_operation(self, *args, **kwargs): if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \ and not isinstance(args[0], bits) \ and args[0].size != self.size: - raise VectorMismatch('Different vector sizes of operands: %d/%d' - % (self.size, args[0].size)) + if min(args[0].size, self.size) == 1: + size = max(args[0].size, self.size) + self = self.expand_to_vector(size) + args = list(args) + args[0] = args[0].expand_to_vector(size) + else: + raise VectorMismatch('Different vector sizes of operands: %d/%d' + % (self.size, args[0].size)) set_global_vector_size(self.size) try: res = operation(self, *args, **kwargs) @@ -214,6 +220,14 @@ def read_mem_operation(self, *args, **kwargs): copy_doc(read_mem_operation, operation) return read_mem_operation +def type_comp(operation): + def type_check(self, other, *args, **kwargs): + if not isinstance(other, (type(self), int, regint, self.clear_type)): + return NotImplemented + return operation(self, other, *args, **kwargs) + copy_doc(type_check, operation) + return type_check + def inputmixed(*args): # helper to cover both cases if isinstance(args[-1], int): @@ -249,8 +263,11 @@ def __mul__(self, other): try: return self.mul(other) except VectorMismatch: - # try reverse multiplication - return NotImplemented + if type(self) != type(other) and 1 in (self.size, other.size): + # try reverse multiplication + return NotImplemented + else: + raise __radd__ = __add__ __rmul__ = __mul__ @@ -324,6 +341,9 @@ def __abs__(self): def popcnt_bits(bits): return sum(bits) + def zero_if_not(self, condition): + return condition * self + class _int(Tape._no_truth): """ Integer functionality. """ @@ -452,6 +472,10 @@ def carry_out(self, a, b): s = a ^ b return a ^ (s & (self ^ a)) + def cond_swap(self, a, b): + prod = self * (a ^ b) + return a ^ prod, b ^ prod + class _gf2n(_bit): """ :math:`\mathrm{GF}(2^n)` functionality. """ @@ -733,7 +757,14 @@ def expand_to_vector(self, size=None): self.mov(res[i], self) return res -class _clear(_register): +class _arithmetic_register(_register): + """ Arithmetic circuit type. """ + def __init__(self, *args, **kwargs): + if program.options.garbled: + raise CompilerError('functionality only available in arithmetic circuits') + super(_arithmetic_register, self).__init__(*args, **kwargs) + +class _clear(_arithmetic_register): """ Clear domain-dependent type. """ __slots__ = [] mov = staticmethod(movc) @@ -1069,6 +1100,8 @@ def __eq__(self, other): def __ne__(self, other): return 1 - (self == other) + equal = lambda self, other, *args, **kwargs: self.__eq__(other) + def __lshift__(self, other): """ Clear left shift. @@ -1165,7 +1198,7 @@ def digest(self, num_bytes): def print_if(self, string): """ Output if value is non-zero. - :param string: Python string """ + :param string: bytearray """ cond_print_str(self, string) def output_if(self, cond): @@ -1638,7 +1671,7 @@ def output_if(self, cond): def _condition(self): if program.options.binary: - from GC.types import cbits + from .GC.types import cbits return cbits.get_type(64)(self) else: return cint(self) @@ -1651,6 +1684,8 @@ def binary_output(self, player=None): """ if player == None: player = -1 + if not util.is_constant(player): + raise CompilerError('Player number must be known at compile time') intoutput(player, self) class localint(Tape._no_truth): @@ -1818,7 +1853,7 @@ def bit_decompose(self, bit_length): res += x.bit_decompose(64) return res[:bit_length] -class _secret(_register, _secret_structure): +class _secret(_arithmetic_register, _secret_structure): __slots__ = [] mov = staticmethod(set_instruction_type(movs)) @@ -2074,12 +2109,15 @@ def __rsub__(self, other): return self.secret_op(other, subs, submr, subsfi, True) __rsub__.__doc__ = __sub__.__doc__ - @vectorize def __truediv__(self, other): """ Secret field division. :param other: any compatible type """ - return self * (self.clear_type(1) / other) + try: + one = self.clear_type(1, size=other.size) + except AttributeError: + one = self.clear_type(1) + return self * (one / other) @vectorize def __rtruediv__(self, other): @@ -2106,12 +2144,12 @@ def secure_shuffle(self, unit_size=1): @set_instruction_type @vectorize - def reveal(self): + def reveal(self, check=True): """ Reveal secret value publicly. :rtype: relevant clear type """ res = self.clear_type() - asm_open(res, self) + asm_open(check, res, self) return res @set_instruction_type @@ -2126,6 +2164,21 @@ def reveal_to(self, player): res = personal(player, masked.reveal() - mask[1]) return res + @set_instruction_type + @vectorize + def raw_right_shift(self, length): + """ Local right shift in supported protocols. + In integer-like protocols, the output is potentially off by one. + + :param length: number of bits + """ + res = type(self)() + shrsi(res, self, length) + return res + + def raw_mod2m(self, m): + return self - (self.raw_right_shift(m) << m) + class sint(_secret, _int): """ @@ -2144,9 +2197,7 @@ class sint(_secret, _int): signed integer in a restricted range, see below. The same holds for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and exponentation (``**``). Modulo only works if the right-hand - operator is a compile-time power of two, and exponentiation only - works if the base is two or if the exponent is a compile-time - integer. + operator is a compile-time power of two. Most non-linear operations require compile-time parameters for bit length and statistical security. They default to the global @@ -2319,6 +2370,9 @@ def reveal_to_clients(cls, clients, values): n_clients = clients.length else: n_clients = len(clients) + set_global_vector_size(1) + clients = Array.create_from(regint.conv(clients)) + reset_global_vector_size() @library.for_range(n_clients) def loop_body(i): @@ -2441,6 +2495,7 @@ def __abs__(self): return (self >= 0).if_else(self, -self) @read_mem_value + @type_comp @vectorize def __lt__(self, other, bit_length=None, security=None): """ Secret comparison (signed). @@ -2455,6 +2510,7 @@ def __lt__(self, other, bit_length=None, security=None): return res @read_mem_value + @type_comp @vectorize def __gt__(self, other, bit_length=None, security=None): res = sintbit() @@ -2463,18 +2519,25 @@ def __gt__(self, other, bit_length=None, security=None): security or program.security) return res + @read_mem_value + @type_comp def __le__(self, other, bit_length=None, security=None): return 1 - self.greater_than(other, bit_length, security) + @read_mem_value + @type_comp def __ge__(self, other, bit_length=None, security=None): return 1 - self.less_than(other, bit_length, security) @read_mem_value + @type_comp @vectorize def __eq__(self, other, bit_length=None, security=None): return floatingpoint.EQZ(self - other, bit_length or program.bit_length, security or program.security) + @read_mem_value + @type_comp def __ne__(self, other, bit_length=None, security=None): return 1 - self.equal(other, bit_length, security) @@ -2645,12 +2708,21 @@ def int_div(self, other, bit_length=None, security=None): comparison.Trunc(res, tmp, 2 * k, k, kappa, True) return res + @vectorize + def int_mod(self, other, bit_length=None): + """ Secret integer modulo. + + :param other: sint + :param bit_length: bit length of input (default: global bit length) + """ + return self - other * self.int_div(other, bit_length=bit_length) + def trunc_zeros(self, n_zeros, bit_length=None, signed=True): bit_length = bit_length or program.bit_length return comparison.TruncZeros(self, bit_length, n_zeros, signed) @staticmethod - def two_power(n): + def two_power(n, size=None): return floatingpoint.two_power(n) def split_to_n_summands(self, length, n): @@ -2668,16 +2740,6 @@ def split_to_two_summands(self, length, get_carry=False): columns = self.split_to_n_summands(length, n) return _bitint.wallace_tree_without_finish(columns, get_carry) - @vectorize - def raw_right_shift(self, length): - res = sint() - shrsi(res, self, length) - return res - - def raw_mod2m(self, m): - return self - (self.raw_right_shift(m) << m) - - @vectorize def reveal_to(self, player): """ Reveal secret value to :py:obj:`player`. @@ -2685,13 +2747,14 @@ def reveal_to(self, player): :returns: :py:class:`personal` """ if not util.is_constant(player): - secret_mask = sint() - player_mask = cint() - inputmaskreg(secret_mask, player_mask, regint.conv(player)) + secret_mask = sint(size=self.size) + player_mask = cint(size=self.size) + inputmaskreg(secret_mask, player_mask, + regint.conv(player).expand_to_vector(self.size)) return personal(player, - (self + secret_mask).reveal() - player_mask) + (self + secret_mask).reveal(False) - player_mask) else: - res = personal(player, self.clear_type()) + res = personal(player, self.clear_type(size=self.size)) privateoutput(self.size, player, res._v, self) return res @@ -2763,6 +2826,30 @@ def secure_permute(self, shuffle, unit_size=1, reverse=False): applyshuffle(res, self, unit_size, shuffle, reverse) return res + def inverse_permutation(self): + if program.use_invperm(): + # If enabled, we use the low-level INVPERM instruction. + # This instruction has only been implemented for a semi-honest two-party environement. + res = sint(size=self.size) + inverse_permutation(res, self) + else: + shuffle = sint.get_secure_shuffle(len(self)) + shuffled = self.secure_permute(shuffle).reveal() + idx = Array.create_from(shuffled) + res = Array.create_from(sint(regint.inc(len(self)))) + res.secure_permute(shuffle, reverse=False) + res.assign_slice_vector(idx, res.get_vector()) + library.break_point() + res = res.get_vector() + return res + + @vectorize + def prefix_sum(self): + """ Prefix sum. """ + res = sint() + prefixsums(res, self) + return res + class sintbit(sint): """ :py:class:`sint` holding a bit, supporting binary operations (``&, |, ^``). """ @@ -2811,7 +2898,9 @@ def __xor__(self, other): elif util.is_zero(other): return self elif util.is_one(other): - return 1 + res = sintbit() + submr(res, cint(1), self) + return res else: return NotImplemented @@ -2824,6 +2913,10 @@ def __rsub__(self, other): else: return super(sintbit, self).__rsub__(other) + __rand__ = __and__ + __rxor__ = __xor__ + __ror__ = __or__ + class sgf2n(_secret, _gf2n): """ Secret :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A @@ -2841,6 +2934,7 @@ class sgf2n(_secret, _gf2n): instruction_type = 'gf2n' clear_type = cgf2n reg_type = 'sg' + long_one = staticmethod(lambda: 1) @classmethod def get_type(cls, length): @@ -2990,6 +3084,7 @@ class _bitint(Tape._no_truth): bits = None log_rounds = False linear_rounds = False + comp_result = staticmethod(lambda x: x) @staticmethod def half_adder(a, b): @@ -3209,12 +3304,16 @@ def wallace_reduction(cls, a, b, c, get_carry=True): del carries[-1] return sums, carries + def expand(self, other): + a = self.bit_decompose() + b = util.bit_decompose(other, self.n_bits) + return a, b + def __sub__(self, other): if type(other) == sgf2n: raise CompilerError('Unclear subtraction') - a = self.bit_decompose() - b = util.bit_decompose(other, self.n_bits) from util import bit_not, bit_and, bit_xor + a, b = self.expand(other) n = 1 for x in (a + b): try: @@ -3261,8 +3360,7 @@ def prep_comparison(a, b): a[-1], b[-1] = b[-1], a[-1] def comparison(self, other, const_rounds=False, index=None): - a = self.bit_decompose() - b = util.bit_decompose(other, self.n_bits) + a, b = self.expand(other) self.prep_comparison(a, b) if const_rounds: return self.get_highest_different_bits(a, b, index) @@ -3272,30 +3370,33 @@ def comparison(self, other, const_rounds=False, index=None): def __lt__(self, other): if program.options.comparison == 'log': x, not_equal = self.comparison(other) - return util.if_else(not_equal, x, 0) + res = util.if_else(not_equal, x, 0) else: - return self.comparison(other, True, 1) + res = self.comparison(other, True, 1) + return self.comp_result(res) def __le__(self, other): if program.options.comparison == 'log': x, not_equal = self.comparison(other) - return util.if_else(not_equal, x, 1) + res = util.if_else(not_equal, x, x.long_one()) else: - return 1 - self.comparison(other, True, 0) + res = self.comparison(other, True, 0).bit_not() + return self.comp_result(res) def __ge__(self, other): - return 1 - (self < other) + return (self < other).bit_not() def __gt__(self, other): - return 1 - (self <= other) + return (self <= other).bit_not() def __eq__(self, other, bit_length=None, security=None): diff = self ^ other - diff_bits = [1 - x for x in diff.bit_decompose()[:bit_length]] - return floatingpoint.KMul(diff_bits) + diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]] + return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y), + diff_bits)) def __ne__(self, other): - return 1 - (self == other) + return (self == other).bit_not() equal = __eq__ @@ -3849,7 +3950,6 @@ def print_plain(self): def output_if(self, cond): cond_print_plain(cint.conv(cond), self.v, cint(-self.f, size=self.size)) - @vectorize def binary_output(self, player=None): """ Write double-precision floating-point number to ``Player-Data/Binary-Output-P-``. @@ -3858,7 +3958,11 @@ def binary_output(self, player=None): """ if player == None: player = -1 + if not util.is_constant(player): + raise CompilerError('Player number must be known at compile time') + set_global_vector_size(self.size) floatoutput(player, self.v, cint(-self.f), cint(0), cint(0)) + reset_global_vector_size() class _single(_number, _secret_structure): """ Representation as single integer preserving the order """ @@ -3878,6 +3982,8 @@ def receive_from_client(cls, n, client_id, message_type=ClientMessageType.NoType :param n: number of inputs (int) :param client_id: regint :param size: vector size (default 1) + :returns: list of length ``n`` + """ sint_inputs = cls.int_type.receive_from_client(n, client_id, message_type) @@ -3915,6 +4021,8 @@ def load_mem(cls, address, mem_type=None): def conv(cls, other): if isinstance(other, cls): return other + elif isinstance(other, (list, tuple)): + return type(other)(cls.conv(x) for x in other) else: try: return cls.from_sint(other) @@ -4092,6 +4200,7 @@ def get_vector(self): class _fix(_single): """ Secret fixed point type. """ __slots__ = ['v', 'f', 'k'] + is_clear = False def set_precision(cls, f, k = None): cls.f = f @@ -4153,7 +4262,7 @@ def conv(cls, other): if isinstance(other, _fix) and (cls.k, cls.f) == (other.k, other.f): return other else: - return cls(other) + return super(_fix, cls).conv(other) @classmethod def _new(cls, other, k=None, f=None): @@ -4317,6 +4426,18 @@ def bit_decompose(self, n_bits=None): """ Bit decomposition. """ return self.v.bit_decompose(n_bits or self.k) + def update(self, other): + """ + Update register. Useful in loops like + :py:func:`~Compiler.library.for_range`. + + :param other: any convertible type + + """ + other = self.conv(other) + assert self.f == other.f + self.v.update(other.v) + class sfix(_fix): """ Secret fixed-point number represented as secret integer, by multiplying with ``2^f`` and then rounding. See :py:class:`sint` @@ -4449,6 +4570,9 @@ def secure_permute(self, *args, **kwargs): return self._new(self.v.secure_permute(*args, **kwargs), k=self.k, f=self.f) + def prefix_sum(self): + return self._new(self.v.prefix_sum(), k=self.k, f=self.f) + class unreduced_sfix(_single): int_type = sint @@ -4705,6 +4829,8 @@ class sfloat(_number, _secret_structure): returning :py:class:`sint`. The other operand can be any of sint/cfix/regint/cint/int/float. + This data type only works with arithmetic computation. + :param v: initialization (sfloat/sfix/float/int/sint/cint/regint) """ __slots__ = ['v', 'p', 'z', 's', 'size'] @@ -4803,6 +4929,9 @@ def get_input_from(cls, player): @vectorize_init @read_mem_value def __init__(self, v, p=None, z=None, s=None, size=None): + if program.options.binary: + raise CompilerError( + 'floating-point operations not supported with binary circuits') self.size = get_global_vector_size() if p is None: if isinstance(v, sfloat): @@ -5057,10 +5186,12 @@ def __ge__(self, other): """ Secret floating-point comparison. """ return 1 - (self < other) + @vectorize def __gt__(self, other): """ Secret floating-point comparison. """ return self.conv(other) < self + @vectorize def __le__(self, other): """ Secret floating-point comparison. """ return self.conv(other) >= self @@ -5189,13 +5320,23 @@ class Array(_vectorizable): a[:] += b[:] """ + check_indices = True + @classmethod def create_from(cls, l): """ Convert Python iterator or vector to array. Basic type will be taken from first element, further elements must to be convertible to - that. """ + that. + + :param l: Python iterable or register vector + :returns: :py:class:`Array` of appropriate type containing the contents + of :py:obj:`l` + + """ if isinstance(l, cls): - return l + res = l.same_shape() + res[:] = l[:] + return res if isinstance(l, _number): tmp = l t = type(l) @@ -5216,7 +5357,6 @@ def __init__(self, length, value_type, address=None, debug=None, alloc=True): self.debug = debug self.creator_tape = program.curr_tape self.sink = None - self.check_indices = True if alloc: self.alloc() @@ -5251,7 +5391,8 @@ def get_address(self, index, size=None): # length can be None for single-element arrays length = 0 base = self.address + index * self.value_type.mem_size() - if size is not None and isinstance(base, _register): + if size is not None and isinstance(base, _register) \ + and not issubclass(self.value_type, _vec): base = regint._expand_address(base, size) self.address_cache[program.curr_block, key] = \ util.untuplify([base + i * length \ @@ -5316,7 +5457,7 @@ def maybe_get(self, condition, index): :param condition: 0/1 (regint/cint/int) :param index: regint/cint/int """ - return condition * self[condition * index] + return self[condition * index].zero_if_not(condition) def maybe_set(self, condition, index, value): """ Change entry if condition is true. @@ -5346,13 +5487,16 @@ def _load(self, address): return self.value_type.load_mem(address) def _store(self, value, address): - self.value_type.conv(value).store_in_mem(address) + tmp = self.value_type.conv(value) + if not isinstance(tmp, _vec) and tmp.size != self.value_type.mem_size(): + raise CompilerError('size mismatch in array assignment') + tmp.store_in_mem(address) def __len__(self): return self.length def total_size(self): - return len(self) * self.value_type.n_elements() + return self.length * self.value_type.n_elements() def __iter__(self): for i in range(self.length): @@ -5417,6 +5561,12 @@ def get_vector(self, base=0, size=None): get_part_vector = get_vector + def get_reverse_vector(self): + """ Return vector with content in reverse order. """ + size = self.length + address = regint.inc(size, size - 1, -1) + return self.value_type.load_mem(self.address + address, size=size) + def get_part(self, base, size): """ Part array. @@ -5476,7 +5626,8 @@ def input_from(self, player, budget=None, raw=False): try: self.assign(input_from(player, size=len(self))) except (TypeError, CompilerError): - @library.for_range_opt(len(self), budget=budget) + print (budget) + @library.for_range_opt(self.length, budget=budget) def _(i): self[i] = input_from(player) @@ -5515,7 +5666,6 @@ def __sub__(self, other): """ Vector subtraction. :param other: vector or container of same length and type that supports operations with type of this array """ - assert len(self) == len(other) return self.get_vector() - other def __mul__(self, value): @@ -5578,7 +5728,7 @@ def reveal(self): """ Reveal the whole array. :returns: Array of relevant clear type. """ - return Array.create_from(x.reveal() for x in self) + return Array.create_from(self.get_vector().reveal()) def reveal_list(self): """ Reveal as list. """ @@ -5638,13 +5788,14 @@ def sort(self, n_threads=None, batcher=False, n_bits=None): :param batcher: use Batcher's odd-even mergesort in any case :param n_bits: number of bits in keys (default: global bit length) """ - if batcher or self.value_type.n_elements() > 1: + if batcher or self.value_type.n_elements() > 1 or \ + program.options.binary: library.loopy_odd_even_merge_sort(self, n_threads=n_threads) else: if n_threads or 1 > 1: raise CompilerError('multi-threaded sorting only implemented ' 'with Batcher\'s odd-even mergesort') - import sorting + from . import sorting sorting.radix_sort(self, self, n_bits=n_bits) def Array(self, size): @@ -5701,7 +5852,8 @@ def __getitem__(self, index): self.sub_cache[key] = \ Array(self.sizes[1], self.value_type, \ self.address + index * self.sizes[1] * - self.value_type.n_elements(), \ + self.value_type.n_elements() * \ + self.value_type.mem_size(), \ debug=self.debug) else: self.sub_cache[key] = \ @@ -5730,6 +5882,13 @@ def __iter__(self): def to_array(self): return Array(self.total_size(), self.value_type, address=self.address) + def maybe_get(self, condition, index): + return self[condition * index] + + def maybe_set(self, condition, index, value): + for i, x in enumerate(value): + self.maybe_get(condition, index).maybe_set(condition, i, x) + def assign_all(self, value): """ Assign the same value to all entries. @@ -5879,17 +6038,16 @@ def input_from(self, player, budget=None, raw=False): """ Fill with inputs from player if supported by type. :param player: public (regint/cint/int) """ - budget = budget or Tape.Register.maximum_size - if (self.total_size() < budget) and \ - self.value_type.n_elements() == 1: + if util.is_constant(self.total_size()) and \ + self.value_type.n_elements() == 1 and \ + self.value_type.mem_size() == 1: if raw or program.always_raw(): input_from = self.value_type.get_raw_input_from else: input_from = self.value_type.get_input_from self.assign_vector(input_from(player, size=self.total_size())) else: - @library.for_range_opt(self.sizes[0], - budget=budget / self[0].total_size()) + @library.for_range_opt(self.sizes[0], budget=budget) def _(i): self[i].input_from(player, budget=budget, raw=raw) @@ -6042,6 +6200,7 @@ def _(base, size): assert n_threads is None if max(res_matrix.sizes) > 1000: raise AttributeError() + self.value_type.matrix_mul A = self.get_vector() B = other.get_vector() res_matrix.assign_vector( @@ -6054,12 +6213,12 @@ def _(i): try: res_matrix[i] = self.value_type.row_matrix_mul( self[i], other, res_params) - except AttributeError: + except (AttributeError, CompilerError): # fallback for binary circuits - @library.for_range(other.sizes[1]) + @library.for_range_opt(other.sizes[1]) def _(j): res_matrix[i][j] = 0 - @library.for_range(self.sizes[1]) + @library.for_range_opt(self.sizes[1]) def _(k): res_matrix[i][j] += self[i][k] * other[k][j] return res_matrix @@ -6178,13 +6337,7 @@ def _(i): res[i] = self.direct_mul_trans(other, indices=indices) def direct_mul_to_matrix(self, other): - """ Matrix multiplication in the virtual machine. - - :param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` - :param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray` - :returns: :py:obj:`Matrix` - - """ + # Obsolete. Use dot(). res = self.value_type.Matrix(self.sizes[0], other.sizes[1]) res.assign_vector(self.direct_mul(other)) return res @@ -6274,13 +6427,15 @@ def transpose(self): res = Matrix(self.sizes[1], self.sizes[0], self.value_type) library.break_point() if self.value_type.n_elements() == 1: - @library.for_range_opt(self.sizes[0]) - def _(j): - res.set_column(j, self[j][:]) + nr = self.sizes[1] + nc = self.sizes[0] + a = regint.inc(nr * nc, 0, nr, 1, nc) + b = regint.inc(nr * nc, 0, 1, nc) + res[:] = self.value_type.load_mem(self.address + a + b) else: - @library.for_range_opt(self.sizes[1]) + @library.for_range_opt(self.sizes[1], budget=100) def _(i): - @library.for_range_opt(self.sizes[0]) + @library.for_range_opt(self.sizes[0], budget=100) def _(j): res[i][j] = self[j][i] library.break_point() @@ -6317,16 +6472,21 @@ def sort(self, key_indices=None, n_bits=None): :param n_bits: number of bits in keys (default: global bit length) """ + if program.options.binary: + assert key_indices is None + assert len(self.sizes) == 2 + library.loopy_odd_even_merge_sort(self) + return if key_indices is None: key_indices = (0,) * (len(self.sizes) - 1) key_indices = (None,) + util.tuplify(key_indices) - import sorting + from . import sorting keys = self.get_vector_by_indices(*key_indices) sorting.radix_sort(keys, self, n_bits=n_bits) def randomize(self, *args): """ Randomize according to data type. """ - if self.total_size() < program.options.budget: + if self.total_size() < program.budget: self.assign_vector( self.value_type.get_random(*args, size=self.total_size())) else: @@ -6334,6 +6494,12 @@ def randomize(self, *args): def _(i): self[i].randomize(*args) + def reveal(self): + """ Reveal to :py:obj:`MultiArray` of same shape. """ + res = MultiArray(self.sizes, self.value_type.clear_type) + res[:] = self.get_vector().reveal() + return res + def reveal_list(self): """ Reveal as list. """ return list(self.get_vector().reveal()) @@ -6354,7 +6520,8 @@ def print_reveal_nested(self, end='\n'): :param end: string to print after (default: line break) """ - if self.total_size() < program.options.budget: + if util.is_constant(self.total_size()) and \ + self.total_size() < program.budget: library.print_str('%s' + end, self.reveal_nested()) else: library.print_str('[') @@ -6443,7 +6610,7 @@ def __init__(self, rows, columns, value_type, debug=None, address=None): @staticmethod def create_from(rows): rows = list(rows) - if isinstance(rows[0], (list, tuple)): + if isinstance(rows[0], (list, tuple, Array)): t = type(rows[0][0]) else: t = type(rows[0]) @@ -6607,7 +6774,10 @@ def read(self): :return: relevant basic type instance """ self.check() if program.curr_block != self.last_write_block: - self.register = self.value_type.load_mem(self.address) + from Compiler.GC.types import sbitvec + self.register = self.value_type.load_mem( + self.address, size=self.size \ + if issubclass(self.value_type, (_register, sbitvec)) else None) self.last_write_block = program.curr_block return self.register @@ -6656,6 +6826,7 @@ def reveal(self): if_else = lambda self,*args,**kwargs: self.read().if_else(*args, **kwargs) bit_and = lambda self,other: self.read().bit_and(other) + bit_not = lambda self: self.read().bit_not() def expand_to_vector(self, size=None): if program.curr_block == self.last_write_block: diff --git a/Compiler/util.py b/Compiler/util.py old mode 100755 new mode 100644 index aa491e4..c1bedc2 --- a/Compiler/util.py +++ b/Compiler/util.py @@ -116,6 +116,11 @@ def round_to_int(x): return x.round_to_int() def tree_reduce(function, sequence): + try: + return sequence.tree_reduce(function) + except AttributeError: + pass + sequence = list(sequence) assert len(sequence) > 0 n = len(sequence) @@ -233,6 +238,9 @@ def mem_size(x): except AttributeError: return 1 +def find_in_dict(d, v): + return list(d.keys())[list(d.values()).index(v)] + class set_by_id(object): def __init__(self, init=[]): self.content = {} diff --git a/Dockerfile b/Dockerfile old mode 100755 new mode 100644 index 5d8c288..79dc702 --- a/Dockerfile +++ b/Dockerfile @@ -42,6 +42,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ automake \ build-essential \ clang-11 \ + cmake \ git \ libboost-dev \ libboost-thread-dev \ @@ -91,6 +92,7 @@ ARG cryptoplayers=0 ENV PLAYERS ${cryptoplayers} RUN ./Scripts/setup-ssl.sh ${cryptoplayers} ${ssl_dir} +RUN make boost libote ############################################################################### # Use this stage to a build a specific virtual machine. For example: # diff --git a/ECDSA/CurveElement.cpp b/ECDSA/CurveElement.cpp old mode 100755 new mode 100644 diff --git a/ECDSA/CurveElement.h b/ECDSA/CurveElement.h old mode 100755 new mode 100644 diff --git a/ECDSA/EcdsaOptions.h b/ECDSA/EcdsaOptions.h old mode 100755 new mode 100644 diff --git a/ECDSA/Fake-ECDSA.cpp b/ECDSA/Fake-ECDSA.cpp old mode 100755 new mode 100644 index 23f81b9..ecf7011 --- a/ECDSA/Fake-ECDSA.cpp +++ b/ECDSA/Fake-ECDSA.cpp @@ -22,4 +22,5 @@ int main() generate_mac_keys>(key, 2, prefix); make_mult_triples>(key, 2, 1000, false, prefix); make_inverse>(key, 2, 1000, false, prefix); + P256Element::finish(); } diff --git a/ECDSA/P256Element.cpp b/ECDSA/P256Element.cpp old mode 100755 new mode 100644 index 2c8c776..1ff3273 --- a/ECDSA/P256Element.cpp +++ b/ECDSA/P256Element.cpp @@ -14,7 +14,14 @@ void P256Element::init() curve = EC_GROUP_new_by_curve_name(NID_secp256k1); assert(curve != 0); auto modulus = EC_GROUP_get0_order(curve); - Scalar::init_field(BN_bn2dec(modulus), false); + auto mod = BN_bn2dec(modulus); + Scalar::init_field(mod, false); + free(mod); +} + +void P256Element::finish() +{ + EC_GROUP_free(curve); } P256Element::P256Element() @@ -42,6 +49,11 @@ P256Element::P256Element(word other) : BN_free(exp); } +P256Element::~P256Element() +{ + EC_POINT_free(point); +} + P256Element& P256Element::operator =(const P256Element& other) { assert(EC_POINT_copy(point, other.point) != 0); @@ -99,7 +111,7 @@ bool P256Element::operator ==(const P256Element& other) const return not cmp; } -void P256Element::pack(octetStream& os) const +void P256Element::pack(octetStream& os, int) const { octet* buffer; size_t length = EC_POINT_point2buf(curve, point, @@ -107,9 +119,10 @@ void P256Element::pack(octetStream& os) const assert(length != 0); os.store_int(length, 8); os.append(buffer, length); + free(buffer); } -void P256Element::unpack(octetStream& os) +void P256Element::unpack(octetStream& os, int) { size_t length = os.get_int(8); assert( diff --git a/ECDSA/P256Element.h b/ECDSA/P256Element.h old mode 100755 new mode 100644 index 4657b5d..bd005c8 --- a/ECDSA/P256Element.h +++ b/ECDSA/P256Element.h @@ -22,7 +22,7 @@ class P256Element : public ValueInterface EC_POINT* point; public: - typedef void next; + typedef P256Element next; typedef void Square; static const true_type invertible; @@ -32,11 +32,13 @@ class P256Element : public ValueInterface static string type_string() { return "P256"; } static void init(); + static void finish(); P256Element(); P256Element(const P256Element& other); P256Element(const Scalar& other); P256Element(word other); + ~P256Element(); P256Element& operator=(const P256Element& other); @@ -58,8 +60,8 @@ class P256Element : public ValueInterface bool is_zero() { return *this == P256Element(); } void add(octetStream& os) { *this += os.get(); } - void pack(octetStream& os) const; - void unpack(octetStream& os); + void pack(octetStream& os, int = -1) const; + void unpack(octetStream& os, int = -1); octetStream hash(size_t n_bytes) const; diff --git a/ECDSA/README.md b/ECDSA/README.md old mode 100755 new mode 100644 diff --git a/ECDSA/fake-spdz-ecdsa-party.cpp b/ECDSA/fake-spdz-ecdsa-party.cpp old mode 100755 new mode 100644 index f0e3257..ea19c8e --- a/ECDSA/fake-spdz-ecdsa-party.cpp +++ b/ECDSA/fake-spdz-ecdsa-party.cpp @@ -45,12 +45,13 @@ int main(int argc, const char** argv) string prefix = get_prep_sub_dir(PREP_DIR "ECDSA/", 2); read_mac_key(prefix, N, keyp); + pShare::MAC_Check::setup(P); + Share::MAC_Check::setup(P); + DataPositions usage; Sub_Data_Files prep(N, prefix, usage); typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); - BaseMachine machine; - machine.ot_setups.push_back({P, false}); SubProcessor proc(_, MCp, prep, P); pShare sk, __; @@ -60,4 +61,8 @@ int main(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); check(tuples, sk, keyp, P); sign_benchmark(tuples, sk, MCp, P, opts); + + pShare::MAC_Check::teardown(); + Share::MAC_Check::teardown(); + P256Element::finish(); } diff --git a/ECDSA/hm-ecdsa-party.hpp b/ECDSA/hm-ecdsa-party.hpp old mode 100755 new mode 100644 index fc19e98..07520f3 --- a/ECDSA/hm-ecdsa-party.hpp +++ b/ECDSA/hm-ecdsa-party.hpp @@ -30,6 +30,8 @@ #include "GC/ThreadMaster.hpp" #include "GC/Secret.hpp" #include "Machines/ShamirMachine.hpp" +#include "Machines/MalRep.hpp" +#include "Machines/Rep.hpp" #include @@ -69,4 +71,5 @@ void run(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); // check(tuples, sk, {}, P); sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); + P256Element::finish(); } diff --git a/ECDSA/mal-rep-ecdsa-party.cpp b/ECDSA/mal-rep-ecdsa-party.cpp old mode 100755 new mode 100644 diff --git a/ECDSA/mal-shamir-ecdsa-party.cpp b/ECDSA/mal-shamir-ecdsa-party.cpp old mode 100755 new mode 100644 diff --git a/ECDSA/mascot-ecdsa-party.cpp b/ECDSA/mascot-ecdsa-party.cpp old mode 100755 new mode 100644 diff --git a/ECDSA/ot-ecdsa-party.hpp b/ECDSA/ot-ecdsa-party.hpp old mode 100755 new mode 100644 index 569aa79..550c0ac --- a/ECDSA/ot-ecdsa-party.hpp +++ b/ECDSA/ot-ecdsa-party.hpp @@ -92,9 +92,6 @@ void run(int argc, const char** argv) P256Element::init(); P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false); - BaseMachine machine; - machine.ot_setups.push_back({P, true}); - P256Element::Scalar keyp; SeededPRNG G; keyp.randomize(G); @@ -102,6 +99,9 @@ void run(int argc, const char** argv) typedef T pShare; DataPositions usage; + pShare::MAC_Check::setup(P); + T::MAC_Check::setup(P); + OnlineOptions::singleton.batch_size = 1; typename pShare::Direct_MC MCp(keyp); ArithmeticProcessor _({}, 0); @@ -137,4 +137,8 @@ void run(int argc, const char** argv) preprocessing(tuples, n_tuples, sk, proc, opts); //check(tuples, sk, keyp, P); sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc); + + pShare::MAC_Check::teardown(); + T::MAC_Check::teardown(); + P256Element::finish(); } diff --git a/ECDSA/preprocessing.hpp b/ECDSA/preprocessing.hpp old mode 100755 new mode 100644 diff --git a/ECDSA/rep-ecdsa-party.cpp b/ECDSA/rep-ecdsa-party.cpp old mode 100755 new mode 100644 diff --git a/ECDSA/semi-ecdsa-party.cpp b/ECDSA/semi-ecdsa-party.cpp old mode 100755 new mode 100644 diff --git a/ECDSA/shamir-ecdsa-party.cpp b/ECDSA/shamir-ecdsa-party.cpp old mode 100755 new mode 100644 diff --git a/ECDSA/sign.hpp b/ECDSA/sign.hpp old mode 100755 new mode 100644 diff --git a/ExternalIO/Client.h b/ExternalIO/Client.h old mode 100755 new mode 100644 diff --git a/ExternalIO/Client.hpp b/ExternalIO/Client.hpp old mode 100755 new mode 100644 diff --git a/ExternalIO/README.md b/ExternalIO/README.md old mode 100755 new mode 100644 index d4f9928..8932844 --- a/ExternalIO/README.md +++ b/ExternalIO/README.md @@ -2,19 +2,20 @@ The ExternalIO directory contains an example of managing I/O between external cl ## Working Examples -[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) acts as a +[bankers-bonus-client.cpp](./bankers-bonus-client.cpp) and +[bankers-bonus-client.py](./bankers-bonus-client.py) act as a client to [bankers_bonus.mpc](../Programs/Source/bankers_bonus.mpc) and demonstrates sending input and receiving output as described by [Damgård et al.](https://eprint.iacr.org/2015/1006) The computation allows up to eight clients to input a number and computes the client -with the largest input. You can run it as follows from the main +with the largest input. You can run the C++ code as follows from the main directory: ``` make bankers-bonus-client.x ./compile.py bankers_bonus 1 Scripts/setup-ssl.sh Scripts/setup-clients.sh 3 -Scripts/.sh bankers_bonus-1 & +PLAYERS= Scripts/.sh bankers_bonus-1 & ./bankers-bonus-client.x 0 100 0 & ./bankers-bonus-client.x 1 200 0 & ./bankers-bonus-client.x 2 50 1 @@ -30,6 +31,11 @@ protocol script. The setup scripts generate the necessary SSL certificates and keys. Therefore, if you run the computation on different hosts, you will have to distribute the `*.pem` files. +For the Python client, make sure to install +[gmpy2](https://pypi.org/project/gmpy2), and run +`ExternalIO/bankers-bonus-client.py` instead of +`bankers-bonus-client.x`. + ## I/O MPC Instructions ### Connection Setup diff --git a/ExternalIO/bankers-bonus-client.cpp b/ExternalIO/bankers-bonus-client.cpp old mode 100755 new mode 100644 diff --git a/FHE/AddableVector.h b/FHE/AddableVector.h old mode 100755 new mode 100644 diff --git a/FHE/AddableVector.hpp b/FHE/AddableVector.hpp old mode 100755 new mode 100644 diff --git a/FHE/Ciphertext.cpp b/FHE/Ciphertext.cpp old mode 100755 new mode 100644 index 00e0513..62cbd52 --- a/FHE/Ciphertext.cpp +++ b/FHE/Ciphertext.cpp @@ -130,7 +130,7 @@ void Ciphertext::rerandomize(const FHE_PK& pk) assert(p != 0); for (auto& x : r) { - G.get(x, params->p0().numBits() - p.numBits() - 1); + G.get(x, params->p0().numBits() - p.numBits() - 1); x *= p; } tmp.from(r, 0); diff --git a/FHE/Ciphertext.h b/FHE/Ciphertext.h old mode 100755 new mode 100644 diff --git a/FHE/Diagonalizer.cpp b/FHE/Diagonalizer.cpp old mode 100755 new mode 100644 diff --git a/FHE/Diagonalizer.h b/FHE/Diagonalizer.h old mode 100755 new mode 100644 diff --git a/FHE/DiscreteGauss.cpp b/FHE/DiscreteGauss.cpp old mode 100755 new mode 100644 diff --git a/FHE/DiscreteGauss.h b/FHE/DiscreteGauss.h old mode 100755 new mode 100644 diff --git a/FHE/FFT.cpp b/FHE/FFT.cpp old mode 100755 new mode 100644 diff --git a/FHE/FFT.h b/FHE/FFT.h old mode 100755 new mode 100644 diff --git a/FHE/FFT_Data.cpp b/FHE/FFT_Data.cpp old mode 100755 new mode 100644 diff --git a/FHE/FFT_Data.h b/FHE/FFT_Data.h old mode 100755 new mode 100644 diff --git a/FHE/FHE_Keys.cpp b/FHE/FHE_Keys.cpp old mode 100755 new mode 100644 diff --git a/FHE/FHE_Keys.h b/FHE/FHE_Keys.h old mode 100755 new mode 100644 diff --git a/FHE/FHE_Params.cpp b/FHE/FHE_Params.cpp old mode 100755 new mode 100644 diff --git a/FHE/FHE_Params.h b/FHE/FHE_Params.h old mode 100755 new mode 100644 diff --git a/FHE/Generator.h b/FHE/Generator.h old mode 100755 new mode 100644 diff --git a/FHE/Matrix.cpp b/FHE/Matrix.cpp old mode 100755 new mode 100644 diff --git a/FHE/Matrix.h b/FHE/Matrix.h old mode 100755 new mode 100644 diff --git a/FHE/NTL-Subs.cpp b/FHE/NTL-Subs.cpp old mode 100755 new mode 100644 index 794e743..f397302 --- a/FHE/NTL-Subs.cpp +++ b/FHE/NTL-Subs.cpp @@ -368,7 +368,8 @@ ZZX Cyclotomic(int N) int phi_N(int N) { if (((N - 1) & N) != 0) - throw runtime_error("compile with NTL support"); + throw runtime_error( + "compile with NTL support (USE_NTL=1 in CONFIG.mine)"); else if (N == 1) return 1; else @@ -418,7 +419,8 @@ void init(Ring& Rg, int m, bool generate_poly) for (int i=0; i& elem) const */ } -void PPData::from_eval(vector& elem) const +void PPData::from_eval(vector&) const { // avoid warning - elem.empty(); throw not_implemented(); /* diff --git a/FHE/PPData.h b/FHE/PPData.h old mode 100755 new mode 100644 diff --git a/FHE/Plaintext.cpp b/FHE/Plaintext.cpp old mode 100755 new mode 100644 diff --git a/FHE/Plaintext.h b/FHE/Plaintext.h old mode 100755 new mode 100644 diff --git a/FHE/QGroup.cpp b/FHE/QGroup.cpp old mode 100755 new mode 100644 diff --git a/FHE/QGroup.h b/FHE/QGroup.h old mode 100755 new mode 100644 diff --git a/FHE/Random_Coins.cpp b/FHE/Random_Coins.cpp old mode 100755 new mode 100644 diff --git a/FHE/Random_Coins.h b/FHE/Random_Coins.h old mode 100755 new mode 100644 diff --git a/FHE/Ring.cpp b/FHE/Ring.cpp old mode 100755 new mode 100644 diff --git a/FHE/Ring.h b/FHE/Ring.h old mode 100755 new mode 100644 diff --git a/FHE/Ring_Element.cpp b/FHE/Ring_Element.cpp old mode 100755 new mode 100644 diff --git a/FHE/Ring_Element.h b/FHE/Ring_Element.h old mode 100755 new mode 100644 diff --git a/FHE/Rq_Element.cpp b/FHE/Rq_Element.cpp old mode 100755 new mode 100644 diff --git a/FHE/Rq_Element.h b/FHE/Rq_Element.h old mode 100755 new mode 100644 diff --git a/FHE/Subroutines.cpp b/FHE/Subroutines.cpp old mode 100755 new mode 100644 diff --git a/FHE/Subroutines.h b/FHE/Subroutines.h old mode 100755 new mode 100644 diff --git a/FHE/tools.h b/FHE/tools.h old mode 100755 new mode 100644 diff --git a/FHEOffline/CutAndChooseMachine.cpp b/FHEOffline/CutAndChooseMachine.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/CutAndChooseMachine.h b/FHEOffline/CutAndChooseMachine.h old mode 100755 new mode 100644 diff --git a/FHEOffline/DataSetup.cpp b/FHEOffline/DataSetup.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/DataSetup.h b/FHEOffline/DataSetup.h old mode 100755 new mode 100644 diff --git a/FHEOffline/DataSetup.hpp b/FHEOffline/DataSetup.hpp old mode 100755 new mode 100644 diff --git a/FHEOffline/DistDecrypt.cpp b/FHEOffline/DistDecrypt.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/DistDecrypt.h b/FHEOffline/DistDecrypt.h old mode 100755 new mode 100644 diff --git a/FHEOffline/DistKeyGen.cpp b/FHEOffline/DistKeyGen.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/DistKeyGen.h b/FHEOffline/DistKeyGen.h old mode 100755 new mode 100644 diff --git a/FHEOffline/EncCommit.cpp b/FHEOffline/EncCommit.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/EncCommit.h b/FHEOffline/EncCommit.h old mode 100755 new mode 100644 diff --git a/FHEOffline/FHE-Subroutines.cpp b/FHEOffline/FHE-Subroutines.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/Multiplier.cpp b/FHEOffline/Multiplier.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/Multiplier.h b/FHEOffline/Multiplier.h old mode 100755 new mode 100644 diff --git a/FHEOffline/PairwiseGenerator.cpp b/FHEOffline/PairwiseGenerator.cpp old mode 100755 new mode 100644 index dcbd29b..0fb7a14 --- a/FHEOffline/PairwiseGenerator.cpp +++ b/FHEOffline/PairwiseGenerator.cpp @@ -24,7 +24,7 @@ PairwiseGenerator::PairwiseGenerator(int thread_num, thread_num, machine.output, machine.get_prep_dir(P)), EC(P, machine.other_pks, machine.setup().FieldD, timers, machine, *this), MC(machine.setup().alphai), - n_ciphertexts(Proof::n_ciphertext_per_proof(machine.sec, machine.pk)), + n_ciphertexts(EC.proof.U), C(n_ciphertexts, machine.setup().params), volatile_memory(0), machine(machine) { diff --git a/FHEOffline/PairwiseGenerator.h b/FHEOffline/PairwiseGenerator.h old mode 100755 new mode 100644 diff --git a/FHEOffline/PairwiseMachine.cpp b/FHEOffline/PairwiseMachine.cpp old mode 100755 new mode 100644 index b19dd62..dd3f896 --- a/FHEOffline/PairwiseMachine.cpp +++ b/FHEOffline/PairwiseMachine.cpp @@ -17,15 +17,13 @@ PairwiseMachine::PairwiseMachine(Player& P) : { } -PairwiseMachine::PairwiseMachine(int argc, const char** argv) : - MachineBase(argc, argv), P(*new PlainPlayer(N, "pairwise")), - other_pks(N.num_players(), {setup_p.params, 0}), - pk(other_pks[N.my_num()]), sk(pk) +RealPairwiseMachine::RealPairwiseMachine(int argc, const char** argv) : + MachineBase(argc, argv), PairwiseMachine(*new PlainPlayer(N, "pairwise")) { init(); } -void PairwiseMachine::init() +void RealPairwiseMachine::init() { if (use_gf2n) { @@ -63,7 +61,7 @@ PairwiseSetup& PairwiseMachine::setup() } template -void PairwiseMachine::setup_keys() +void RealPairwiseMachine::setup_keys() { auto& N = P; PairwiseSetup& s = setup(); @@ -84,10 +82,11 @@ void PairwiseMachine::setup_keys() if (i != N.my_num()) other_pks[i].unpack(os[i]); set_mac_key(s.alphai); + Share::MAC_Check::setup(P); } template -void PairwiseMachine::set_mac_key(T alphai) +void RealPairwiseMachine::set_mac_key(T alphai) { typedef typename T::FD FD; auto& N = P; @@ -142,5 +141,5 @@ void PairwiseMachine::check(Player& P) const bundle.compare(P); } -template void PairwiseMachine::setup_keys(); -template void PairwiseMachine::setup_keys(); +template void RealPairwiseMachine::setup_keys(); +template void RealPairwiseMachine::setup_keys(); diff --git a/FHEOffline/PairwiseMachine.h b/FHEOffline/PairwiseMachine.h old mode 100755 new mode 100644 index c228344..a8a0c64 --- a/FHEOffline/PairwiseMachine.h +++ b/FHEOffline/PairwiseMachine.h @@ -10,7 +10,7 @@ #include "FHEOffline/SimpleMachine.h" #include "FHEOffline/PairwiseSetup.h" -class PairwiseMachine : public MachineBase +class PairwiseMachine : public virtual MachineBase { public: PairwiseSetup setup_p; @@ -23,15 +23,6 @@ class PairwiseMachine : public MachineBase vector enc_alphas; PairwiseMachine(Player& P); - PairwiseMachine(int argc, const char** argv); - - void init(); - - template - void setup_keys(); - - template - void set_mac_key(T alphai); template PairwiseSetup& setup(); @@ -42,4 +33,18 @@ class PairwiseMachine : public MachineBase void check(Player& P) const; }; +class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine +{ +public: + RealPairwiseMachine(int argc, const char** argv); + + void init(); + + template + void setup_keys(); + + template + void set_mac_key(T alphai); +}; + #endif /* FHEOFFLINE_PAIRWISEMACHINE_H_ */ diff --git a/FHEOffline/PairwiseSetup.cpp b/FHEOffline/PairwiseSetup.cpp old mode 100755 new mode 100644 index 0197118..bc890ed --- a/FHEOffline/PairwiseSetup.cpp +++ b/FHEOffline/PairwiseSetup.cpp @@ -116,6 +116,14 @@ void secure_init(T& setup, Player& P, U& machine, ofstream file(filename); os.output(file); } + + if (OnlineOptions::singleton.verbose) + { + cerr << "Ciphertext length: " << params.p0().numBits(); + for (size_t i = 1; i < params.FFTD().size(); i++) + cerr << "+" << params.FFTD()[i].get_prime().numBits(); + cerr << endl; + } } template diff --git a/FHEOffline/PairwiseSetup.h b/FHEOffline/PairwiseSetup.h old mode 100755 new mode 100644 diff --git a/FHEOffline/Producer.cpp b/FHEOffline/Producer.cpp old mode 100755 new mode 100644 index 9b143dd..c3ab59e --- a/FHEOffline/Producer.cpp +++ b/FHEOffline/Producer.cpp @@ -577,7 +577,7 @@ void InputProducer::run(const Player& P, const FHE_PK& pk, for (int j = min; j < max; j++) { AddableVector C; - vector> m(EC.machine->sec, FieldD); + vector> m(personal_EC.proof.U, FieldD); if (j == P.my_num()) { for (auto& x : m) diff --git a/FHEOffline/Producer.h b/FHEOffline/Producer.h old mode 100755 new mode 100644 diff --git a/FHEOffline/Proof.cpp b/FHEOffline/Proof.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/Proof.h b/FHEOffline/Proof.h old mode 100755 new mode 100644 index 6059ef3..2eec043 --- a/FHEOffline/Proof.h +++ b/FHEOffline/Proof.h @@ -22,6 +22,8 @@ enum SlackType class Proof { + protected: + unsigned int sec; bool diagonal; @@ -153,14 +155,18 @@ class Proof class NonInteractiveProof : public Proof { + // sec = 0 used for protocols without proofs + static int comp_sec(int sec) { return sec > 0 ? max(COMP_SEC, sec) : 0; } + public: bigint static slack(int sec, int phim) - { return bigint(phim * sec * sec) << (sec / 2 + 8); } + { sec = comp_sec(sec); return bigint(phim * sec * sec) << (sec / 2 + 8); } NonInteractiveProof(int sec, const FHE_PK& pk, int extra_slack, bool diagonal = false) : - Proof(sec, pk, 1, diagonal) + Proof(comp_sec(sec), pk, 1, diagonal) { + sec = this->sec; bigint B; B=128*sec*sec; B <<= extra_slack; diff --git a/FHEOffline/Prover.cpp b/FHEOffline/Prover.cpp old mode 100755 new mode 100644 index d92f308..7127b8c --- a/FHEOffline/Prover.cpp +++ b/FHEOffline/Prover.cpp @@ -128,6 +128,7 @@ size_t Prover::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl bool ok=false; int cnt=0; + (void) cnt; while (!ok) { cnt++; Stage_1(P,ciphertexts,c,pk); diff --git a/FHEOffline/Prover.h b/FHEOffline/Prover.h old mode 100755 new mode 100644 diff --git a/FHEOffline/Reshare.cpp b/FHEOffline/Reshare.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/Reshare.h b/FHEOffline/Reshare.h old mode 100755 new mode 100644 diff --git a/FHEOffline/Sacrificing.cpp b/FHEOffline/Sacrificing.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/Sacrificing.h b/FHEOffline/Sacrificing.h old mode 100755 new mode 100644 diff --git a/FHEOffline/SimpleDistDecrypt.cpp b/FHEOffline/SimpleDistDecrypt.cpp old mode 100755 new mode 100644 diff --git a/FHEOffline/SimpleDistDecrypt.h b/FHEOffline/SimpleDistDecrypt.h old mode 100755 new mode 100644 diff --git a/FHEOffline/SimpleEncCommit.cpp b/FHEOffline/SimpleEncCommit.cpp old mode 100755 new mode 100644 index 9129206..c161f1d --- a/FHEOffline/SimpleEncCommit.cpp +++ b/FHEOffline/SimpleEncCommit.cpp @@ -26,9 +26,10 @@ SimpleEncCommit::SimpleEncCommit(const PlayerBase& P, const FHE_PK& pk int thread_num, bool diagonal) : NonInteractiveProofSimpleEncCommit(P, pk, FTD, timers, machine, diagonal), - SimpleEncCommitFactory(pk, FTD, machine, diagonal) + SimpleEncCommitFactory(pk) { (void)thread_num; + this->init(this->proof, FTD); } template @@ -48,11 +49,15 @@ NonInteractiveProofSimpleEncCommit::NonInteractiveProofSimpleEncCommit( } template -SimpleEncCommitFactory::SimpleEncCommitFactory(const FHE_PK& pk, - const FD& FTD, const MachineBase& machine, bool diagonal) : +SimpleEncCommitFactory::SimpleEncCommitFactory(const FHE_PK& pk) : cnt(-1), n_calls(0), pk(pk) { - int sec = Proof::n_ciphertext_per_proof(machine.sec, pk, diagonal); +} + +template +void SimpleEncCommitFactory::init(const Proof& proof, const FD& FTD) +{ + int sec = proof.U; c.resize(sec, pk.get_params()); m.resize(sec, FTD); for (int i = 0; i < sec; i++) @@ -224,7 +229,7 @@ template SummingEncCommit::SummingEncCommit(const Player& P, const FHE_PK& pk, const FD& FTD, map& timers, const MachineBase& machine, int thread_num, bool diagonal) : - SimpleEncCommitFactory(pk, FTD, machine, diagonal), SimpleEncCommitBase_( + SimpleEncCommitFactory(pk), SimpleEncCommitBase_( machine), proof(machine.sec, pk, P.num_players(), diagonal), pk(pk), FTD( FTD), P(P), thread_num(thread_num), #ifdef LESS_ALLOC_MORE_MEM @@ -233,6 +238,7 @@ SummingEncCommit::SummingEncCommit(const Player& P, const FHE_PK& pk, #endif timers(timers) { + this->init(proof, FTD); } template diff --git a/FHEOffline/SimpleEncCommit.h b/FHEOffline/SimpleEncCommit.h old mode 100755 new mode 100644 index 9fccd9a..f3034fa --- a/FHEOffline/SimpleEncCommit.h +++ b/FHEOffline/SimpleEncCommit.h @@ -92,9 +92,9 @@ class SimpleEncCommitFactory virtual void create_more() = 0; public: - SimpleEncCommitFactory(const FHE_PK& pk, const FD& FTD, - const MachineBase& machine, bool diagonal = false); + SimpleEncCommitFactory(const FHE_PK& pk); virtual ~SimpleEncCommitFactory(); + void init(const Proof& proof, const FD& FTD); bool has_left() { return cnt >= 0; } void next(Plaintext_& mess, Ciphertext& C); virtual size_t report_size(ReportType type); diff --git a/FHEOffline/SimpleGenerator.cpp b/FHEOffline/SimpleGenerator.cpp old mode 100755 new mode 100644 index b2701b2..be5ee2c --- a/FHEOffline/SimpleGenerator.cpp +++ b/FHEOffline/SimpleGenerator.cpp @@ -12,7 +12,7 @@ template