From 378af3bd902ef382bd8c76ab46f5acbfe5b30d6d Mon Sep 17 00:00:00 2001 From: weiwee Date: Tue, 19 Jul 2022 04:47:26 -0800 Subject: [PATCH 01/11] chore: move par impl to subpackage Signed-off-by: weiwee --- .../blocks/rust_paillier_block/__init__.py | 78 +++--- .../tensor/impl/tensor/multithread.py | 47 +--- rust/fate-tensor/benches/base_bench.py | 166 ++++++++----- rust/fate-tensor/fate_tensor.pyi | 126 ---------- rust/fate-tensor/fate_tensor/__init__.py | 1 + rust/fate-tensor/fate_tensor/__init__.pyi | 69 +++++ rust/fate-tensor/fate_tensor/par/__init__.py | 0 rust/fate-tensor/fate_tensor/par/__init__.pyi | 70 ++++++ rust/fate-tensor/pyproject.toml | 4 +- rust/fate-tensor/src/block/matmul.rs | 10 - rust/fate-tensor/src/block/mod.rs | 4 - rust/fate-tensor/src/cb.rs | 125 ---------- rust/fate-tensor/src/lib.rs | 221 +--------------- rust/fate-tensor/src/par/cb.rs | 167 +++++++++++++ rust/fate-tensor/src/par/mod.rs | 235 ++++++++++++++++++ rust/fate-tensor/tests/test_base.py | 138 +++++----- 16 files changed, 745 insertions(+), 716 deletions(-) delete mode 100644 rust/fate-tensor/fate_tensor.pyi create mode 100644 rust/fate-tensor/fate_tensor/__init__.py create mode 100644 rust/fate-tensor/fate_tensor/__init__.pyi create mode 100644 rust/fate-tensor/fate_tensor/par/__init__.py create mode 100644 rust/fate-tensor/fate_tensor/par/__init__.pyi create mode 100644 rust/fate-tensor/src/par/cb.rs create mode 100644 rust/fate-tensor/src/par/mod.rs diff --git a/python/fate_arch/tensor/impl/blocks/rust_paillier_block/__init__.py b/python/fate_arch/tensor/impl/blocks/rust_paillier_block/__init__.py index 4c34742a68..05aa49300e 100644 --- a/python/fate_arch/tensor/impl/blocks/rust_paillier_block/__init__.py +++ b/python/fate_arch/tensor/impl/blocks/rust_paillier_block/__init__.py @@ -17,9 +17,7 @@ import pickle import typing -import fate_tensor import numpy as np -from fate_tensor import Cipherblock import torch from ....abc.block import ( @@ -36,10 +34,10 @@ class PaillierBlock(PHEBlockABC): - def __init__(self, cb: Cipherblock) -> None: + def __init__(self, cb) -> None: self._cb = cb - def create(self, cb: Cipherblock): + def create(self, cb): return PaillierBlock(cb) def __add__(self, other) -> "PaillierBlock": @@ -177,9 +175,8 @@ def serialize(self) -> bytes: class BlockPaillierEncryptor(PHEBlockEncryptorABC): - def __init__(self, pk: fate_tensor.PK, multithread=False) -> None: + def __init__(self, pk) -> None: self._pk = pk - self._multithread = multithread def encrypt(self, other) -> PaillierBlock: if isinstance(other, FPBlock): @@ -188,63 +185,46 @@ def encrypt(self, other) -> PaillierBlock: raise NotImplementedError(f"type {other} not supported") def _encrypt_numpy(self, other): - if self._multithread: - if isinstance(other, np.ndarray): - if other.dtype == np.float64: - return self._pk.encrypt_f64_par(other) - if other.dtype == np.float32: - return self._pk.encrypt_f32_par(other) - if other.dtype == np.int64: - return self._pk.encrypt_i64_par(other) - if other.dtype == np.int32: - return self._pk.encrypt_i32_par(other) - else: - if isinstance(other, np.ndarray): - if other.dtype == np.float64: - return self._pk.encrypt_f64(other) - if other.dtype == np.float32: - return self._pk.encrypt_f32(other) - if other.dtype == np.int64: - return self._pk.encrypt_i64(other) - if other.dtype == np.int32: - return self._pk.encrypt_i32(other) + if isinstance(other, np.ndarray): + if other.dtype == np.float64: + return self._pk.encrypt_f64(other) + if other.dtype == np.float32: + return self._pk.encrypt_f32(other) + if other.dtype == np.int64: + return self._pk.encrypt_i64(other) + if other.dtype == np.int32: + return self._pk.encrypt_i32(other) raise NotImplementedError(f"type {other} {other.dtype} not supported") class BlockPaillierDecryptor(PHEBlockDecryptorABC): - def __init__(self, sk: fate_tensor.SK, multithread=False) -> None: + def __init__(self, sk) -> None: self._sk = sk - self._multithread = multithread def decrypt(self, other: PaillierBlock, dtype=np.float64): return torch.from_numpy(self._decrypt_numpy(other._cb, dtype)) - def _decrypt_numpy(self, cb: Cipherblock, dtype=np.float64): - if self._multithread: - if dtype == np.float64: - return self._sk.decrypt_f64_par(cb) - if dtype == np.float32: - return self._sk.decrypt_f32_par(cb) - if dtype == np.int64: - return self._sk.decrypt_i64_par(cb) - if dtype == np.int32: - return self._sk.decrypt_i32_par(cb) - else: - if dtype == np.float64: - return self._sk.decrypt_f64(cb) - if dtype == np.float32: - return self._sk.decrypt_f32(cb) - if dtype == np.int64: - return self._sk.decrypt_i64(cb) - if dtype == np.int32: - return self._sk.decrypt_i32(cb) + def _decrypt_numpy(self, cb, dtype=np.float64): + if dtype == np.float64: + return self._sk.decrypt_f64(cb) + if dtype == np.float32: + return self._sk.decrypt_f32(cb) + if dtype == np.int64: + return self._sk.decrypt_i64(cb) + if dtype == np.int32: + return self._sk.decrypt_i32(cb) raise NotImplementedError("dtype = {dtype}") class BlockPaillierCipher(PHEBlockCipherABC): @classmethod def keygen( - cls, key_length=1024, multithread=False, + cls, key_length=1024 ) -> typing.Tuple[BlockPaillierEncryptor, BlockPaillierDecryptor]: + import fate_tensor + pubkey, prikey = fate_tensor.keygen(bit_size=key_length) - return (BlockPaillierEncryptor(pubkey, multithread), BlockPaillierDecryptor(prikey, multithread)) + return ( + BlockPaillierEncryptor(pubkey), + BlockPaillierDecryptor(prikey), + ) diff --git a/python/fate_arch/tensor/impl/tensor/multithread.py b/python/fate_arch/tensor/impl/tensor/multithread.py index 7aeead7885..e932ed033a 100644 --- a/python/fate_arch/tensor/impl/tensor/multithread.py +++ b/python/fate_arch/tensor/impl/tensor/multithread.py @@ -4,7 +4,7 @@ import torch from ...abc.tensor import ( - FPTensorABC, + FPTensorProtocol, PHECipherABC, PHEDecryptorABC, PHEEncryptorABC, @@ -17,39 +17,6 @@ TYPECT = typing.Union[TYPEFP, "PHETensorLocal"] -# class FPTensorLocal(FPTensorABC): -# """ -# CPU multiple thread backend Local Fixed Presicion Tensor -# """ - -# def __init__(self, block): -# self._block = block - -# def __add__(self, other: TYPEFP) -> "FPTensorLocal": -# return _fp_binary_op(self, other, operator.add, FP_OP_TYPES) - -# def __radd__(self, other: TYPEFP) -> "FPTensorLocal": -# return _fp_binary_op(other, self, operator.add, FP_OP_TYPES) - -# def __sub__(self, other: TYPEFP) -> "FPTensorLocal": -# return _fp_binary_op(self, other, operator.sub, FP_OP_TYPES) - -# def __rsub__(self, other: TYPEFP) -> "FPTensorLocal": -# return _fp_binary_op(other, self, operator.sub, FP_OP_TYPES) - -# def __mul__(self, other: TYPEFP) -> "FPTensorLocal": -# return _fp_binary_op(self, other, operator.mul, FP_OP_TYPES) - -# def __rmul__(self, other: TYPEFP) -> "FPTensorLocal": -# return _fp_binary_op(other, self, operator.mul, FP_OP_TYPES) - -# def __matmul__(self, other: "FPTensorLocal") -> "FPTensorLocal": -# return FPTensorLocal(operator.matmul(self._block, other._block)) - -# def __rmatmul__(self, other: "FPTensorLocal") -> "FPTensorLocal": -# return FPTensorLocal(operator.matmul(other._block, self._block)) - - class PHETensorLocal(PHETensorABC): def __init__(self, block) -> None: """ """ @@ -131,18 +98,6 @@ def keygen( PaillierPHEDecryptorLocal(block_decryptor), ) - -def _fp_binary_op(self, other, func, types): - if type(other) not in types: - return NotImplemented - elif isinstance(other, FPTensorLocal): - return FPTensorLocal(func(self, other)) - elif isinstance(other, (int, float)): - return FPTensorLocal(func(self, other)) - else: - return NotImplemented - - def _phe_binary_op(self, other, func, types): if type(other) not in types: return NotImplemented diff --git a/rust/fate-tensor/benches/base_bench.py b/rust/fate-tensor/benches/base_bench.py index 31b91176d0..576c88707e 100644 --- a/rust/fate-tensor/benches/base_bench.py +++ b/rust/fate-tensor/benches/base_bench.py @@ -1,11 +1,57 @@ -import fate_tensor -import numpy as np +from _pytest.mark import expression import pytest import operator +import os + +import numpy as np import phe +try: + import gmpy2 +except: + raise RuntimeError(f"gmpy2 not installed, lib phe without gmpy2 is slow") + + +def get_num_threads(): + num = int(os.environ.get("NUM_THREADS", 4)) + cpu_count = os.cpu_count() + if cpu_count is not None and cpu_count < num: + raise RuntimeError( + f"num threads {num} larger than num cpu core deteacted, try specify num threads by `NUM_THREADS=xxx pytest ...`" + ) + return num + + +def get_single_thread_keygen(): + from fate_tensor import keygen + + return keygen + + +NUM_THREADS = get_num_threads() + + +def get_multiple_thread_keygen(): + from fate_tensor.par import keygen, set_num_threads + + set_num_threads(NUM_THREADS) + return keygen -class PHESuite: + +# modify this if you want to benchmark your custom packages +BENCH_PACKAGES = { + "cpu_thread[1]": get_single_thread_keygen(), + f"cpu_multiple_thread[{NUM_THREADS}]": get_multiple_thread_keygen(), +} + +sa, sb, sc, sd = ((11, 21), (11, 21), (21, 11), 21) +a = np.random.random(size=sa).astype(dtype=np.float64) - 0.5 +b = np.random.random(size=sb).astype(dtype=np.float64) - 0.5 +c = np.random.random(size=sc).astype(dtype=np.float64) - 0.5 +d = np.random.random(size=sd).astype(dtype=np.float64) - 0.5 + + +class BaselineSuite: def __init__(self, a, b, c, d) -> None: self.a = a self.b = b @@ -54,109 +100,105 @@ def rmatmul_plain_ix1(self): return self.d @ self.ec -class CPUBlockSuite: - _mix = "" - - def __init__(self, a, b, c, d) -> None: +class BenchSuite: + def __init__(self, a, b, c, d, keygen) -> None: self.a = a self.b = b self.c = c self.d = d - self.pk, self.sk = fate_tensor.keygen(1024) + self.pk, self.sk = keygen(1024) self.ea = self.pk.encrypt_f64(self.a) self.eb = self.pk.encrypt_f64(self.b) self.ec = self.pk.encrypt_f64(self.c) self.ed = self.pk.encrypt_f64(self.d) - def mix(self, name): - return f"{name}{self._mix}" - def get(self, name): return getattr(self, name) def encrypt(self): - getattr(self.pk, self.mix("encrypt_f64"))(self.a) + getattr(self.pk, "encrypt_f64")(self.a) def decrypt(self): - getattr(self.sk, self.mix("decrypt_f64"))(self.ea) + getattr(self.sk, "decrypt_f64")(self.ea) def add_cipher(self): - getattr(self.ea, self.mix("add_cipherblock"))(self.eb) + getattr(self.ea, "add_cipherblock")(self.eb) def sub_cipher(self): - getattr(self.ea, self.mix("sub_cipherblock"))(self.eb) + getattr(self.ea, "sub_cipherblock")(self.eb) def add_plain(self): - getattr(self.ea, self.mix("add_plaintext_f64"))(self.b) + getattr(self.ea, "add_plaintext_f64")(self.b) def sub_plain(self): - getattr(self.ea, self.mix("sub_plaintext_f64"))(self.b) + getattr(self.ea, "sub_plaintext_f64")(self.b) def mul_plain(self): - getattr(self.ea, self.mix("mul_plaintext_f64"))(self.b) + getattr(self.ea, "mul_plaintext_f64")(self.b) def matmul_plain_ix2(self): - getattr(self.ea, self.mix("matmul_plaintext_ix2_f64"))(self.c) + getattr(self.ea, "matmul_plaintext_ix2_f64")(self.c) def rmatmul_plain_ix2(self): - getattr(self.ec, self.mix("rmatmul_plaintext_ix2_f64"))(self.a) + getattr(self.ec, "rmatmul_plaintext_ix2_f64")(self.a) def matmul_plain_ix1(self): - getattr(self.ea, self.mix("matmul_plaintext_ix1_f64"))(self.d) + getattr(self.ea, "matmul_plaintext_ix1_f64")(self.d) def rmatmul_plain_ix1(self): - getattr(self.ec, self.mix("rmatmul_plaintext_ix1_f64"))(self.d) - - -class CPUBlockParSuite(CPUBlockSuite): - _mix = "_par" + getattr(self.ec, "rmatmul_plaintext_ix1_f64")(self.d) -class Suites: - def __init__(self, a, b, c, d) -> None: - self.suites = { - "phe": PHESuite(a, b, c, d), - "block": CPUBlockSuite(a, b, c, d), - "block_par": CPUBlockParSuite(a, b, c, d), - } - - def get(self, name): - return self.suites[name] +def get_suites(): + ids = [] + suites = [] + ids.append("baseline") + suites.append(BaselineSuite(a, b, c, d)) + for package_id, keygen in BENCH_PACKAGES.items(): + ids.append(package_id) + suites.append(BenchSuite(a, b, c, d, keygen)) + return ids, suites -@pytest.fixture -def shape(): - return ((11, 21), (11, 21), (21, 11), 21) +ids, suites = get_suites() -@pytest.fixture -def suites(shape): - sa, sb, sc, sd = shape - a = np.random.random(size=sa).astype(dtype=np.float64) - 0.5 - b = np.random.random(size=sb).astype(dtype=np.float64) - 0.5 - c = np.random.random(size=sc).astype(dtype=np.float64) - 0.5 - d = np.random.random(size=sd).astype(dtype=np.float64) - 0.5 - return Suites(a, b, c, d) +def pytest_generate_tests(metafunc): + if "suite" in metafunc.fixturenames: + metafunc.parametrize("suite", suites, ids=ids) def create_tests(func_name): - # @pytest.mark.benchmark(group=func_name) - @pytest.mark.parametrize("name", ["phe", "block", "block_par"]) - def f(name, suites, benchmark): - benchmark(suites.get(name).get(func_name)) + @pytest.mark.benchmark(group=f"{func_name}") + def f(suite, benchmark): + benchmark(suite.get(func_name)) f.__name__ = f"test_{func_name}" return f -test_encrypt = create_tests("encrypt") -test_decrypt = create_tests("decrypt") -test_add_cipher = create_tests("add_cipher") -test_sub_cipher = create_tests("sub_cipher") -test_add_plain = create_tests("add_plain") -test_sub_plain = create_tests("sub_plain") -test_mul_plain = create_tests("mul_plain") -test_matmul_plain_ix2 = create_tests("matmul_plain_ix2") -test_rmatmul_plain_ix2 = create_tests("rmatmul_plain_ix2") -test_matmul_plain_ix1 = create_tests("matmul_plain_ix1") -test_rmatmul_plain_ix1 = create_tests("rmatmul_plain_ix1") +def get_exec_expression(name, *args): + shapes = [] + for shape in args: + if isinstance(shape, int): + shapes.append(f"{shape}") + if isinstance(shape, tuple): + shapes.append(f"{'x'.join(map(str, shape))}") + shape_suffix = "_".join(shapes) + return f'test_{name}_{shape_suffix} = create_tests("{name}")' + + +for name, *shapes in [ + ("encrypt", sa), + ("decrypt", sa), + ("add_cipher", sa, sb), + ("sub_cipher", sa, sb), + ("add_plain", sa, sb), + ("sub_plain", sa, sb), + ("mul_plain", sa, sb), + ("matmul_plain_ix2", sa, sc), + ("rmatmul_plain_ix2", sc, sa), + ("matmul_plain_ix1", sa, sd), + ("rmatmul_plain_ix1", sc, sd), +]: + exec(get_exec_expression(name, *shapes)) diff --git a/rust/fate-tensor/fate_tensor.pyi b/rust/fate-tensor/fate_tensor.pyi deleted file mode 100644 index 057c06669f..0000000000 --- a/rust/fate-tensor/fate_tensor.pyi +++ /dev/null @@ -1,126 +0,0 @@ -import typing - -import numpy as np -import numpy.typing as npt - -class Cipherblock: - def add_cipherblock(self, other: Cipherblock) -> Cipherblock: ... - def add_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def add_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def add_plaintext_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def add_plaintext_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def add_plaintext_scalar_f64(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... - def add_plaintext_scalar_f32(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... - def add_plaintext_scalar_i64(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... - def add_plaintext_scalar_i32(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... - - def sub_cipherblock(self, other: Cipherblock) -> Cipherblock: ... - def sub_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def sub_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def sub_plaintext_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def sub_plaintext_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def sub_plaintext_scalar_f64(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... - def sub_plaintext_scalar_f32(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... - def sub_plaintext_scalar_i64(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... - def sub_plaintext_scalar_i32(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... - - def mul_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def mul_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def mul_plaintext_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def mul_plaintext_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def mul_plaintext_scalar_f64(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... - def mul_plaintext_scalar_f32(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... - def mul_plaintext_scalar_i64(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... - def mul_plaintext_scalar_i32(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... - - def matmul_plaintext_ix2_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def matmul_plaintext_ix2_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def matmul_plaintext_ix2_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def matmul_plaintext_ix2_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def matmul_plaintext_ix1_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def matmul_plaintext_ix1_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def matmul_plaintext_ix1_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def matmul_plaintext_ix1_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def rmatmul_plaintext_ix2_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def rmatmul_plaintext_ix2_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def rmatmul_plaintext_ix2_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def rmatmul_plaintext_ix2_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def rmatmul_plaintext_ix1_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def rmatmul_plaintext_ix1_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def rmatmul_plaintext_ix1_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def rmatmul_plaintext_ix1_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - - def sum(self) -> Cipherblock: ... - def mean(self) -> Cipherblock: ... - - """parallel""" - def add_cipherblock_par(self, other: Cipherblock) -> Cipherblock: ... - def add_plaintext_f64_par(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def add_plaintext_f32_par(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def add_plaintext_i64_par(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def add_plaintext_scalar_f64_par(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... - def add_plaintext_scalar_f32_par(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... - def add_plaintext_scalar_i64_par(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... - def add_plaintext_scalar_i32_par(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... - - def add_plaintext_i32_par(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def sub_cipherblock_par(self, other: Cipherblock) -> Cipherblock: ... - def sub_plaintext_f64_par(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def sub_plaintext_f32_par(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def sub_plaintext_i64_par(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def sub_plaintext_i32_par(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def sub_plaintext_scalar_f64_par(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... - def sub_plaintext_scalar_f32_par(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... - def sub_plaintext_scalar_i64_par(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... - def sub_plaintext_scalar_i32_par(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... - - def mul_plaintext_f64_par(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def mul_plaintext_f32_par(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def mul_plaintext_i64_par(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def mul_plaintext_i32_par(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def mul_plaintext_scalar_f64_par(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... - def mul_plaintext_scalar_f32_par(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... - def mul_plaintext_scalar_i64_par(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... - def mul_plaintext_scalar_i32_par(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... - - def matmul_plaintext_ix2_f64_par(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def matmul_plaintext_ix2_f32_par(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def matmul_plaintext_ix2_i64_par(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def matmul_plaintext_ix2_i32_par(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def matmul_plaintext_ix1_f64_par(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def matmul_plaintext_ix1_f32_par(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def matmul_plaintext_ix1_i64_par(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def matmul_plaintext_ix1_i32_par(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def rmatmul_plaintext_ix2_f64_par(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def rmatmul_plaintext_ix2_f32_par(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def rmatmul_plaintext_ix2_i64_par(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def rmatmul_plaintext_ix2_i32_par(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - def rmatmul_plaintext_ix1_f64_par(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... - def rmatmul_plaintext_ix1_f32_par(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... - def rmatmul_plaintext_ix1_i64_par(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... - def rmatmul_plaintext_ix1_i32_par(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... - - def sum_par(self) -> Cipherblock: ... - def mean_par(self) -> Cipherblock: ... - -class PK: - def encrypt_f64(self, a: npt.NDArray[np.float64]) -> Cipherblock: ... - def encrypt_f32(self, a: npt.NDArray[np.float32]) -> Cipherblock: ... - def encrypt_i64(self, a: npt.NDArray[np.int64]) -> Cipherblock: ... - def encrypt_i32(self, a: npt.NDArray[np.int32]) -> Cipherblock: ... - def encrypt_f64_par(self, a: npt.NDArray[np.float64]) -> Cipherblock: ... - def encrypt_f32_par(self, a: npt.NDArray[np.float32]) -> Cipherblock: ... - def encrypt_i64_par(self, a: npt.NDArray[np.int64]) -> Cipherblock: ... - def encrypt_i32_par(self, a: npt.NDArray[np.int32]) -> Cipherblock: ... - -class SK: - def decrypt_f64(self, a: Cipherblock) -> npt.NDArray[np.float64]: ... - def decrypt_f32(self, a: Cipherblock) -> npt.NDArray[np.float32]: ... - def decrypt_i64(self, a: Cipherblock) -> npt.NDArray[np.int64]: ... - def decrypt_i32(self, a: Cipherblock) -> npt.NDArray[np.int32]: ... - def decrypt_f64_par(self, a: Cipherblock) -> npt.NDArray[np.float64]: ... - def decrypt_f32_par(self, a: Cipherblock) -> npt.NDArray[np.float32]: ... - def decrypt_i64_par(self, a: Cipherblock) -> npt.NDArray[np.int64]: ... - def decrypt_i32_par(self, a: Cipherblock) -> npt.NDArray[np.int32]: ... - -def keygen(bit_size) -> typing.Tuple[PK, SK]:... diff --git a/rust/fate-tensor/fate_tensor/__init__.py b/rust/fate-tensor/fate_tensor/__init__.py new file mode 100644 index 0000000000..1c1df9b607 --- /dev/null +++ b/rust/fate-tensor/fate_tensor/__init__.py @@ -0,0 +1 @@ +from .fate_tensor import * diff --git a/rust/fate-tensor/fate_tensor/__init__.pyi b/rust/fate-tensor/fate_tensor/__init__.pyi new file mode 100644 index 0000000000..a4a8509100 --- /dev/null +++ b/rust/fate-tensor/fate_tensor/__init__.pyi @@ -0,0 +1,69 @@ +import typing + +import numpy as np +import numpy.typing as npt + +class Cipherblock: + def add_cipherblock(self, other: Cipherblock) -> Cipherblock: ... + def add_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def add_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def add_plaintext_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def add_plaintext_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def add_plaintext_scalar_f64(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... + def add_plaintext_scalar_f32(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... + def add_plaintext_scalar_i64(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... + def add_plaintext_scalar_i32(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... + + def sub_cipherblock(self, other: Cipherblock) -> Cipherblock: ... + def sub_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def sub_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def sub_plaintext_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def sub_plaintext_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def sub_plaintext_scalar_f64(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... + def sub_plaintext_scalar_f32(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... + def sub_plaintext_scalar_i64(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... + def sub_plaintext_scalar_i32(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... + + def mul_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def mul_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def mul_plaintext_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def mul_plaintext_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def mul_plaintext_scalar_f64(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... + def mul_plaintext_scalar_f32(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... + def mul_plaintext_scalar_i64(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... + def mul_plaintext_scalar_i32(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... + + def matmul_plaintext_ix2_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def matmul_plaintext_ix2_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def matmul_plaintext_ix2_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def matmul_plaintext_ix2_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def matmul_plaintext_ix1_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def matmul_plaintext_ix1_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def matmul_plaintext_ix1_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def matmul_plaintext_ix1_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def rmatmul_plaintext_ix2_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def rmatmul_plaintext_ix2_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def rmatmul_plaintext_ix2_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def rmatmul_plaintext_ix2_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def rmatmul_plaintext_ix1_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def rmatmul_plaintext_ix1_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def rmatmul_plaintext_ix1_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def rmatmul_plaintext_ix1_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + + def sum(self) -> Cipherblock: ... + def mean(self) -> Cipherblock: ... + + +class PK: + def encrypt_f64(self, a: npt.NDArray[np.float64]) -> Cipherblock: ... + def encrypt_f32(self, a: npt.NDArray[np.float32]) -> Cipherblock: ... + def encrypt_i64(self, a: npt.NDArray[np.int64]) -> Cipherblock: ... + def encrypt_i32(self, a: npt.NDArray[np.int32]) -> Cipherblock: ... + +class SK: + def decrypt_f64(self, a: Cipherblock) -> npt.NDArray[np.float64]: ... + def decrypt_f32(self, a: Cipherblock) -> npt.NDArray[np.float32]: ... + def decrypt_i64(self, a: Cipherblock) -> npt.NDArray[np.int64]: ... + def decrypt_i32(self, a: Cipherblock) -> npt.NDArray[np.int32]: ... + +def keygen(bit_size: int) -> typing.Tuple[PK, SK]:... diff --git a/rust/fate-tensor/fate_tensor/par/__init__.py b/rust/fate-tensor/fate_tensor/par/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/rust/fate-tensor/fate_tensor/par/__init__.pyi b/rust/fate-tensor/fate_tensor/par/__init__.pyi new file mode 100644 index 0000000000..7de2099ddb --- /dev/null +++ b/rust/fate-tensor/fate_tensor/par/__init__.pyi @@ -0,0 +1,70 @@ +import typing + +import numpy as np +import numpy.typing as npt + +class Cipherblock: + def add_cipherblock(self, other: Cipherblock) -> Cipherblock: ... + def add_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def add_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def add_plaintext_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def add_plaintext_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def add_plaintext_scalar_f64(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... + def add_plaintext_scalar_f32(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... + def add_plaintext_scalar_i64(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... + def add_plaintext_scalar_i32(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... + + def sub_cipherblock(self, other: Cipherblock) -> Cipherblock: ... + def sub_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def sub_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def sub_plaintext_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def sub_plaintext_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def sub_plaintext_scalar_f64(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... + def sub_plaintext_scalar_f32(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... + def sub_plaintext_scalar_i64(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... + def sub_plaintext_scalar_i32(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... + + def mul_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def mul_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def mul_plaintext_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def mul_plaintext_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def mul_plaintext_scalar_f64(self, other: typing.Union[float, np.float64]) -> Cipherblock: ... + def mul_plaintext_scalar_f32(self, other: typing.Union[float, np.float32]) -> Cipherblock: ... + def mul_plaintext_scalar_i64(self, other: typing.Union[int, np.int64]) -> Cipherblock: ... + def mul_plaintext_scalar_i32(self, other: typing.Union[int, np.int32]) -> Cipherblock: ... + + def matmul_plaintext_ix2_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def matmul_plaintext_ix2_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def matmul_plaintext_ix2_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def matmul_plaintext_ix2_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def matmul_plaintext_ix1_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def matmul_plaintext_ix1_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def matmul_plaintext_ix1_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def matmul_plaintext_ix1_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def rmatmul_plaintext_ix2_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def rmatmul_plaintext_ix2_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def rmatmul_plaintext_ix2_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def rmatmul_plaintext_ix2_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + def rmatmul_plaintext_ix1_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... + def rmatmul_plaintext_ix1_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... + def rmatmul_plaintext_ix1_i64(self, other: npt.NDArray[np.int64]) -> Cipherblock: ... + def rmatmul_plaintext_ix1_i32(self, other: npt.NDArray[np.int32]) -> Cipherblock: ... + + def sum(self) -> Cipherblock: ... + def mean(self) -> Cipherblock: ... + + +class PK: + def encrypt_f64(self, a: npt.NDArray[np.float64]) -> Cipherblock: ... + def encrypt_f32(self, a: npt.NDArray[np.float32]) -> Cipherblock: ... + def encrypt_i64(self, a: npt.NDArray[np.int64]) -> Cipherblock: ... + def encrypt_i32(self, a: npt.NDArray[np.int32]) -> Cipherblock: ... + +class SK: + def decrypt_f64(self, a: Cipherblock) -> npt.NDArray[np.float64]: ... + def decrypt_f32(self, a: Cipherblock) -> npt.NDArray[np.float32]: ... + def decrypt_i64(self, a: Cipherblock) -> npt.NDArray[np.int64]: ... + def decrypt_i32(self, a: Cipherblock) -> npt.NDArray[np.int32]: ... + +def keygen(bit_size: int) -> typing.Tuple[PK, SK]:... +def set_num_threads(num_threads: int): ... diff --git a/rust/fate-tensor/pyproject.toml b/rust/fate-tensor/pyproject.toml index 92b7d1124e..abfdd0b0ec 100644 --- a/rust/fate-tensor/pyproject.toml +++ b/rust/fate-tensor/pyproject.toml @@ -16,8 +16,8 @@ exclude = [ "**/__pycache__", ".venv/" ] -venvPath = "/Users/sage/projects/rust-tensor" -venv = ".venv" +venvPath = "/Users/sage/MEGA/FATE/" +venv = "venv" reportMissingImports = true reportMissingTypeStubs = false executionEnvironments = [ diff --git a/rust/fate-tensor/src/block/matmul.rs b/rust/fate-tensor/src/block/matmul.rs index 63ca900452..2cc52d6ce5 100644 --- a/rust/fate-tensor/src/block/matmul.rs +++ b/rust/fate-tensor/src/block/matmul.rs @@ -1,6 +1,5 @@ use super::{fixedpoint, Cipherblock, CouldCode}; use ndarray::{ArrayView1, ArrayView2}; -#[cfg(feature = "rayon")] use rayon::prelude::*; /// help function for generic matrix multiply @@ -28,7 +27,6 @@ where /// 1. the output matrix C has shape (m, n) /// 2. C(i,j) = \sum_k A(i,k)B(k,j) /// 3. the function F(i, k, j, v): v += A(i,k)B(k,j) -#[cfg(feature = "rayon")] fn matmul_apply_par(m: usize, s: usize, n: usize, func: F) -> Vec where F: Fn(usize, usize, usize, &mut fixedpoint::CT) -> () + Sync, // (i, k, j, v) -> () @@ -203,7 +201,6 @@ pub fn cipherblock_rmatmul_plaintext_ix2( } } -#[cfg(feature = "rayon")] pub fn cipherblock_matmul_plaintext_ix1_par( lhs: &Cipherblock, rhs: ArrayView1, @@ -219,7 +216,6 @@ pub fn cipherblock_matmul_plaintext_ix1_par( shape: vec![m, n], } } -#[cfg(feature = "rayon")] pub fn cipherblock_matmul_plaintext_ix2_par( lhs: &Cipherblock, rhs: ArrayView2, @@ -235,7 +231,6 @@ pub fn cipherblock_matmul_plaintext_ix2_par( shape: vec![m, n], } } -#[cfg(feature = "rayon")] pub fn cipherblock_rmatmul_plaintext_ix1_par( lhs: ArrayView1, rhs: &Cipherblock, @@ -252,7 +247,6 @@ pub fn cipherblock_rmatmul_plaintext_ix1_par( } } -#[cfg(feature = "rayon")] pub fn cipherblock_rmatmul_plaintext_ix2_par( lhs: ArrayView2, rhs: &Cipherblock, @@ -284,22 +278,18 @@ impl Cipherblock { } // par - #[cfg(feature = "rayon")] pub fn matmul_plaintext_ix1_par(&self, rhs: ArrayView1) -> Cipherblock { cipherblock_matmul_plaintext_ix1_par(self, rhs) } - #[cfg(feature = "rayon")] pub fn rmatmul_plaintext_ix1_par( &self, lhs: ArrayView1, ) -> Cipherblock { cipherblock_rmatmul_plaintext_ix1_par(lhs, self) } - #[cfg(feature = "rayon")] pub fn matmul_plaintext_ix2_par(&self, rhs: ArrayView2) -> Cipherblock { cipherblock_matmul_plaintext_ix2_par(self, rhs) } - #[cfg(feature = "rayon")] pub fn rmatmul_plaintext_ix2_par( &self, lhs: ArrayView2, diff --git a/rust/fate-tensor/src/block/mod.rs b/rust/fate-tensor/src/block/mod.rs index 7c37d93416..303e995979 100644 --- a/rust/fate-tensor/src/block/mod.rs +++ b/rust/fate-tensor/src/block/mod.rs @@ -3,7 +3,6 @@ use std::ops::Index; use super::fixedpoint; use super::fixedpoint::CouldCode; use ndarray::{ArrayD, ArrayViewD}; -#[cfg(feature = "rayon")] use rayon::prelude::*; use serde::{Deserialize, Serialize}; mod matmul; @@ -118,7 +117,6 @@ impl fixedpoint::SK { } } -#[cfg(feature = "rayon")] impl Cipherblock { pub fn agg_par(&self, identity: ID, f: F, op: OP) -> T where @@ -188,7 +186,6 @@ impl Cipherblock { } } -#[cfg(feature = "rayon")] impl fixedpoint::PK { pub fn encrypt_array_par(&self, array: ArrayViewD) -> Cipherblock where @@ -207,7 +204,6 @@ impl fixedpoint::PK { } } -#[cfg(feature = "rayon")] impl fixedpoint::SK { pub fn decrypt_array_par(&self, array: &Cipherblock) -> ArrayD where diff --git a/rust/fate-tensor/src/cb.rs b/rust/fate-tensor/src/cb.rs index f73a75cc2b..08c2a9caf8 100644 --- a/rust/fate-tensor/src/cb.rs +++ b/rust/fate-tensor/src/cb.rs @@ -36,14 +36,6 @@ macro_rules! impl_ops_cipher_scalar { }) } }; - ($name:ident,$fn:expr,$feature:ident) => { - #[cfg(feature = "rayon")] - pub fn $name(&self, other: &fixedpoint::CT) -> Cipherblock { - operation_with_scalar(self, other, |lhs, rhs| { - block::Cipherblock::map_par(lhs, |c| $fn(c, rhs, &lhs.pk)) - }) - } - }; } macro_rules! impl_ops_plaintext_scalar { ($name:ident,$fn:expr) => { @@ -56,17 +48,6 @@ macro_rules! impl_ops_plaintext_scalar { }) } }; - ($name:ident,$fn:expr,$feature:ident) => { - #[cfg(feature = "rayon")] - pub fn $name(&self, other: T) -> Cipherblock - where - T: CouldCode + Sync, - { - operation_with_scalar(self, other, |lhs, rhs| { - block::Cipherblock::map_par(lhs, |c| $fn(c, &rhs.encode(&lhs.pk.coder), &lhs.pk)) - }) - } - }; } macro_rules! impl_ops_cipher { ($name:ident,$fn:expr) => { @@ -76,14 +57,6 @@ macro_rules! impl_ops_cipher { }) } }; - ($name:ident,$fn:expr,$feature:ident) => { - #[cfg(feature = "rayon")] - pub fn $name(&self, other: &Cipherblock) -> Cipherblock { - operation_with_cipherblock(self, other, |lhs, rhs| { - block::Cipherblock::binary_cipherblock_cipherblock_par(lhs, rhs, $fn) - }) - } - }; } macro_rules! impl_ops_plain { ($name:ident,$fn:expr) => { @@ -96,17 +69,6 @@ macro_rules! impl_ops_plain { }) } }; - ($name:ident,$fn:expr,$feature:ident) => { - #[cfg(feature = "rayon")] - pub fn $name(&self, other: ArrayViewD) -> Cipherblock - where - T: fixedpoint::CouldCode + Sync + Send, - { - operation_with_arrayview_dyn(self, other, |lhs, rhs| { - block::Cipherblock::binary_cipherblock_plaintext_par(lhs, rhs, $fn) - }) - } - }; } macro_rules! impl_ops_matmul { ($name:ident, $fn:expr, $oty:ident) => { @@ -114,12 +76,6 @@ macro_rules! impl_ops_matmul { Cipherblock::new($fn(self.unwrap(), other)) } }; - ($name:ident, $fn:expr, $oty:ident, $feature:ident) => { - #[cfg(feature = "rayon")] - pub fn $name(&self, other: $oty) -> Cipherblock { - Cipherblock::new($fn(self.unwrap(), other)) - } - }; } impl Cipherblock { fn new(cb: block::Cipherblock) -> Cipherblock { @@ -161,45 +117,6 @@ impl Cipherblock { ArrayView2 ); - //par - impl_ops_cipher!(add_cb_par, fixedpoint::CT::add, rayon); - impl_ops_plain!(add_plaintext_par, fixedpoint::CT::add_pt, rayon); - impl_ops_cipher_scalar!(add_cipher_scalar_par, fixedpoint::CT::add, rayon); - impl_ops_plaintext_scalar!(add_plaintext_scalar_par, fixedpoint::CT::add_pt, rayon); - - impl_ops_cipher!(sub_cb_par, fixedpoint::CT::sub, rayon); - impl_ops_plain!(sub_plaintext_par, fixedpoint::CT::sub_pt, rayon); - impl_ops_cipher_scalar!(sub_cipher_scalar_par, fixedpoint::CT::add, rayon); - impl_ops_plaintext_scalar!(sub_plaintext_scalar_par, fixedpoint::CT::sub_pt, rayon); - - impl_ops_plain!(mul_plaintext_par, fixedpoint::CT::mul, rayon); - impl_ops_plaintext_scalar!(mul_plaintext_scalar_par, fixedpoint::CT::mul, rayon); - - // matmul - impl_ops_matmul!( - matmul_plaintext_ix1_par, - block::Cipherblock::matmul_plaintext_ix1_par, - ArrayView1, - rayon - ); - impl_ops_matmul!( - rmatmul_plaintext_ix1_par, - block::Cipherblock::rmatmul_plaintext_ix1_par, - ArrayView1, - rayon - ); - impl_ops_matmul!( - matmul_plaintext_ix2_par, - block::Cipherblock::matmul_plaintext_ix2_par, - ArrayView2, - rayon - ); - impl_ops_matmul!( - rmatmul_plaintext_ix2_par, - block::Cipherblock::rmatmul_plaintext_ix2_par, - ArrayView2, - rayon - ); } impl Cipherblock { @@ -224,36 +141,6 @@ impl Cipherblock { shape: vec![1], }) } - - #[cfg(feature = "rayon")] - pub fn sum_cb_par(&self) -> Cipherblock { - let cb = self.unwrap(); - let sum = cb.agg_par( - fixedpoint::CT::zero, - |s, c| s.add(c, &cb.pk), - |s1, s2| s1.add(&s2, &cb.pk), - ); - Cipherblock::new(block::Cipherblock { - pk: cb.pk.clone(), - data: vec![sum], - shape: vec![1], - }) - } - #[cfg(feature = "rayon")] - pub fn mean_cb_par(&self) -> Cipherblock { - let cb = self.unwrap(); - let (s, n) = cb.agg_par( - || (fixedpoint::CT::zero(), 0usize), - |s, c| (s.0.add(c, &cb.pk), s.1 + 1), - |s1, s2| (s1.0.add(&s2.0, &cb.pk), s1.1 + s2.1), - ); - let mean = s.mul(&(1.0f64 / (n as f64)).encode(&cb.pk.coder), &cb.pk); - Cipherblock::new(block::Cipherblock { - pk: cb.pk.clone(), - data: vec![mean], - shape: vec![1], - }) - } } impl SK { @@ -261,22 +148,10 @@ impl SK { let array = a.0.as_ref().unwrap(); self.sk.decrypt_array(array) } - #[cfg(feature = "rayon")] - pub fn decrypt_array_par(&self, a: &Cipherblock) -> ArrayD { - let array = a.0.as_ref().unwrap(); - self.sk.decrypt_array_par(array) - } } impl PK { pub fn encrypt_array(&self, array: ArrayViewD) -> Cipherblock { Cipherblock::new(self.pk.encrypt_array(array)) } - #[cfg(feature = "rayon")] - pub fn encrypt_array_par( - &self, - array: ArrayViewD, - ) -> Cipherblock { - Cipherblock::new(self.pk.encrypt_array_par(array)) - } } diff --git a/rust/fate-tensor/src/lib.rs b/rust/fate-tensor/src/lib.rs index 170bc0cf81..60feeb91c0 100644 --- a/rust/fate-tensor/src/lib.rs +++ b/rust/fate-tensor/src/lib.rs @@ -3,6 +3,7 @@ pub mod cb; pub mod fixedpoint; pub mod math; pub mod paillier; +mod par; use bincode::{deserialize, serialize}; use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArrayDyn}; @@ -49,22 +50,6 @@ impl PK { fn encrypt_i32(&self, a: PyReadonlyArrayDyn) -> Cipherblock { self.encrypt_array(a.as_array()) } - #[cfg(feature = "rayon")] - fn encrypt_f64_par(&self, a: PyReadonlyArrayDyn) -> Cipherblock { - self.encrypt_array_par(a.as_array()) - } - #[cfg(feature = "rayon")] - fn encrypt_f32_par(&self, a: PyReadonlyArrayDyn) -> Cipherblock { - self.encrypt_array_par(a.as_array()) - } - #[cfg(feature = "rayon")] - fn encrypt_i64_par(&self, a: PyReadonlyArrayDyn) -> Cipherblock { - self.encrypt_array_par(a.as_array()) - } - #[cfg(feature = "rayon")] - fn encrypt_i32_par(&self, a: PyReadonlyArrayDyn) -> Cipherblock { - self.encrypt_array_par(a.as_array()) - } } /// secret key for paillier system used to encrypt arrays @@ -84,22 +69,6 @@ impl SK { fn decrypt_i32<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { self.decrypt_array(a).into_pyarray(py) } - #[cfg(feature = "rayon")] - fn decrypt_f64_par<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { - self.decrypt_array_par(a).into_pyarray(py) - } - #[cfg(feature = "rayon")] - fn decrypt_f32_par<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { - self.decrypt_array_par(a).into_pyarray(py) - } - #[cfg(feature = "rayon")] - fn decrypt_i64_par<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { - self.decrypt_array_par(a).into_pyarray(py) - } - #[cfg(feature = "rayon")] - fn decrypt_i32_par<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { - self.decrypt_array_par(a).into_pyarray(py) - } } /// methods for cipherblock @@ -266,192 +235,6 @@ impl Cipherblock { self.sum_cb() } - // rayon - - // add - #[cfg(feature = "rayon")] - pub fn add_cipherblock_par(&self, other: &Cipherblock) -> Cipherblock { - self.add_cb_par(other) - } - #[cfg(feature = "rayon")] - pub fn add_plaintext_f64_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.add_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn add_plaintext_f32_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.add_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn add_plaintext_i64_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.add_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn add_plaintext_i32_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.add_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn add_plaintext_scalar_f64_par(&self, other: f64) -> Cipherblock { - self.add_plaintext_scalar_par(other) - } - #[cfg(feature = "rayon")] - pub fn add_plaintext_scalar_f32_par(&self, other: f32) -> Cipherblock { - self.add_plaintext_scalar_par(other) - } - #[cfg(feature = "rayon")] - pub fn add_plaintext_scalar_i64_par(&self, other: i64) -> Cipherblock { - self.add_plaintext_scalar_par(other) - } - #[cfg(feature = "rayon")] - pub fn add_plaintext_scalar_i32_par(&self, other: i32) -> Cipherblock { - self.add_plaintext_scalar_par(other) - } - - // sub - #[cfg(feature = "rayon")] - pub fn sub_cipherblock_par(&self, other: &Cipherblock) -> Cipherblock { - self.sub_cb_par(other) - } - #[cfg(feature = "rayon")] - pub fn sub_plaintext_f64_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.sub_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn sub_plaintext_f32_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.sub_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn sub_plaintext_i64_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.sub_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn sub_plaintext_i32_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.sub_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn sub_plaintext_scalar_f64_par(&self, other: f64) -> Cipherblock { - self.sub_plaintext_scalar_par(other) - } - #[cfg(feature = "rayon")] - pub fn sub_plaintext_scalar_f32_par(&self, other: f32) -> Cipherblock { - self.sub_plaintext_scalar_par(other) - } - #[cfg(feature = "rayon")] - pub fn sub_plaintext_scalar_i64_par(&self, other: i64) -> Cipherblock { - self.sub_plaintext_scalar_par(other) - } - #[cfg(feature = "rayon")] - pub fn sub_plaintext_scalar_i32_par(&self, other: i32) -> Cipherblock { - self.sub_plaintext_scalar_par(other) - } - - // mul - #[cfg(feature = "rayon")] - pub fn mul_plaintext_f64_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.mul_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn mul_plaintext_f32_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.mul_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn mul_plaintext_i64_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.mul_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn mul_plaintext_i32_par(&self, other: PyReadonlyArrayDyn) -> Cipherblock { - self.mul_plaintext_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn mul_plaintext_scalar_f64_par(&self, other: f64) -> Cipherblock { - self.mul_plaintext_scalar_par(other) - } - #[cfg(feature = "rayon")] - pub fn mul_plaintext_scalar_f32_par(&self, other: f32) -> Cipherblock { - self.mul_plaintext_scalar_par(other) - } - #[cfg(feature = "rayon")] - pub fn mul_plaintext_scalar_i64_par(&self, other: i64) -> Cipherblock { - self.mul_plaintext_scalar_par(other) - } - #[cfg(feature = "rayon")] - pub fn mul_plaintext_scalar_i32_par(&self, other: i32) -> Cipherblock { - self.mul_plaintext_scalar_par(other) - } - - // matmul - #[cfg(feature = "rayon")] - pub fn matmul_plaintext_ix2_f64_par(&self, other: PyReadonlyArray2) -> Cipherblock { - self.matmul_plaintext_ix2_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn matmul_plaintext_ix2_f32_par(&self, other: PyReadonlyArray2) -> Cipherblock { - self.matmul_plaintext_ix2_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn matmul_plaintext_ix2_i64_par(&self, other: PyReadonlyArray2) -> Cipherblock { - self.matmul_plaintext_ix2_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn matmul_plaintext_ix2_i32_par(&self, other: PyReadonlyArray2) -> Cipherblock { - self.matmul_plaintext_ix2_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn rmatmul_plaintext_ix2_f64_par(&self, other: PyReadonlyArray2) -> Cipherblock { - self.rmatmul_plaintext_ix2_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn rmatmul_plaintext_ix2_f32_par(&self, other: PyReadonlyArray2) -> Cipherblock { - self.rmatmul_plaintext_ix2_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn rmatmul_plaintext_ix2_i64_par(&self, other: PyReadonlyArray2) -> Cipherblock { - self.rmatmul_plaintext_ix2_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn rmatmul_plaintext_ix2_i32_par(&self, other: PyReadonlyArray2) -> Cipherblock { - self.rmatmul_plaintext_ix2_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn matmul_plaintext_ix1_f64_par(&self, other: PyReadonlyArray1) -> Cipherblock { - self.matmul_plaintext_ix1_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn matmul_plaintext_ix1_f32_par(&self, other: PyReadonlyArray1) -> Cipherblock { - self.matmul_plaintext_ix1_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn matmul_plaintext_ix1_i64_par(&self, other: PyReadonlyArray1) -> Cipherblock { - self.matmul_plaintext_ix1_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn matmul_plaintext_ix1_i32_par(&self, other: PyReadonlyArray1) -> Cipherblock { - self.matmul_plaintext_ix1_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn rmatmul_plaintext_ix1_f64_par(&self, other: PyReadonlyArray1) -> Cipherblock { - self.rmatmul_plaintext_ix1_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn rmatmul_plaintext_ix1_f32_par(&self, other: PyReadonlyArray1) -> Cipherblock { - self.rmatmul_plaintext_ix1_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn rmatmul_plaintext_ix1_i64_par(&self, other: PyReadonlyArray1) -> Cipherblock { - self.rmatmul_plaintext_ix1_par(other.as_array()) - } - #[cfg(feature = "rayon")] - pub fn rmatmul_plaintext_ix1_i32_par(&self, other: PyReadonlyArray1) -> Cipherblock { - self.rmatmul_plaintext_ix1_par(other.as_array()) - } - // agg - #[cfg(feature = "rayon")] - pub fn sum_par(&self) -> Cipherblock { - self.sum_cb_par() - } - #[cfg(feature = "rayon")] - pub fn mean_par(&self) -> Cipherblock { - self.sum_cb_par() - } } #[pymodule] fn fate_tensor(_py: Python, m: &PyModule) -> PyResult<()> { @@ -459,5 +242,7 @@ fn fate_tensor(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(keygen, m)?)?; + + par::register(_py, m)?; Ok(()) } diff --git a/rust/fate-tensor/src/par/cb.rs b/rust/fate-tensor/src/par/cb.rs new file mode 100644 index 0000000000..33b902dc16 --- /dev/null +++ b/rust/fate-tensor/src/par/cb.rs @@ -0,0 +1,167 @@ +use super::{block, fixedpoint, fixedpoint::CouldCode, Cipherblock, PK, SK}; +use ndarray::{ArrayD, ArrayView1, ArrayView2, ArrayViewD}; + +fn operation_with_arrayview_dyn( + this: &Cipherblock, + other: ArrayViewD, + func: F, +) -> Cipherblock +where + F: Fn(&block::Cipherblock, ArrayViewD) -> block::Cipherblock, +{ + Cipherblock::new(func(this.unwrap(), other)) +} + +fn operation_with_cipherblock(this: &Cipherblock, other: &Cipherblock, func: F) -> Cipherblock +where + F: Fn(&block::Cipherblock, &block::Cipherblock) -> block::Cipherblock, +{ + let a = this.unwrap(); + let b = other.unwrap(); + Cipherblock::new(func(a, b)) +} + +fn operation_with_scalar(this: &Cipherblock, other: T, func: F) -> Cipherblock +where + F: Fn(&block::Cipherblock, T) -> block::Cipherblock, +{ + Cipherblock::new(func(this.unwrap(), other)) +} + +macro_rules! impl_ops_cipher_scalar { + ($name:ident,$fn:expr) => { + pub fn $name(&self, other: &fixedpoint::CT) -> Cipherblock { + operation_with_scalar(self, other, |lhs, rhs| { + block::Cipherblock::map_par(lhs, |c| $fn(c, rhs, &lhs.pk)) + }) + } + }; +} +macro_rules! impl_ops_plaintext_scalar { + ($name:ident,$fn:expr) => { + pub fn $name(&self, other: T) -> Cipherblock + where + T: CouldCode + Sync, + { + operation_with_scalar(self, other, |lhs, rhs| { + block::Cipherblock::map_par(lhs, |c| $fn(c, &rhs.encode(&lhs.pk.coder), &lhs.pk)) + }) + } + }; +} +macro_rules! impl_ops_cipher { + ($name:ident,$fn:expr) => { + pub fn $name(&self, other: &Cipherblock) -> Cipherblock { + operation_with_cipherblock(self, other, |lhs, rhs| { + block::Cipherblock::binary_cipherblock_cipherblock_par(lhs, rhs, $fn) + }) + } + }; +} +macro_rules! impl_ops_plain { + ($name:ident,$fn:expr) => { + pub fn $name(&self, other: ArrayViewD) -> Cipherblock + where + T: fixedpoint::CouldCode + Sync + Send, + { + operation_with_arrayview_dyn(self, other, |lhs, rhs| { + block::Cipherblock::binary_cipherblock_plaintext_par(lhs, rhs, $fn) + }) + } + }; +} +macro_rules! impl_ops_matmul { + ($name:ident, $fn:expr, $oty:ident) => { + pub fn $name(&self, other: $oty) -> Cipherblock { + Cipherblock::new($fn(self.unwrap(), other)) + } + }; +} +impl Cipherblock { + fn new(cb: block::Cipherblock) -> Cipherblock { + Cipherblock(Some(cb)) + } + fn unwrap(&self) -> &block::Cipherblock { + self.0.as_ref().unwrap() + } + impl_ops_cipher!(add_cb, fixedpoint::CT::add); + impl_ops_plain!(add_plaintext, fixedpoint::CT::add_pt); + impl_ops_cipher_scalar!(add_cipher_scalar, fixedpoint::CT::add); + impl_ops_plaintext_scalar!(add_plaintext_scalar, fixedpoint::CT::add_pt); + + impl_ops_cipher!(sub_cb, fixedpoint::CT::sub); + impl_ops_plain!(sub_plaintext, fixedpoint::CT::sub_pt); + impl_ops_cipher_scalar!(sub_cipher_scalar, fixedpoint::CT::add); + impl_ops_plaintext_scalar!(sub_plaintext_scalar, fixedpoint::CT::sub_pt); + + impl_ops_plain!(mul_plaintext, fixedpoint::CT::mul); + impl_ops_plaintext_scalar!(mul_plaintext_scalar, fixedpoint::CT::mul); + + // matmul + impl_ops_matmul!( + matmul_plaintext_ix1, + block::Cipherblock::matmul_plaintext_ix1_par, + ArrayView1 + ); + impl_ops_matmul!( + rmatmul_plaintext_ix1, + block::Cipherblock::rmatmul_plaintext_ix1_par, + ArrayView1 + ); + impl_ops_matmul!( + matmul_plaintext_ix2, + block::Cipherblock::matmul_plaintext_ix2_par, + ArrayView2 + ); + impl_ops_matmul!( + rmatmul_plaintext_ix2, + block::Cipherblock::rmatmul_plaintext_ix2_par, + ArrayView2 + ); +} + +impl Cipherblock { + pub fn sum_cb(&self) -> Cipherblock { + let cb = self.unwrap(); + let sum = cb.agg_par( + fixedpoint::CT::zero, + |s, c| s.add(c, &cb.pk), + |s1, s2| s1.add(&s2, &cb.pk), + ); + Cipherblock::new(block::Cipherblock { + pk: cb.pk.clone(), + data: vec![sum], + shape: vec![1], + }) + } + pub fn mean_cb(&self) -> Cipherblock { + let cb = self.unwrap(); + let (s, n) = cb.agg_par( + || (fixedpoint::CT::zero(), 0usize), + |s, c| (s.0.add(c, &cb.pk), s.1 + 1), + |s1, s2| (s1.0.add(&s2.0, &cb.pk), s1.1 + s2.1), + ); + let mean = s.mul(&(1.0f64 / (n as f64)).encode(&cb.pk.coder), &cb.pk); + Cipherblock::new(block::Cipherblock { + pk: cb.pk.clone(), + data: vec![mean], + shape: vec![1], + }) + } +} + +impl SK { + pub fn decrypt_array(&self, a: &Cipherblock) -> ArrayD { + let array = a.0.as_ref().unwrap(); + self.sk.decrypt_array_par(array) + } +} + +impl PK { + pub fn encrypt_array( + &self, + array: ArrayViewD, + ) -> Cipherblock { + Cipherblock::new(self.pk.encrypt_array_par(array)) + } +} diff --git a/rust/fate-tensor/src/par/mod.rs b/rust/fate-tensor/src/par/mod.rs new file mode 100644 index 0000000000..62ef3e094d --- /dev/null +++ b/rust/fate-tensor/src/par/mod.rs @@ -0,0 +1,235 @@ +use bincode::{deserialize, serialize}; +use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArrayDyn}; +use pyo3::prelude::*; +use pyo3::types::PyBytes; +use crate::fixedpoint; +use crate::block; + +mod cb; + +#[pyclass(module = "fate_tensor.par")] +pub struct Cipherblock(Option); + +#[pyclass(module = "fate_tensor.par")] +pub struct PK { + pk: fixedpoint::PK, +} + +#[pyclass(module = "fate_tensor.par")] +pub struct SK { + sk: fixedpoint::SK, +} + +#[pyfunction] +fn keygen(bit_size: u32) -> (PK, SK) { + let (sk, pk) = fixedpoint::keygen(bit_size); + (PK { pk }, SK { sk }) +} + +#[pyfunction] +fn set_num_threads(num_threads: usize) { + rayon::ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap(); +} + +#[pymethods] +impl PK { + fn encrypt_f64(&self, a: PyReadonlyArrayDyn) -> Cipherblock { + self.encrypt_array(a.as_array()) + } + fn encrypt_f32(&self, a: PyReadonlyArrayDyn) -> Cipherblock { + self.encrypt_array(a.as_array()) + } + fn encrypt_i64(&self, a: PyReadonlyArrayDyn) -> Cipherblock { + self.encrypt_array(a.as_array()) + } + fn encrypt_i32(&self, a: PyReadonlyArrayDyn) -> Cipherblock { + self.encrypt_array(a.as_array()) + } +} + +#[pymethods] +impl SK { + fn decrypt_f64<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { + self.decrypt_array(a).into_pyarray(py) + } + fn decrypt_f32<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { + self.decrypt_array(a).into_pyarray(py) + } + fn decrypt_i64<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { + self.decrypt_array(a).into_pyarray(py) + } + fn decrypt_i32<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { + self.decrypt_array(a).into_pyarray(py) + } +} + +#[pymethods] +impl Cipherblock { + #[new] + fn __new__() -> Self { + Cipherblock(None) + } + pub fn __getstate__(&self, py: Python) -> PyResult { + Ok(PyBytes::new(py, &serialize(&self.0).unwrap()).to_object(py)) + } + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + match state.extract::<&PyBytes>(py) { + Ok(s) => { + self.0 = deserialize(s.as_bytes()).unwrap(); + Ok(()) + } + Err(e) => Err(e), + } + } + // add + pub fn add_cipherblock(&self, other: &Cipherblock) -> Cipherblock { + self.add_cb(other) + } + pub fn add_plaintext_f64(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.add_plaintext(other.as_array()) + } + pub fn add_plaintext_f32(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.add_plaintext(other.as_array()) + } + pub fn add_plaintext_i64(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.add_plaintext(other.as_array()) + } + pub fn add_plaintext_i32(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.add_plaintext(other.as_array()) + } + pub fn add_plaintext_scalar_f64(&self, other: f64) -> Cipherblock { + self.add_plaintext_scalar(other) + } + pub fn add_plaintext_scalar_f32(&self, other: f32) -> Cipherblock { + self.add_plaintext_scalar(other) + } + pub fn add_plaintext_scalar_i64(&self, other: i64) -> Cipherblock { + self.add_plaintext_scalar(other) + } + pub fn add_plaintext_scalar_i32(&self, other: i32) -> Cipherblock { + self.add_plaintext_scalar(other) + } + + // sub + pub fn sub_cipherblock(&self, other: &Cipherblock) -> Cipherblock { + self.sub_cb(other) + } + pub fn sub_plaintext_f64(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.sub_plaintext(other.as_array()) + } + pub fn sub_plaintext_f32(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.sub_plaintext(other.as_array()) + } + pub fn sub_plaintext_i64(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.sub_plaintext(other.as_array()) + } + pub fn sub_plaintext_i32(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.sub_plaintext(other.as_array()) + } + pub fn sub_plaintext_scalar_f64(&self, other: f64) -> Cipherblock { + self.sub_plaintext_scalar(other) + } + pub fn sub_plaintext_scalar_f32(&self, other: f32) -> Cipherblock { + self.sub_plaintext_scalar(other) + } + pub fn sub_plaintext_scalar_i64(&self, other: i64) -> Cipherblock { + self.sub_plaintext_scalar(other) + } + pub fn sub_plaintext_scalar_i32(&self, other: i32) -> Cipherblock { + self.sub_plaintext_scalar(other) + } + + // mul + pub fn mul_plaintext_f64(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.mul_plaintext(other.as_array()) + } + pub fn mul_plaintext_f32(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.mul_plaintext(other.as_array()) + } + pub fn mul_plaintext_i64(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.mul_plaintext(other.as_array()) + } + pub fn mul_plaintext_i32(&self, other: PyReadonlyArrayDyn) -> Cipherblock { + self.mul_plaintext(other.as_array()) + } + pub fn mul_plaintext_scalar_f64(&self, other: f64) -> Cipherblock { + self.mul_plaintext_scalar(other) + } + pub fn mul_plaintext_scalar_f32(&self, other: f32) -> Cipherblock { + self.mul_plaintext_scalar(other) + } + pub fn mul_plaintext_scalar_i64(&self, other: i64) -> Cipherblock { + self.mul_plaintext_scalar(other) + } + pub fn mul_plaintext_scalar_i32(&self, other: i32) -> Cipherblock { + self.mul_plaintext_scalar(other) + } + + // matmul + pub fn matmul_plaintext_ix2_f64(&self, other: PyReadonlyArray2) -> Cipherblock { + self.matmul_plaintext_ix2(other.as_array()) + } + pub fn matmul_plaintext_ix2_f32(&self, other: PyReadonlyArray2) -> Cipherblock { + self.matmul_plaintext_ix2(other.as_array()) + } + pub fn matmul_plaintext_ix2_i64(&self, other: PyReadonlyArray2) -> Cipherblock { + self.matmul_plaintext_ix2(other.as_array()) + } + pub fn matmul_plaintext_ix2_i32(&self, other: PyReadonlyArray2) -> Cipherblock { + self.matmul_plaintext_ix2(other.as_array()) + } + pub fn rmatmul_plaintext_ix2_f64(&self, other: PyReadonlyArray2) -> Cipherblock { + self.rmatmul_plaintext_ix2(other.as_array()) + } + pub fn rmatmul_plaintext_ix2_f32(&self, other: PyReadonlyArray2) -> Cipherblock { + self.rmatmul_plaintext_ix2(other.as_array()) + } + pub fn rmatmul_plaintext_ix2_i64(&self, other: PyReadonlyArray2) -> Cipherblock { + self.rmatmul_plaintext_ix2(other.as_array()) + } + pub fn rmatmul_plaintext_ix2_i32(&self, other: PyReadonlyArray2) -> Cipherblock { + self.rmatmul_plaintext_ix2(other.as_array()) + } + pub fn matmul_plaintext_ix1_f64(&self, other: PyReadonlyArray1) -> Cipherblock { + self.matmul_plaintext_ix1(other.as_array()) + } + pub fn matmul_plaintext_ix1_f32(&self, other: PyReadonlyArray1) -> Cipherblock { + self.matmul_plaintext_ix1(other.as_array()) + } + pub fn matmul_plaintext_ix1_i64(&self, other: PyReadonlyArray1) -> Cipherblock { + self.matmul_plaintext_ix1(other.as_array()) + } + pub fn matmul_plaintext_ix1_i32(&self, other: PyReadonlyArray1) -> Cipherblock { + self.matmul_plaintext_ix1(other.as_array()) + } + pub fn rmatmul_plaintext_ix1_f64(&self, other: PyReadonlyArray1) -> Cipherblock { + self.rmatmul_plaintext_ix1(other.as_array()) + } + pub fn rmatmul_plaintext_ix1_f32(&self, other: PyReadonlyArray1) -> Cipherblock { + self.rmatmul_plaintext_ix1(other.as_array()) + } + pub fn rmatmul_plaintext_ix1_i64(&self, other: PyReadonlyArray1) -> Cipherblock { + self.rmatmul_plaintext_ix1(other.as_array()) + } + pub fn rmatmul_plaintext_ix1_i32(&self, other: PyReadonlyArray1) -> Cipherblock { + self.rmatmul_plaintext_ix1(other.as_array()) + } + // agg + pub fn sum(&self) -> Cipherblock { + self.sum_cb() + } + pub fn mean(&self) -> Cipherblock { + self.sum_cb() + } +} + +pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { + let submodule_par = PyModule::new(py, "par")?; + submodule_par.add_function(wrap_pyfunction!(keygen, submodule_par)?)?; + submodule_par.add_function(wrap_pyfunction!(set_num_threads, submodule_par)?)?; + m.add_submodule(submodule_par)?; + py.import("sys")? + .getattr("modules")? + .set_item("fate_tensor.par", submodule_par)?; + Ok(()) +} diff --git a/rust/fate-tensor/tests/test_base.py b/rust/fate-tensor/tests/test_base.py index 31211e7fd3..7f3c36d6a1 100644 --- a/rust/fate-tensor/tests/test_base.py +++ b/rust/fate-tensor/tests/test_base.py @@ -1,25 +1,43 @@ +import importlib import operator -import fate_tensor +import cachetools import numpy as np import pytest -import cachetools -pk, sk = fate_tensor.keygen(1024) +def get_suites(): + suites = [] + packages = ["fate_tensor", "fate_tensor.par"] + for package in packages: + module = importlib.import_module(package) + suites.append(Suite(module.keygen)) + return suites -def encrypt(fp, par, data): - if par: - return getattr(pk, f"encrypt_{fp}_par")(data) - else: - return getattr(pk, f"encrypt_{fp}")(data) +class Suite: + def __init__(self, keygen) -> None: + self.pk, self.sk = keygen(1024) -def decrypt(fp, par, data): - if par: - return getattr(sk, f"decrypt_{fp}_par")(data) - else: - return getattr(sk, f"decrypt_{fp}")(data) + def encrypt(self, fp, data): + return getattr(self.pk, f"encrypt_{fp}")(data) + + def decrypt(self, fp, data): + return getattr(self.sk, f"decrypt_{fp}")(data) + + def cipher_op(self, ciphertext, op): + return getattr(ciphertext, f"{op.__name__}_cipherblock") + + def plaintest_op(self, ciphertext, op, fp, scalar=False): + if scalar: + return getattr(ciphertext, f"{op.__name__}_plaintext_scalar_{fp}") + else: + return getattr(ciphertext, f"{op.__name__}_plaintext_{fp}") + + +def pytest_generate_tests(metafunc): + if "suite" in metafunc.fixturenames: + metafunc.parametrize("suite", get_suites()) @cachetools.cached({}) @@ -30,7 +48,9 @@ def data(fp, index, shape=(3, 5), scalar=False) -> np.ndarray: if fp == "f32": return np.random.random(shape).astype(np.float32) - 0.5 if fp == "i64": - return np.random.randint(low=2147483648, high=2147483648000, size=shape, dtype=np.int64) + return np.random.randint( + low=2147483648, high=2147483648000, size=shape, dtype=np.int64 + ) if fp == "i32": return np.random.randint(low=-100, high=100, size=shape, dtype=np.int32) else: @@ -39,118 +59,88 @@ def data(fp, index, shape=(3, 5), scalar=False) -> np.ndarray: if fp == "f32": return np.random.random(1).astype(np.float32)[0] - 0.5 if fp == "i64": - return np.random.randint(low=2147483648, high=2147483648000, size=1, dtype=np.int64)[0] + return np.random.randint( + low=2147483648, high=2147483648000, size=1, dtype=np.int64 + )[0] if fp == "i32": return np.random.randint(low=-100, high=100, size=1, dtype=np.int32)[0] -def test_keygen(): - fate_tensor.keygen(1024) - - -def cipher_op(ciphertext, op, par): - if par: - return getattr(ciphertext, f"{op.__name__}_cipherblock_par") - else: - return getattr(ciphertext, f"{op.__name__}_cipherblock") - - -def plaintest_op(ciphertext, op, par, fp, scalar=False): - if par: - if scalar: - return getattr(ciphertext, f"{op.__name__}_plaintext_scalar_{fp}_par") - else: - return getattr(ciphertext, f"{op.__name__}_plaintext_{fp}_par") - else: - if scalar: - return getattr(ciphertext, f"{op.__name__}_plaintext_scalar_{fp}") - else: - return getattr(ciphertext, f"{op.__name__}_plaintext_{fp}") - - -@pytest.mark.parametrize("par", [False, True]) @pytest.mark.parametrize("fp", ["f64", "f32", "i32", "i64"]) -def test_cipher(par, fp): - e = decrypt(fp, par, encrypt(fp, par, data(fp, 0))) +def test_cipher(suite: Suite, fp): + e = suite.decrypt(fp, suite.encrypt(fp, data(fp, 0))) c = data(fp, 0) assert np.isclose(e, c).all() -@pytest.mark.parametrize("par", [False, True]) @pytest.mark.parametrize("fp", ["f64", "f32", "i64", "i32"]) @pytest.mark.parametrize("op", [operator.add, operator.sub]) -def test_cipher_op(par, fp, op): - ea = encrypt(fp, par, data(fp, 0)) - eb = encrypt(fp, par, data(fp, 1)) - result = cipher_op(ea, op, par)(eb) +def test_cipher_op(suite, fp, op): + ea = suite.encrypt(fp, data(fp, 0)) + eb = suite.encrypt(fp, data(fp, 1)) + result = suite.cipher_op(ea, op)(eb) expect = op(data(fp, 0), data(fp, 1)) - diff = decrypt(fp, par, result) - expect + diff = suite.decrypt(fp, result) - expect assert np.isclose(diff, 0).all() -@pytest.mark.parametrize("par", [False, True]) @pytest.mark.parametrize("fp", ["f64", "f32", "i64", "i32"]) @pytest.mark.parametrize("op", [operator.add, operator.sub, operator.mul]) -def test_plaintext_op(par, fp, op): - ea = encrypt(fp, par, data(fp, 0)) +def test_plaintext_op(suite: Suite, fp, op): + ea = suite.encrypt(fp, data(fp, 0)) b = data(fp, 1) - result = plaintest_op(ea, op, par, fp)(b) + result = suite.plaintest_op(ea, op, fp)(b) expect = op(data(fp, 0), data(fp, 1)) - diff = decrypt(fp, par, result) - expect + diff = suite.decrypt(fp, result) - expect assert np.isclose(diff, 0).all() -@pytest.mark.parametrize("par", [False, True]) @pytest.mark.parametrize("fp", ["f64", "f32", "i64", "i32"]) @pytest.mark.parametrize("op", [operator.add, operator.sub, operator.mul]) -def test_plaintext_op_scalar(par, fp, op): - ea = encrypt(fp, par, data(fp, 0)) +def test_plaintext_op_scalar(suite: Suite, fp, op): + ea = suite.encrypt(fp, data(fp, 0)) b = data(fp, 1, scalar=True) - result = plaintest_op(ea, op, par, fp, True)(b) + result = suite.plaintest_op(ea, op, fp, True)(b) expect = op(data(fp, 0), data(fp, 1, scalar=True)) - diff = decrypt(fp, par, result) - expect + diff = suite.decrypt(fp, result) - expect assert np.isclose(diff, 0).all() -@pytest.mark.parametrize("par", [False, True]) @pytest.mark.parametrize("fp", ["f64", "f32", "i64", "i32"]) -def test_matmul_ix2(par, fp): +def test_matmul_ix2(suite: Suite, fp): a = data(fp, 0, (11, 17)) b = data(fp, 0, (17, 5)) - ea = encrypt(fp, par, a) + ea = suite.encrypt(fp, a) eab = getattr(ea, f"matmul_plaintext_ix2_{fp}")(b) - ab = decrypt(fp, par, eab) + ab = suite.decrypt(fp, eab) assert np.isclose(ab, a @ b).all() -@pytest.mark.parametrize("par", [False, True]) @pytest.mark.parametrize("fp", ["f64", "f32", "i64", "i32"]) -def test_matmul_ix1(par, fp): +def test_matmul_ix1(suite: Suite, fp): a = data(fp, 0, (11, 17)) b = data(fp, 0, 17) - ea = encrypt(fp, par, a) + ea = suite.encrypt(fp, a) eab = getattr(ea, f"matmul_plaintext_ix1_{fp}")(b) - ab = decrypt(fp, par, eab) + ab = suite.decrypt(fp, eab) assert np.isclose(ab, (a @ b).reshape(ab.shape)).all() -@pytest.mark.parametrize("par", [False, True]) @pytest.mark.parametrize("fp", ["f64", "f32", "i64", "i32"]) -def test_rmatmul_ix2(par, fp): +def test_rmatmul_ix2(suite: Suite, fp): a = data(fp, 0, (11, 17)) b = data(fp, 0, (17, 5)) - eb = encrypt(fp, par, b) + eb = suite.encrypt(fp, b) reab = getattr(eb, f"rmatmul_plaintext_ix2_{fp}")(a) - rab = decrypt(fp, par, reab) + rab = suite.decrypt(fp, reab) assert np.isclose(rab, a @ b).all() -@pytest.mark.parametrize("par", [False, True]) @pytest.mark.parametrize("fp", ["f64", "f32", "i64", "i32"]) -def test_rmatmul_ix1(par, fp): +def test_rmatmul_ix1(suite: Suite, fp): a = data(fp, 0, 17) b = data(fp, 0, (17, 5)) - eb = encrypt(fp, par, b) + eb = suite.encrypt(fp, b) reab = getattr(eb, f"rmatmul_plaintext_ix1_{fp}")(a) - rab = decrypt(fp, par, reab) + rab = suite.decrypt(fp, reab) assert np.isclose(rab, (a @ b).reshape(rab.shape)).all() From 9d4a908b642bdffd7d1dc44e97dbccaf48e759dd Mon Sep 17 00:00:00 2001 From: weiwee Date: Thu, 21 Jul 2022 03:52:17 -0800 Subject: [PATCH 02/11] refact: use metaclass simplify block/tensor impl Signed-off-by: weiwee --- .../tensor/impl/blocks/_metaclass.py | 363 ++++++++++++++++++ .../tensor/impl/blocks/cpu_paillier_block.py | 50 +++ .../blocks/multithread_cpu_paillier_block.py | 50 +++ .../blocks/rust_paillier_block/__init__.py | 230 ----------- .../tensor/impl/tensor/_metaclass.py | 156 ++++++++ .../tensor/impl/tensor/distributed.py | 4 +- .../tensor/impl/tensor/multithread.py | 114 ------ .../impl/tensor/multithread_cpu_tensor.py | 38 ++ 8 files changed, 659 insertions(+), 346 deletions(-) create mode 100644 python/fate_arch/tensor/impl/blocks/_metaclass.py create mode 100644 python/fate_arch/tensor/impl/blocks/cpu_paillier_block.py create mode 100644 python/fate_arch/tensor/impl/blocks/multithread_cpu_paillier_block.py delete mode 100644 python/fate_arch/tensor/impl/blocks/rust_paillier_block/__init__.py create mode 100644 python/fate_arch/tensor/impl/tensor/_metaclass.py delete mode 100644 python/fate_arch/tensor/impl/tensor/multithread.py create mode 100644 python/fate_arch/tensor/impl/tensor/multithread_cpu_tensor.py diff --git a/python/fate_arch/tensor/impl/blocks/_metaclass.py b/python/fate_arch/tensor/impl/blocks/_metaclass.py new file mode 100644 index 0000000000..9797a26865 --- /dev/null +++ b/python/fate_arch/tensor/impl/blocks/_metaclass.py @@ -0,0 +1,363 @@ +import pickle + +import numpy as np +import torch + +from ...abc.block import ( + PHEBlockABC, + PHEBlockCipherABC, + PHEBlockDecryptorABC, + PHEBlockEncryptorABC, +) + + +def _impl_ops(class_obj, method_name, ops): + def func(self, other): + cb = ops(self._cb, other, class_obj) + if cb is NotImplemented: + return NotImplemented + else: + return class_obj(cb) + + func.__name__ = method_name + return func + + +def _impl_init(): + def __init__(self, cb): + self._cb = cb + + return __init__ + + +def _impl_encryptor_init(): + def __init__(self, pk): + self._pk = pk + + return __init__ + + +def _impl_decryptor_init(): + def __init__(self, sk): + self._sk = sk + + return __init__ + + +def _impl_encrypt(pheblock_cls, fpbloke_cls, encrypt_op): + def encrypt(self, other) -> pheblock_cls: + if isinstance(other, fpbloke_cls): + return pheblock_cls(encrypt_op(self._pk, other.numpy())) + + raise NotImplementedError(f"type {other} not supported") + + return encrypt + + +def _impl_decrypt(pheblock_cls, fpbloke_cls, decrypt_op): + def decrypt(self, other, dtype=np.float64) -> fpbloke_cls: + if isinstance(other, pheblock_cls): + return torch.from_numpy(decrypt_op(self._sk, other._cb, dtype)) + raise NotImplementedError(f"type {other} not supported") + + return decrypt + + +def _impl_serialize(): + def serialize(self) -> bytes: + return pickle.dumps(self._cb) + + return serialize + + +def _impl_keygen(encrypt_cls, decrypt_cls, keygen_op): + @classmethod + def keygen(cls, key_length=1024): + pk, sk = keygen_op(bit_size=key_length) + return (encrypt_cls(pk), decrypt_cls(sk)) + + return keygen + + +def _maybe_setattr(obj, name, value): + if not hasattr(obj, name): + setattr(obj, name, value) + + +def phe_keygen_metaclass(encrypt_cls, decrypt_cls, keygen_op): + class PHEKeygenMetaclass(type): + def __new__(cls, name, bases, dict): + keygen_cls = super().__new__(cls, name, bases, dict) + + setattr( + keygen_cls, "keygen", _impl_keygen(encrypt_cls, decrypt_cls, keygen_op) + ) + return keygen_cls + + return PHEKeygenMetaclass + + +def phe_decryptor_metaclass(pheblock_cls, fpblock_cls): + class PHEDecryptorMetaclass(type): + def __new__(cls, name, bases, dict): + decryptor_cls = super().__new__( + cls, name, bases, dict + ) + + setattr(decryptor_cls, "__init__", _impl_decryptor_init()) + setattr( + decryptor_cls, + "decrypt", + _impl_decrypt( + pheblock_cls, fpblock_cls, PHEDecryptorMetaclass._decrypt_numpy + ), + ) + return decryptor_cls + + @staticmethod + def _decrypt_numpy(sk, cb, dtype): + if dtype == np.float64: + return sk.decrypt_f64(cb) + if dtype == np.float32: + return sk.decrypt_f32(cb) + if dtype == np.int64: + return sk.decrypt_i64(cb) + if dtype == np.int32: + return sk.decrypt_i32(cb) + raise NotImplementedError("dtype = {dtype}") + + return PHEDecryptorMetaclass + + +def phe_encryptor_metaclass(pheblock_cls, fpblock_cls): + class PHEEncryptorMetaclass(type): + def __new__(cls, name, bases, dict): + encryptor_cls = super().__new__( + cls, name, bases, dict + ) + + setattr(encryptor_cls, "__init__", _impl_encryptor_init()) + setattr( + encryptor_cls, + "encrypt", + _impl_encrypt( + pheblock_cls, fpblock_cls, PHEEncryptorMetaclass._encrypt_numpy + ), + ) + return encryptor_cls + + @staticmethod + def _encrypt_numpy(pk, other): + if is_ndarray(other): + if is_nd_float64(other): + return pk.encrypt_f64(other) + if is_nd_float32(other): + return pk.encrypt_f32(other) + if is_nd_int64(other): + return pk.encrypt_i64(other) + if is_nd_int32(other): + return pk.encrypt_i32(other) + raise NotImplementedError(f"type {other} {other.dtype} not supported") + + return PHEEncryptorMetaclass + + +class PHEBlockMetaclass(type): + def __new__(cls, name, bases, dict): + class_obj = super().__new__(cls, name, bases, dict) + + setattr(class_obj, "__init__", _impl_init()) + _maybe_setattr(class_obj, "serialize", _impl_serialize()) + for impl_name, ops in { + "__add__": PHEBlockMetaclass._add, + "__radd__": PHEBlockMetaclass._radd, + "__sub__": PHEBlockMetaclass._sub, + "__rsub__": PHEBlockMetaclass._rsub, + "__mul__": PHEBlockMetaclass._mul, + "__rmul__": PHEBlockMetaclass._rmul, + "__matmul__": PHEBlockMetaclass._matmul, + "__rmatmul__": PHEBlockMetaclass._rmatmul, + }.items(): + _maybe_setattr(class_obj, impl_name, _impl_ops(class_obj, impl_name, ops)) + + return class_obj + + @staticmethod + def _rmatmul(cb, other, class_obj): + if isinstance(other, torch.Tensor): + other = other.numpy() + if isinstance(other, np.ndarray): + if len(other.shape) == 2: + if is_nd_float64(other): + return cb.rmatmul_plaintext_ix2_f64(other) + if is_nd_float32(other): + return cb.rmatmul_plaintext_ix2_f32(other) + if is_nd_int64(other): + return cb.rmatmul_plaintext_ix2_i64(other) + if is_nd_int32(other): + return cb.rmatmul_plaintext_ix2_i32(other) + if len(other.shape) == 1: + if is_nd_float64(other): + return cb.rmatmul_plaintext_ix1_f64(other) + if is_nd_float32(other): + return cb.rmatmul_plaintext_ix1_f32(other) + if is_nd_int64(other): + return cb.rmatmul_plaintext_ix1_i64(other) + if is_nd_int32(other): + return cb.rmatmul_plaintext_ix1_i32(other) + return NotImplemented + + @staticmethod + def _matmul(cb, other, class_obj): + if isinstance(other, torch.Tensor): + other = other.numpy() + if is_ndarray(other): + if len(other.shape) == 2: + if is_nd_float64(other): + return cb.matmul_plaintext_ix2_f64(other) + if is_nd_float32(other): + return cb.matmul_plaintext_ix2_f32(other) + if is_nd_int64(other): + return cb.matmul_plaintext_ix2_i64(other) + if is_nd_int32(other): + return cb.matmul_plaintext_ix2_i32(other) + if len(other.shape) == 1: + if is_nd_float64(other): + return cb.matmul_plaintext_ix1_f64(other) + if is_nd_float32(other): + return cb.matmul_plaintext_ix1_f32(other) + if is_nd_int64(other): + return cb.matmul_plaintext_ix1_i64(other) + if is_nd_int32(other): + return cb.matmul_plaintext_ix1_i32(other) + return NotImplemented + + @staticmethod + def _mul(cb, other, class_obj): + if isinstance(other, torch.Tensor): + other = other.numpy() + if is_ndarray(other): + if is_nd_float64(other): + return cb.mul_plaintext_f64(other) + if is_nd_float32(other): + return cb.mul_plaintext_f32(other) + if is_nd_int64(other): + return cb.mul_plaintext_i64(other) + if is_nd_int32(other): + return cb.mul_plaintext_i32(other) + raise NotImplemented + if is_float(other): + return cb.mul_plaintext_scalar_f64(other) + if is_float32(other): + return cb.mul_plaintext_scalar_f32(other) + if is_int(other): + return cb.mul_plaintext_scalar_i64(other) + if is_int32(other): + return cb.mul_plaintext_scalar_i32(other) + return NotImplemented + + @staticmethod + def _sub(cb, other, class_obj): + if isinstance(other, torch.Tensor): + other = other.numpy() + if is_ndarray(other): + if is_nd_float64(other): + return cb.sub_plaintext_f64(other) + if is_nd_float32(other): + return cb.sub_plaintext_f32(other) + if is_nd_int64(other): + return cb.sub_plaintext_i64(other) + if is_nd_int32(other): + return cb.sub_plaintext_i32(other) + return NotImplemented + + if isinstance(other, class_obj): + return cb.sub_cipherblock(other._cb) + if is_float(other): + return cb.sub_plaintext_scalar_f64(other) + if is_float32(other): + return cb.sub_plaintext_scalar_f32(other) + if is_int(other): + return cb.sub_plaintext_scalar_i64(other) + if is_int32(other): + return cb.sub_plaintext_scalar_i32(other) + + return NotImplemented + + @staticmethod + def _add(cb, other, class_obj): + if isinstance(other, torch.Tensor): + other = other.numpy() + if is_ndarray(other): + if is_nd_float64(other): + return cb.add_plaintext_f64(other) + if is_nd_float32(other): + return cb.add_plaintext_f32(other) + if is_nd_int64(other): + return cb.add_plaintext_i64(other) + if is_nd_int32(other): + return cb.add_plaintext_i32(other) + return NotImplemented + + if isinstance(other, class_obj): + return cb.add_cipherblock(other._cb) + if is_float(other): + return cb.add_plaintext_scalar_f64(other) + if is_float32(other): + return cb.add_plaintext_scalar_f32(other) + if is_int(other): + return cb.add_plaintext_scalar_i64(other) + if is_int32(other): + return cb.add_plaintext_scalar_i32(other) + + return NotImplemented + + @staticmethod + def _radd(cb, other, class_obj): + return PHEBlockMetaclass._add(cb, other, class_obj) + + @staticmethod + def _rsub(cb, other, class_obj): + return PHEBlockMetaclass._add( + PHEBlockMetaclass._mul(cb, -1, class_obj), other, class_obj + ) + + @staticmethod + def _rmul(cb, other, class_obj): + return PHEBlockMetaclass._mul(cb, other, class_obj) + + +def is_ndarray(v): + return isinstance(v, np.ndarray) + + +def is_float(v): + return isinstance(v, (float, np.float64)) + + +def is_float32(v): + return isinstance(v, np.float32) + + +def is_int(v): + return isinstance(v, (int, np.int64)) + + +def is_int32(v): + return isinstance(v, np.int32) + + +def is_nd_float64(v): + return v.dtype == np.float64 + + +def is_nd_float32(v): + return v.dtype == np.float32 + + +def is_nd_int64(v): + return v.dtype == np.int64 + + +def is_nd_int32(v): + return v.dtype == np.int32 diff --git a/python/fate_arch/tensor/impl/blocks/cpu_paillier_block.py b/python/fate_arch/tensor/impl/blocks/cpu_paillier_block.py new file mode 100644 index 0000000000..bddd8ed727 --- /dev/null +++ b/python/fate_arch/tensor/impl/blocks/cpu_paillier_block.py @@ -0,0 +1,50 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import fate_tensor.par +import torch + +from ._metaclass import ( + PHEBlockMetaclass, + phe_decryptor_metaclass, + phe_encryptor_metaclass, + phe_keygen_metaclass, +) + + +class PaillierBlock(metaclass=PHEBlockMetaclass): + pass + + +class BlockPaillierEncryptor( + metaclass=phe_encryptor_metaclass(PaillierBlock, torch.Tensor) +): + pass + + +class BlockPaillierDecryptor( + metaclass=phe_decryptor_metaclass(PaillierBlock, torch.Tensor) +): + pass + + +class BlockPaillierCipher( + metaclass=phe_keygen_metaclass( + BlockPaillierEncryptor, BlockPaillierDecryptor, fate_tensor.par.keygen + ) +): + pass diff --git a/python/fate_arch/tensor/impl/blocks/multithread_cpu_paillier_block.py b/python/fate_arch/tensor/impl/blocks/multithread_cpu_paillier_block.py new file mode 100644 index 0000000000..8864750a59 --- /dev/null +++ b/python/fate_arch/tensor/impl/blocks/multithread_cpu_paillier_block.py @@ -0,0 +1,50 @@ +# +# Copyright 2019 The FATE Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import fate_tensor +import torch + +from ._metaclass import ( + PHEBlockMetaclass, + phe_decryptor_metaclass, + phe_encryptor_metaclass, + phe_keygen_metaclass, +) + + +class PaillierBlock(metaclass=PHEBlockMetaclass): + pass + + +class BlockPaillierEncryptor( + metaclass=phe_encryptor_metaclass(PaillierBlock, torch.Tensor) +): + pass + + +class BlockPaillierDecryptor( + metaclass=phe_decryptor_metaclass(PaillierBlock, torch.Tensor) +): + pass + + +class BlockPaillierCipher( + metaclass=phe_keygen_metaclass( + BlockPaillierEncryptor, BlockPaillierDecryptor, fate_tensor.keygen + ) +): + pass diff --git a/python/fate_arch/tensor/impl/blocks/rust_paillier_block/__init__.py b/python/fate_arch/tensor/impl/blocks/rust_paillier_block/__init__.py deleted file mode 100644 index 05aa49300e..0000000000 --- a/python/fate_arch/tensor/impl/blocks/rust_paillier_block/__init__.py +++ /dev/null @@ -1,230 +0,0 @@ -# -# Copyright 2019 The FATE Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import pickle -import typing - -import numpy as np -import torch - -from ....abc.block import ( - PHEBlockABC, - PHEBlockCipherABC, - PHEBlockDecryptorABC, - PHEBlockEncryptorABC, -) - -# maybe need wrap? -FPBlock = torch.Tensor - -# TODO: move numpy related apis to rust side - - -class PaillierBlock(PHEBlockABC): - def __init__(self, cb) -> None: - self._cb = cb - - def create(self, cb): - return PaillierBlock(cb) - - def __add__(self, other) -> "PaillierBlock": - if isinstance(other, torch.Tensor): - other = other.numpy() - if isinstance(other, np.ndarray): - if other.dtype == np.float64: - return self.create(self._cb.add_plaintext_f64(other)) - if other.dtype == np.float32: - return self.create(self._cb.add_plaintext_f32(other)) - if other.dtype == np.int64: - return self.create(self._cb.add_plaintext_i64(other)) - if other.dtype == np.int64: - return self.create(self._cb.add_plaintext_i32(other)) - raise NotImplemented(f"dtype {other.dtype} not supported") - if isinstance(other, PaillierBlock): - return self.create(self._cb.add_cipherblock(other._cb)) - if isinstance(other, (float, np.float64)): - return self.create(self._cb.add_plaintext_scalar_f64(other)) - if isinstance(other, np.float32): - return self.create(self._cb.add_plaintext_scalar_f32(other)) - if isinstance(other, (int, np.int64)): - return self.create(self._cb.add_plaintext_scalar_i64(other)) - if isinstance(other, np.int32): - return self.create(self._cb.add_plaintext_scalar_i32(other)) - raise NotImplemented(f"type {other} not supported") - - def __radd__(self, other) -> "PaillierBlock": - return self.__add__(other) - - def __sub__(self, other) -> "PaillierBlock": - if isinstance(other, torch.Tensor): - other = other.numpy() - if isinstance(other, np.ndarray): - if other.dtype == np.float64: - return self.create(self._cb.sub_plaintext_f64(other)) - if other.dtype == np.float32: - return self.create(self._cb.sub_plaintext_f32(other)) - if other.dtype == np.int64: - return self.create(self._cb.sub_plaintext_i64(other)) - if other.dtype == np.int32: - return self.create(self._cb.sub_plaintext_i32(other)) - raise NotImplemented(f"dtype {other.dtype} not supported") - if isinstance(other, PaillierBlock): - return self.create(self._cb.sub_cipherblock(other._cb)) - if isinstance(other, (float, np.float64)): - return self.create(self._cb.sub_plaintext_scalar_f64(other)) - if isinstance(other, np.float32): - return self.create(self._cb.sub_plaintext_scalar_f32(other)) - if isinstance(other, (int, np.int64)): - return self.create(self._cb.sub_plaintext_scalar_i64(other)) - if isinstance(other, np.int32): - return self.create(self._cb.sub_plaintext_scalar_i32(other)) - raise NotImplemented(f"type {other} not supported") - - def __rsub__(self, other) -> "PaillierBlock": - return self.__mul__(-1).__add__(other) - - def __mul__(self, other) -> "PaillierBlock": - if isinstance(other, torch.Tensor): - other = other.numpy() - if isinstance(other, np.ndarray): - if other.dtype == np.float64: - return self.create(self._cb.mul_plaintext_f64(other)) - if other.dtype == np.float32: - return self.create(self._cb.mul_plaintext_f32(other)) - if other.dtype == np.int64: - return self.create(self._cb.mul_plaintext_i64(other)) - if other.dtype == np.int32: - return self.create(self._cb.mul_plaintext_i32(other)) - raise NotImplemented(f"dtype {other.dtype} not supported") - if isinstance(other, (float, np.float64)): - return self.create(self._cb.mul_plaintext_scalar_f64(other)) - if isinstance(other, np.float32): - return self.create(self._cb.mul_plaintext_scalar_f32(other)) - if isinstance(other, (int, np.int64)): - return self.create(self._cb.mul_plaintext_scalar_i64(other)) - if isinstance(other, np.int32): - return self.create(self._cb.mul_plaintext_scalar_i32(other)) - raise NotImplemented(f"type {other} not supported") - - def __rmul__(self, other) -> "PaillierBlock": - return self.__mul__(other) - - def __matmul__(self, other) -> "PaillierBlock": - if isinstance(other, torch.Tensor): - other = other.numpy() - if isinstance(other, np.ndarray): - if len(other.shape) == 2: - if other.dtype == np.float64: - return self.create(self._cb.matmul_plaintext_ix2_f64(other)) - if other.dtype == np.float32: - return self.create(self._cb.matmul_plaintext_ix2_f32(other)) - if other.dtype == np.int64: - return self.create(self._cb.matmul_plaintext_ix2_i64(other)) - if other.dtype == np.int32: - return self.create(self._cb.matmul_plaintext_ix2_i32(other)) - if len(other.shape) == 1: - if other.dtype == np.float64: - return self.create(self._cb.matmul_plaintext_ix1_f64(other)) - if other.dtype == np.float32: - return self.create(self._cb.matmul_plaintext_ix1_f32(other)) - if other.dtype == np.int64: - return self.create(self._cb.matmul_plaintext_ix1_i64(other)) - if other.dtype == np.int32: - return self.create(self._cb.matmul_plaintext_ix1_i32(other)) - return NotImplemented - - def __rmatmul__(self, other) -> "PaillierBlock": - if isinstance(other, torch.Tensor): - other = other.numpy() - if isinstance(other, np.ndarray): - if len(other.shape) == 2: - if other.dtype == np.float64: - return self.create(self._cb.rmatmul_plaintext_ix2_f64(other)) - if other.dtype == np.float32: - return self.create(self._cb.rmatmul_plaintext_ix2_f32(other)) - if other.dtype == np.int64: - return self.create(self._cb.rmatmul_plaintext_ix2_i64(other)) - if other.dtype == np.int32: - return self.create(self._cb.rmatmul_plaintext_ix2_i32(other)) - if len(other.shape) == 1: - if other.dtype == np.float64: - return self.create(self._cb.rmatmul_plaintext_ix1_f64(other)) - if other.dtype == np.float32: - return self.create(self._cb.rmatmul_plaintext_ix1_f32(other)) - if other.dtype == np.int64: - return self.create(self._cb.rmatmul_plaintext_ix1_i64(other)) - if other.dtype == np.int32: - return self.create(self._cb.rmatmul_plaintext_ix1_i32(other)) - return NotImplemented - - def serialize(self) -> bytes: - return pickle.dumps(self._cb) - - -class BlockPaillierEncryptor(PHEBlockEncryptorABC): - def __init__(self, pk) -> None: - self._pk = pk - - def encrypt(self, other) -> PaillierBlock: - if isinstance(other, FPBlock): - return PaillierBlock(self._encrypt_numpy(other.numpy())) - - raise NotImplementedError(f"type {other} not supported") - - def _encrypt_numpy(self, other): - if isinstance(other, np.ndarray): - if other.dtype == np.float64: - return self._pk.encrypt_f64(other) - if other.dtype == np.float32: - return self._pk.encrypt_f32(other) - if other.dtype == np.int64: - return self._pk.encrypt_i64(other) - if other.dtype == np.int32: - return self._pk.encrypt_i32(other) - raise NotImplementedError(f"type {other} {other.dtype} not supported") - - -class BlockPaillierDecryptor(PHEBlockDecryptorABC): - def __init__(self, sk) -> None: - self._sk = sk - - def decrypt(self, other: PaillierBlock, dtype=np.float64): - return torch.from_numpy(self._decrypt_numpy(other._cb, dtype)) - - def _decrypt_numpy(self, cb, dtype=np.float64): - if dtype == np.float64: - return self._sk.decrypt_f64(cb) - if dtype == np.float32: - return self._sk.decrypt_f32(cb) - if dtype == np.int64: - return self._sk.decrypt_i64(cb) - if dtype == np.int32: - return self._sk.decrypt_i32(cb) - raise NotImplementedError("dtype = {dtype}") - - -class BlockPaillierCipher(PHEBlockCipherABC): - @classmethod - def keygen( - cls, key_length=1024 - ) -> typing.Tuple[BlockPaillierEncryptor, BlockPaillierDecryptor]: - import fate_tensor - - pubkey, prikey = fate_tensor.keygen(bit_size=key_length) - return ( - BlockPaillierEncryptor(pubkey), - BlockPaillierDecryptor(prikey), - ) diff --git a/python/fate_arch/tensor/impl/tensor/_metaclass.py b/python/fate_arch/tensor/impl/tensor/_metaclass.py new file mode 100644 index 0000000000..f27166efab --- /dev/null +++ b/python/fate_arch/tensor/impl/tensor/_metaclass.py @@ -0,0 +1,156 @@ +import typing + +from ...abc.tensor import ( + PHECipherABC, + PHEDecryptorABC, + PHEEncryptorABC, + PHETensorABC, +) + + +def phe_tensor_metaclass(fp_cls): + class PHETensorMetaclass(type): + def __new__(cls, name, bases, dict): + phe_cls = super().__new__(cls, name, bases, dict) + + def __init__(self, block) -> None: + self._block = block + self._is_transpose = False + + setattr(phe_cls, "__init__", __init__) + + @property + def T(self) -> phe_cls: + transposed = phe_cls(self._block) + transposed._is_transpose = not self._is_transpose + return transposed + + setattr(phe_cls, "T", T) + + def serialize(self) -> bytes: + # todo: impl me + ... + + setattr(phe_cls, "serialize", serialize) + + def __add__(self, other): + if isinstance(other, phe_cls): + other = other._block + + if isinstance(other, (phe_cls, fp_cls)): + return phe_cls(self._block + other) + elif isinstance(other, (int, float)): + return phe_cls(self._block + other) + else: + return NotImplemented + + def __radd__(self, other): + return __add__(other, self) + + setattr(phe_cls, "__add__", __add__) + setattr(phe_cls, "__radd__", __radd__) + + def __sub__(self, other): + if isinstance(other, phe_cls): + other = other._block + + if isinstance(other, (phe_cls, fp_cls)): + return phe_cls(self._block - other) + elif isinstance(other, (int, float)): + return phe_cls(self._block - other) + else: + return NotImplemented + + def __rsub__(self, other): + return __sub__(other, self) + + setattr(phe_cls, "__sub__", __sub__) + setattr(phe_cls, "__rsub__", __rsub__) + + def __mul__(self, other): + if isinstance(other, fp_cls): + return phe_cls(self._block * other) + elif isinstance(other, (int, float)): + return phe_cls(self._block * other) + else: + return NotImplemented + + def __rmul__(self, other): + return __mul__(other, self) + + setattr(phe_cls, "__mul__", __mul__) + setattr(phe_cls, "__rmul__", __rmul__) + + def __matmul__(self, other): + if isinstance(other, fp_cls): + return phe_cls(self._block @ other) + return NotImplemented + + def __rmatmul__(self, other): + if isinstance(other, fp_cls): + return phe_cls(other @ self._block) + return NotImplemented + + setattr(phe_cls, "__matmul__", __matmul__) + setattr(phe_cls, "__rmatmul__", __rmatmul__) + + return phe_cls + + return PHETensorMetaclass + + +def phe_tensor_encryptor_metaclass(phe_cls, fp_cls): + class PHETensorEncryptorMetaclass(type): + def __new__(cls, name, bases, dict): + phe_encrypt_cls = super().__new__(cls, name, bases, dict) + + def __init__(self, block_encryptor): + self._block_encryptor = block_encryptor + + def encrypt(self, tensor: fp_cls) -> phe_cls: + return phe_cls(self._block_encryptor.encrypt(tensor)) + + setattr(phe_encrypt_cls, "__init__", __init__) + setattr(phe_encrypt_cls, "encrypt", encrypt) + return phe_encrypt_cls + + return PHETensorEncryptorMetaclass + + +def phe_tensor_decryptor_metaclass(phe_cls, fp_cls): + class PHETensorDecryptorMetaclass(type): + def __new__(cls, name, bases, dict): + phe_decrypt_cls = super().__new__(cls, name, bases, dict) + + def __init__(self, block_decryptor) -> None: + self._block_decryptor = block_decryptor + + def decrypt(self, tensor: phe_cls) -> fp_cls: + return self._block_decryptor.decrypt(tensor._block) + + setattr(phe_decrypt_cls, "__init__", __init__) + setattr(phe_decrypt_cls, "decrypt", decrypt) + return phe_decrypt_cls + + return PHETensorDecryptorMetaclass + + +def phe_tensor_cipher_metaclass( + phe_cls, phe_encrypt_cls, phe_decrypt_cls, block_cipher, +): + class PHETensorCipherMetaclass(type): + def __new__(cls, name, bases, dict): + phe_cipher_cls = super().__new__(cls, name, bases, dict) + + @classmethod + def keygen(cls, **kwargs) -> typing.Tuple[phe_encrypt_cls, phe_decrypt_cls]: + block_encrytor, block_decryptor = block_cipher.keygen(**kwargs) + return ( + phe_encrypt_cls(block_encrytor), + phe_decrypt_cls(block_decryptor), + ) + + setattr(phe_cipher_cls, "keygen", keygen) + return phe_cipher_cls + + return PHETensorCipherMetaclass diff --git a/python/fate_arch/tensor/impl/tensor/distributed.py b/python/fate_arch/tensor/impl/tensor/distributed.py index a7ff528cd6..2326e8f862 100644 --- a/python/fate_arch/tensor/impl/tensor/distributed.py +++ b/python/fate_arch/tensor/impl/tensor/distributed.py @@ -2,7 +2,7 @@ from typing import Union from ...abc.tensor import ( - FPTensorABC, + FPTensorProtocol, PHECipherABC, PHEDecryptorABC, PHEEncryptorABC, @@ -12,7 +12,7 @@ Numeric = typing.Union[int, float] -class FPTensorDistributed(FPTensorABC): +class FPTensorDistributed(FPTensorProtocol): """ Demo of Distributed Fixed Presicion Tensor """ diff --git a/python/fate_arch/tensor/impl/tensor/multithread.py b/python/fate_arch/tensor/impl/tensor/multithread.py deleted file mode 100644 index e932ed033a..0000000000 --- a/python/fate_arch/tensor/impl/tensor/multithread.py +++ /dev/null @@ -1,114 +0,0 @@ -import typing -from typing import Union -import operator -import torch - -from ...abc.tensor import ( - FPTensorProtocol, - PHECipherABC, - PHEDecryptorABC, - PHEEncryptorABC, - PHETensorABC, -) - -FPTensorLocal = torch.Tensor -Numeric = typing.Union[int, float] -TYPEFP = typing.Union[Numeric, "FPTensorLocal"] -TYPECT = typing.Union[TYPEFP, "PHETensorLocal"] - - -class PHETensorLocal(PHETensorABC): - def __init__(self, block) -> None: - """ """ - self._block = block - self._is_transpose = False - - def __add__(self, other: TYPECT) -> "PHETensorLocal": - if isinstance(other, PHETensorLocal): - other = other._block - return _phe_binary_op(self._block, other, operator.add, PHE_OP_TYPES) - - def __radd__(self, other: TYPECT) -> "PHETensorLocal": - if isinstance(other, PHETensorLocal): - other = other._block - return _phe_binary_op(other, self._block, operator.add, PHE_OP_TYPES) - - def __sub__(self, other: TYPECT) -> "PHETensorLocal": - if isinstance(other, PHETensorLocal): - other = other._block - return _phe_binary_op(self._block, other, operator.sub, PHE_OP_TYPES) - - def __rsub__(self, other: TYPECT) -> "PHETensorLocal": - if isinstance(other, PHETensorLocal): - other = other._block - return _phe_binary_op(other, self._block, operator.sub, PHE_OP_TYPES) - - def __mul__(self, other: TYPEFP) -> "PHETensorLocal": - return _phe_binary_op(self._block, other, operator.mul, PHE_OP_PLAIN_TYPES) - - def __rmul__(self, other: TYPEFP) -> "PHETensorLocal": - return _phe_binary_op(other, self._block, operator.mul, PHE_OP_PLAIN_TYPES) - - def __matmul__(self, other: FPTensorLocal) -> "PHETensorLocal": - if isinstance(other, FPTensorLocal): - return PHETensorLocal(operator.matmul(self._block, other)) - return NotImplemented - - def __rmatmul__(self, other: FPTensorLocal) -> "PHETensorLocal": - if isinstance(other, FPTensorLocal): - return PHETensorLocal(operator.matmul(other, self._block)) - return NotImplemented - - def T(self) -> "PHETensorLocal": - transposed = PHETensorLocal(self._block) - transposed._is_transpose = not self._is_transpose - return transposed - - def serialize(self): - # todo: impl me - ... - - -class PaillierPHEEncryptorLocal(PHEEncryptorABC): - def __init__(self, block_encryptor) -> None: - self._block_encryptor = block_encryptor - - def encrypt(self, tensor: FPTensorLocal) -> PHETensorLocal: - return PHETensorLocal(self._block_encryptor.encrypt(tensor)) - - -class PaillierPHEDecryptorLocal(PHEDecryptorABC): - def __init__(self, block_decryptor) -> None: - self._block_decryptor = block_decryptor - - def decrypt(self, tensor: PHETensorLocal) -> FPTensorLocal: - return self._block_decryptor.decrypt(tensor._block) - - -class PaillierPHECipherLocal(PHECipherABC): - @classmethod - def keygen( - cls, **kwargs - ) -> typing.Tuple[PaillierPHEEncryptorLocal, PaillierPHEDecryptorLocal]: - from ..blocks.rust_paillier_block import BlockPaillierCipher - - block_encrytor, block_decryptor = BlockPaillierCipher.keygen(**kwargs) - return ( - PaillierPHEEncryptorLocal(block_encrytor), - PaillierPHEDecryptorLocal(block_decryptor), - ) - -def _phe_binary_op(self, other, func, types): - if type(other) not in types: - return NotImplemented - elif isinstance(other, (PHETensorLocal, FPTensorLocal)): - return PHETensorLocal(func(self, other)) - elif isinstance(other, (int, float)): - return PHETensorLocal(func(self, other)) - else: - return NotImplemented - - -PHE_OP_PLAIN_TYPES = {int, float, FPTensorLocal, PHETensorLocal} -PHE_OP_TYPES = {int, float, FPTensorLocal, PHETensorLocal} -FP_OP_TYPES = {int, float, FPTensorLocal} diff --git a/python/fate_arch/tensor/impl/tensor/multithread_cpu_tensor.py b/python/fate_arch/tensor/impl/tensor/multithread_cpu_tensor.py new file mode 100644 index 0000000000..ffe1fbf26c --- /dev/null +++ b/python/fate_arch/tensor/impl/tensor/multithread_cpu_tensor.py @@ -0,0 +1,38 @@ +from ._metaclass import ( + phe_tensor_metaclass, + phe_tensor_encryptor_metaclass, + phe_tensor_decryptor_metaclass, + phe_tensor_cipher_metaclass, +) + +import torch +from ..blocks.multithread_cpu_paillier_block import BlockPaillierCipher + +FPTensorLocal = torch.Tensor + + +class PHETensorLocal(metaclass=phe_tensor_metaclass(FPTensorLocal)): + ... + + +class PaillierPHEEncryptorLocal( + metaclass=phe_tensor_encryptor_metaclass(PHETensorLocal, FPTensorLocal) +): + ... + + +class PaillierPHEDecryptorLocal( + metaclass=phe_tensor_decryptor_metaclass(PHETensorLocal, FPTensorLocal) +): + ... + + +class PaillierPHECipherLocal( + metaclass=phe_tensor_cipher_metaclass( + PHETensorLocal, + PaillierPHEEncryptorLocal, + PaillierPHEDecryptorLocal, + BlockPaillierCipher, + ) +): + ... From e6e8f8bf51dae96744892e4f93afbddafd017dbe Mon Sep 17 00:00:00 2001 From: weiwee Date: Thu, 21 Jul 2022 03:57:04 -0800 Subject: [PATCH 03/11] feat: add context Signed-off-by: weiwee --- python/fate_arch/tensor/__init__.py | 15 +- python/fate_arch/tensor/_context.py | 264 --------------- python/fate_arch/tensor/_federation.py | 74 ----- python/fate_arch/tensor/_parties.py | 84 +++++ python/fate_arch/tensor/_tensor.py | 418 ++++++++++++++++++++++-- python/fate_arch/tensor/abc/block.py | 10 +- python/fate_arch/tensor/abc/tensor.py | 50 ++- python/federatedml/ml/toy/enterpoint.py | 65 ++-- 8 files changed, 525 insertions(+), 455 deletions(-) delete mode 100644 python/fate_arch/tensor/_context.py delete mode 100644 python/fate_arch/tensor/_federation.py create mode 100644 python/fate_arch/tensor/_parties.py diff --git a/python/fate_arch/tensor/__init__.py b/python/fate_arch/tensor/__init__.py index 9c1f939a7d..2ab55e0f39 100644 --- a/python/fate_arch/tensor/__init__.py +++ b/python/fate_arch/tensor/__init__.py @@ -1,19 +1,20 @@ from ._dataloader import LabeledDataloaderWrapper, UnlabeledDataloaderWrapper -from ._federation import ARBITER, GUEST, HOST -from ._context import Context, CipherKind -from ._tensor import ( - FPTensor, - PHETensor, -) +from ._parties import Parties, PreludeParty +from ._tensor import CipherKind, Context, FPTensor, PHETensor + +ARBITER = PreludeParty.ARBITER +GUEST = PreludeParty.GUEST +HOST = PreludeParty.HOST __all__ = [ "FPTensor", "PHETensor", + "Parties", "ARBITER", "GUEST", "HOST", "Context", "LabeledDataloaderWrapper", "UnlabeledDataloaderWrapper", - "CipherKind" + "CipherKind", ] diff --git a/python/fate_arch/tensor/_context.py b/python/fate_arch/tensor/_context.py deleted file mode 100644 index 92b85a017d..0000000000 --- a/python/fate_arch/tensor/_context.py +++ /dev/null @@ -1,264 +0,0 @@ -from typing import TYPE_CHECKING, Any, List, Tuple -from typing_extensions import Literal -from contextlib import contextmanager - -from ._federation import _Parties -import typing -from enum import Enum - -from fate_arch.common import Party -from fate_arch.federation.transfer_variable import IterationGC -from fate_arch.session import get_session -from .abc.tensor import PHEEncryptorABC, PHEDecryptorABC - -# for better type checking -if TYPE_CHECKING: - from ._tensor import FPTensor, PHETensor - - -class ExcutionState: - def __init__(self, tag) -> None: - self._tag = tag - - def generate_tag(self) -> str: - return self._tag - - -class DefaultState(ExcutionState): - def __init__(self) -> None: - super().__init__("default") - - -class FitState(ExcutionState): - ... - - -class PredictState(ExcutionState): - ... - - -class IterationState(ExcutionState): - def __init__(self, tag: str, index: int = -1) -> None: - self._tag = tag - self._index = index - - def generate_tag(self) -> str: - return self._tag - - -class CipherKind(Enum): - PHE = 1 - PHE_PAILLIER = 2 - - -class Device(Enum): - CPU = 1 - GPU = 2 - FPGA = 3 - - -class Context: - def __init__(self, start_iter_num=-1) -> None: - self._device = None - self._iter_num = start_iter_num - self._push_gc_dict = {} - self._pull_gc_dict = {} - - self._binded_variables = {} - self._current_iter = None - self._flowid = None - - self._execution_state = DefaultState() - - self._cypher_utils = CypherUtils(self) - self._tensor_utils = TensorUtils(self) - - def device_init(self, **kwargs): - self._device = Device.CPU - - def device(self) -> Device: - if self._device is None: - raise RuntimeError(f"init device first") - return self._device - - @property - def cypher_utils(self): - return self._cypher_utils - - @property - def tensor_utils(self): - return self._tensor_utils - - @contextmanager - def create_iter(self, max_iter, template="{i}"): - # cache previous state - previous_state = self._execution_state - current_tag = self.generate_federation_tag() - - def _state_iterator(): - for i in range(max_iter): - # the tags in the iteration need to be distinguishable - template_formated = template.format(i=i) - self._execution_state = IterationState(f"{current_tag}.{template_formated}", i) - yield i - - yield _state_iterator() - # recover state - self._execution_state = previous_state - - def generate_federation_tag(self): - return self._execution_state.generate_tag() - - def remote(self, target: _Parties, key: str, value): - self._push(target.parties, key, value) - return self - - def get(self, source: _Parties, key: str): - return self._pull(source.parties, key)[0] - - def get_multi(self, source: _Parties, key: str) -> List: - return self._pull(source.parties, key) - - def _push(self, parties: List[Party], key, value): - if key not in self._push_gc_dict: - self._push_gc_dict[key] = IterationGC() - get_session().federation.remote( - v=value, - name=key, - tag=self.generate_federation_tag(), - parties=parties, - gc=self._push_gc_dict[key], - ) - - def _pull(self, parties: List[Party], key): - if key not in self._pull_gc_dict: - self._pull_gc_dict[key] = IterationGC() - return get_session().federation.get( - name=key, - tag=self.generate_federation_tag(), - parties=parties, - gc=self._pull_gc_dict[key], - ) - - -class TensorUtils: - """utils for tensor operation such as: - 1. creation: zeros, ones, random - 2. recv from remote: get, get_multi - - Notes: - 1. methods perfix with `phe_` is bound to `PHETensor` - 2. others is bound to `FPTensor` - """ - - def __init__(self, ctx: Context) -> None: - self._ctx = ctx - - def zeros(self, shape) -> "FPTensor": - ... - - def get(self, source: _Parties, key: str) -> "FPTensor": - from ._tensor import FPTensor - - tensor = self._ctx.get(source, key) - if not isinstance(tensor, FPTensor): - raise ValueError( - f"{PHETensor.__name__} expected while {type(tensor).__name__} got" - ) - return tensor - - def get_multi(self, source: _Parties, key: str) -> typing.List["FPTensor"]: - from ._tensor import FPTensor - - tensors = self._ctx.get_multi(source, key) - for tensor in tensors: - if not isinstance(tensor, PHETensor): - raise ValueError( - f"{FPTensor.__name__} expected while {type(tensor).__name__} got" - ) - return tensors - - def phe_get(self, source: _Parties, key: str) -> "PHETensor": - from ._tensor import PHETensor - - tensor = self._ctx.get(source, key) - if not isinstance(tensor, PHETensor): - raise ValueError( - f"{PHETensor.__name__} expected while {type(tensor).__name__} got" - ) - return tensor - - def phe_get_multi(self, source: _Parties, key: str) -> typing.List["PHETensor"]: - from ._tensor import PHETensor - - tensors = self._ctx.get_multi(source, key) - for tensor in tensors: - if not isinstance(tensor, PHETensor): - raise ValueError( - f"{PHETensor.__name__} expected while {type(tensor).__name__} got" - ) - return tensors - - -class CypherUtils: - def __init__(self, ctx: Context) -> None: - self._ctx = ctx - - @typing.overload - def keygen( - self, kind: Literal[CipherKind.PHE], key_length: int - ) -> Tuple["PHEEncryptor", "PHEDecryptor"]: - ... - - @typing.overload - def keygen(self, kind: CipherKind, **kwargs) -> Any: - ... - - def keygen(self, kind, key_length: int, **kwargs): - if kind == CipherKind.PHE or kind == CipherKind.PHE_PAILLIER: - if self._ctx._device == Device.CPU: - from .impl.tensor.multithread import PaillierPHECipherLocal - - encryptor, decryptor = PaillierPHECipherLocal().keygen( - key_length=key_length - ) - return PHEEncryptor(encryptor), PHEDecryptor(decryptor) - else: - raise NotImplementedError(f"keygen for kind `{kind}` is not implemented") - - def phe_get_encryptor(self, source: _Parties, key: str) -> "PHEEncryptor": - encryptor = self._ctx.get(source, key) - ... - - -class PHEEncryptor: - def __init__(self, encryptor: PHEEncryptorABC) -> None: - self._encryptor = encryptor - - def encrypt(self, tensor: "FPTensor"): - from ._tensor import PHETensor - - return PHETensor(tensor._ctx, self._encryptor.encrypt(tensor._tensor)) - - @classmethod - def get(cls, ctx: Context, source: _Parties, key: str) -> "PHEEncryptor": - return PHEEncryptor(ctx.get(source, key)) - - @classmethod - def get_multi( - cls, ctx: Context, source: _Parties, key: str - ) -> List["PHEEncryptor"]: - return [PHEEncryptor(encryptor) for encryptor in ctx.get_multi(source, key)] - - def remote(self, ctx: Context, target: _Parties, key: str): - return ctx.remote(target, key, self._encryptor) - - -class PHEDecryptor: - def __init__(self, decryptor: PHEDecryptorABC) -> None: - self._decryptor = decryptor - - def decrypt(self, tensor: "PHETensor") -> "FPTensor": - from ._tensor import FPTensor - - return FPTensor(tensor._ctx, self._decryptor.decrypt(tensor._tensor)) diff --git a/python/fate_arch/tensor/_federation.py b/python/fate_arch/tensor/_federation.py deleted file mode 100644 index 1f6edc792a..0000000000 --- a/python/fate_arch/tensor/_federation.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import List, Union - -from fate_arch.common import Party -from fate_arch.session import get_parties - - -def _get_role_parties(role: str): - return get_parties().roles_to_parties([role], strict=False) - - -class _RoleIndexedParty: - def __init__(self, role: str, index: int) -> None: - assert index >= 0, "index should >= 0" - self._role = role - self._index = index - - @property - def party(self) -> Party: - parties = _get_role_parties(self._role) - if 0 <= self._index < len(parties): - return parties[self._index] - raise KeyError( - f"index `{self._index}` out of bound `0 <= index < {len(parties)}`" - ) - - -class _Parties: - def __init__( - self, parties: List[Union[str, Party, _RoleIndexedParty, "_Parties"]] - ) -> None: - self._parties = parties - - def _reverse(self): - self._parties.reverse() - return self - - @property - def parties(self) -> List[Party]: - flatten = [] - for p in self._parties: - if isinstance(p, str) and (p == "guest" or p == "host" or p == "arbiter"): - flatten.extend(_get_role_parties(p)) - elif isinstance(p, Party): - flatten.append(p) - elif isinstance(p, _RoleIndexedParty): - flatten.append(p.party) - elif isinstance(p, _Parties): - flatten.extend(p.parties) - return flatten - - def __add__(self, other) -> "_Parties": - if isinstance(other, Party): - return _Parties([self, other]) - elif isinstance(other, list): - return _Parties([self, *other]) - else: - raise ValueError(f"can't add `{other}`") - - def __radd__(self, other) -> "_Parties": - return self.__add__(other)._reverse() - - -class _Role(_Parties): - def __init__(self, role: str) -> None: - self._role = role - super().__init__([role]) - - def __getitem__(self, key) -> _RoleIndexedParty: - return _RoleIndexedParty(self._role, key) - - -ARBITER = _Role("arbiter") -GUEST = _Role("guest") -HOST = _Role("host") diff --git a/python/fate_arch/tensor/_parties.py b/python/fate_arch/tensor/_parties.py new file mode 100644 index 0000000000..5cb3cfaf81 --- /dev/null +++ b/python/fate_arch/tensor/_parties.py @@ -0,0 +1,84 @@ +import enum +from typing import List + +from fate_arch.common import Party +from fate_arch.session import get_parties + + +class Parties: + def __init__(self, flag) -> None: + self._flag = flag + + def contains_hosts(self) -> bool: + return bool(self._flag & 1) + + def contains_arbiter(self) -> bool: + return bool(self._flag & 2) + + def contains_guest(self) -> bool: + return bool(self._flag & 4) + + def contains_host(self) -> bool: + return bool(self._flag & 8) + + @property + def indexes(self) -> List[int]: + return [i for i, e in enumerate(bin(self._flag)[::-1]) if e == "1"] + + @classmethod + def get_name(cls, i): + if i < 4: + return {0: "HOSTS", 1: "ARBITER", 2: "GUEST", 3: "HOST"}[i] + else: + return f"HOST{i-3}" + + def __or__(self, other): + return Parties(self._flag | other._flag) + + def __ror__(self, other): + return Parties(self._flag | other._flag) + + def __hash__(self) -> int: + return self._flag + + def __eq__(self, o) -> bool: + return self._flag == o._flag + + def __str__(self): + readable = "|".join([self.get_name(i) for i in self.indexes]) + return f"4b}): {readable}>" + + def __repr__(self): + return self.__str__() + + def __getitem__(self, key): + if self._flag == 1 and isinstance(key, int) and key >= 0: + return Parties(1 << (key + 3)) + raise TypeError("not subscriptable") + + def _get_role_parties(self, role: str): + return get_parties().roles_to_parties([role], strict=False) + + def get_parties(self) -> List[Party]: + parties = [] + if self._flag & 2: + parties.extend(self._get_role_parties("arbiter")) + if self._flag & 4: + parties.extend(self._get_role_parties("guest")) + if self._flag & 1: + parties.extend(self._get_role_parties("host")) + else: + host_bit_int = self._flag >> 3 + if host_bit_int: + hosts = self._get_role_parties("host") + for i, e in enumerate(bin(host_bit_int)[::-1]): + if e == "1": + parties.append(hosts[i]) + return parties + + +class PreludeParty(Parties, enum.Flag): + HOSTS = 1 + ARBITER = 2 + GUEST = 4 + HOST = 8 diff --git a/python/fate_arch/tensor/_tensor.py b/python/fate_arch/tensor/_tensor.py index 7e953cd95f..feb30529f5 100644 --- a/python/fate_arch/tensor/_tensor.py +++ b/python/fate_arch/tensor/_tensor.py @@ -1,63 +1,422 @@ -import typing -from typing import Any, Union -from ._federation import _Parties -from ._context import Context, PHEEncryptor, PHEDecryptor -from .abc.tensor import FPTensorABC, PHETensorABC +import json +from contextlib import contextmanager +from enum import Enum +from typing import ( + Any, + Callable, + Generator, + List, + Optional, + Tuple, + TypeVar, + Union, + overload, +) + +import torch +from fate_arch.common import Party +from fate_arch.federation.transfer_variable import IterationGC +from fate_arch.session import get_session +from typing_extensions import Literal + +from ._parties import Parties, PreludeParty +from .abc.tensor import PHEDecryptorABC, PHEEncryptorABC, PHETensorABC + + +class NamespaceState: + def __init__(self, namespace) -> None: + self._namespace = namespace + + def get_namespce(self) -> str: + return self._namespace + + def sub_namespace(self, namespace): + return f"{self._namespace}.{namespace}" + + +class FitState(NamespaceState): + ... + + +class PredictState(NamespaceState): + ... + + +class IterationState(NamespaceState): + ... + + +class CipherKind(Enum): + PHE = 1 + PHE_PAILLIER = 2 + + +class Device(Enum): + CPU = 1 + GPU = 2 + FPGA = 3 + + +T = TypeVar("T") + + +class Future: + """ + get maybe async in future, in this version, + we wrap obj to support explicit typing and check + """ + + def __init__(self, inside) -> None: + self._inside = inside + + def unwrap_tensor(self) -> "FPTensor": + + assert isinstance(self._inside, FPTensor) + return self._inside + + def unwrap_phe_encryptor(self) -> "PHEEncryptor": + assert isinstance(self._inside, PHEEncryptor) + return self._inside + + def unwrap_phe_tensor(self) -> "PHETensor": + + assert isinstance(self._inside, PHETensor) + return self._inside + + def unwrap(self, check: Optional[Callable[[T], bool]] = None) -> T: + if check is not None and not check(self._inside): + raise TypeError(f"`{self._inside}` check failed") + return self._inside + + +class Futures: + def __init__(self, insides) -> None: + self._insides = insides + + def unwrap_tensors(self) -> List["FPTensor"]: + + for t in self._insides: + assert isinstance(t, FPTensor) + return self._insides + + def unwrap_phe_tensors(self) -> List["PHETensor"]: + + for t in self._insides: + assert isinstance(t, PHETensor) + return self._insides + + def unwrap(self, check: Optional[Callable[[T], bool]] = None) -> List[T]: + if check is not None: + for i, t in enumerate(self._insides): + if not check(t): + raise TypeError(f"{i}th element `{self._insides}` check failed") + return self._insides + + +class _ContextInside: + def __init__(self, cpn_input) -> None: + self._device = None + self._push_gc_dict = {} + self._pull_gc_dict = {} + + self._flowid = None + + self._roles = cpn_input.roles + self._job_parameters = cpn_input.job_parameters + self._parameters = cpn_input.parameters + self._flow_feeded_parameters = cpn_input.flow_feeded_parameters + + @property + def is_guest(self): + return self._roles["local"]["role"] == "guest" + + @property + def is_host(self): + return self._roles["local"]["role"] == "host" + + @property + def is_arbiter(self): + return self._roles["local"]["role"] == "arbiter" + + @property + def party(self): + role = self._roles["local"]["role"] + party_id = self._roles["local"]["party_id"] + return Party(role, party_id) + + def get_or_set_push_gc(self, key): + if key not in self._push_gc_dict: + self._push_gc_dict[key] = IterationGC() + return self._push_gc_dict[key] + + def get_or_set_pull_gc(self, key): + if key not in self._push_gc_dict: + self._pull_gc_dict[key] = IterationGC() + return self._pull_gc_dict[key] + + def describe(self): + return dict( + party=self.party, + job_parameters=self._job_parameters, + parameters=self._parameters, + flow_feeded_parameters=self._flow_feeded_parameters, + ) + + +class Context: + def __init__(self, inside: _ContextInside, namespace: str) -> None: + self._inside = inside + self._namespace_state = NamespaceState(namespace) + + @classmethod + def from_cpn_input(cls, cpn_input): + states = _ContextInside(cpn_input) + namespace = "fate" + return Context(states, namespace) + + def describe(self): + return json.dumps( + dict( + states=self._inside.describe(), + ) + ) + + @property + def party(self): + return self._inside.party + + @property + def role(self): + return self.party.role + + @property + def party_id(self): + return self.party.party_id + + @property + def is_guest(self): + return self._inside.is_guest + + @property + def is_host(self): + return self._inside.is_guest + + @property + def is_arbiter(self): + return self._inside.is_guest + + def device_init(self, **kwargs): + self._device = Device.CPU + + def device(self) -> Device: + if self._device is None: + raise RuntimeError(f"init device first") + return self._device + + def current_namespace(self): + return self._namespace_state.get_namespce() + + def push(self, target: Parties, key: str, value): + return self._push(target, key, value) + + def pull( + self, + source: Literal[PreludeParty.GUEST, PreludeParty.HOST, PreludeParty.ARBITER], + key: str, + ) -> Future: + return Future(self._pull(source, key)[0]) + + def pulls(self, source: Parties, key: str) -> Futures: + return Futures(self._pull(source, key)) + + def _push(self, parties: Parties, key, value): + get_session().federation.remote( + v=value, + name=key, + tag=self.current_namespace(), + parties=parties.get_parties(), + gc=self._inside.get_or_set_push_gc(key), + ) + + def _pull(self, parties: Parties, key): + return get_session().federation.get( + name=key, + tag=self.current_namespace(), + parties=parties.get_parties(), + gc=self._inside.get_or_set_pull_gc(key), + ) + + @overload + def keygen( + self, kind: Literal[CipherKind.PHE], key_length: int + ) -> Tuple["PHEEncryptor", "PHEDecryptor"]: + ... + + @overload + def keygen(self, kind: CipherKind, **kwargs) -> Any: + ... + + def keygen(self, kind, key_length: int, **kwargs): + if kind == CipherKind.PHE or kind == CipherKind.PHE_PAILLIER: + if self._device == Device.CPU: + from .impl.tensor.multithread import PaillierPHECipherLocal + + encryptor, decryptor = PaillierPHECipherLocal().keygen( + key_length=key_length + ) + return PHEEncryptor(encryptor), PHEDecryptor(decryptor) + else: + raise NotImplementedError(f"keygen for kind `{kind}` is not implemented") + + def create_tensor(self, tensor: torch.Tensor) -> "FPTensor": + + return FPTensor(self, tensor) + + @contextmanager + def sub_namespace(self, namespace): + """ + into sub_namespace ``, suffix federation namespace with `namespace` + + Examples: + ``` + with ctx.sub_namespace("fit"): + ctx.push(..., trans_key, obj) + + with ctx.sub_namespace("predict"): + ctx.push(..., trans_key, obj2) + ``` + `obj1` and `obj2` are pushed with different namespace + without conflic. + """ + + prev_namespace_state = self._namespace_state + + # into subnamespace + self._namespace_state = NamespaceState( + self._namespace_state.sub_namespace(namespace) + ) + + # return sub_ctx + # ```python + # with ctx.sub_namespace(xxx) as sub_ctx: + # ... + # ``` + # + yield self + + # restore namespace state when leaving with context + self._namespace_state = prev_namespace_state + + @overload + @contextmanager + def iter_namespaces( + self, start: int, stop: int, *, prefix_name="" + ) -> Generator[Generator["Context", None, None], None, None]: + ... + + @overload + @contextmanager + def iter_namespaces( + self, stop: int, *, prefix_name="" + ) -> Generator[Generator["Context", None, None], None, None]: + ... + + @contextmanager + def iter_namespaces(self, *args, prefix_name=""): + assert 0 < len(args) <= 2, "position argument should be 1 or 2" + if len(args) == 1: + start, stop = 0, args[0] + if len(args) == 2: + start, stop = args[0], args[1] + + prev_namespace_state = self._namespace_state + + def _state_iterator() -> Generator[Context, None, None]: + for i in range(start, stop): + # the tags in the iteration need to be distinguishable + template_formated = f"{prefix_name}iter_{i}" + self._namespace_state = IterationState( + prev_namespace_state.sub_namespace(template_formated) + ) + yield self + + # with context returns iterator of Contexts + # namespaec state inside context is changed alone with iterator comsued + yield _state_iterator() + + # restore namespace state when leaving with context + self._namespace_state = prev_namespace_state + + +class PHEEncryptor: + def __init__(self, encryptor: PHEEncryptorABC) -> None: + self._encryptor = encryptor + + def encrypt(self, tensor: "FPTensor"): + + return PHETensor(tensor._ctx, self._encryptor.encrypt(tensor._tensor)) + + +class PHEDecryptor: + def __init__(self, decryptor: PHEDecryptorABC) -> None: + self._decryptor = decryptor + + def decrypt(self, tensor: "PHETensor") -> "FPTensor": + + return FPTensor(tensor._ctx, self._decryptor.decrypt(tensor._tensor)) class FPTensor: - def __init__(self, ctx: Context, tensor: FPTensorABC) -> None: + def __init__(self, ctx: Context, tensor) -> None: self._ctx = ctx self._tensor = tensor def __add__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__add__"): + return NotImplemented return self._binary_op(other, self._tensor.__add__) def __radd__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__radd__"): + return self.__add__(other) return self._binary_op(other, self._tensor.__add__) def __sub__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__sub__"): + return NotImplemented return self._binary_op(other, self._tensor.__sub__) def __rsub__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__rsub__"): + return self.__mul__(-1).__add__(other) return self._binary_op(other, self._tensor.__rsub__) def __mul__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__mul__"): + return NotImplemented return self._binary_op(other, self._tensor.__mul__) def __rmul__(self, other: Union["FPTensor", float, int]) -> "FPTensor": + if not hasattr(self._tensor, "__rmul__"): + return self.__mul__(other) return self._binary_op(other, self._tensor.__rmul__) def __matmul__(self, other: "FPTensor") -> "FPTensor": + if not hasattr(self._tensor, "__matmul__"): + return NotImplemented if isinstance(other, FPTensor): return FPTensor(self._ctx, self._tensor.__matmul__(other._tensor)) else: return NotImplemented def __rmatmul__(self, other: "FPTensor") -> "FPTensor": + if not hasattr(self._tensor, "__rmatmul__"): + return NotImplemented if isinstance(other, FPTensor): return FPTensor(self._ctx, self._tensor.__rmatmul__(other._tensor)) else: return NotImplemented - @typing.overload - def encrypted(self, encryptor: "PHEEncryptor") -> "PHETensor": - ... - - @typing.overload - def encrypted(self, encryptor): - ... - - def encrypted(self, encryptor): - return encryptor.encrypt(self) - - def remote(self, target: _Parties, name: str): - return self._ctx.remote(target, name, self) - - @classmethod - def get(cls, ctx: Context, source: _Parties, name: str) -> "FPTensor": - return ctx.get(source, name) - def _binary_op(self, other, func): if isinstance(other, FPTensor): return FPTensor(self._ctx, func(other._tensor)) @@ -109,24 +468,17 @@ def __rmatmul__(self, other: FPTensor) -> "PHETensor": def T(self) -> "PHETensor": return PHETensor(self._ctx, self._tensor.T()) - @typing.overload + @overload def decrypt(self, decryptor: "PHEDecryptor") -> FPTensor: ... - @typing.overload + @overload def decrypt(self, decryptor) -> Any: ... def decrypt(self, decryptor): return decryptor.decrypt(self) - def remote(self, target: _Parties, name: str): - return self._ctx.remote(target, name, self) - - @classmethod - def get(cls, ctx: Context, source: _Parties, name: str) -> "PHETensor": - return ctx.get(source, name) - def _binary_op(self, other, func): if isinstance(other, (PHETensor, FPTensor)): return PHETensor(self._ctx, func(other._tensor)) diff --git a/python/fate_arch/tensor/abc/block.py b/python/fate_arch/tensor/abc/block.py index 7bb502beb7..5741011dda 100644 --- a/python/fate_arch/tensor/abc/block.py +++ b/python/fate_arch/tensor/abc/block.py @@ -2,7 +2,7 @@ import typing -class FPBlockABC(abc.ABC): +class FPBlockABC: @classmethod def zeors(cls, shape) -> "FPBlockABC": ... @@ -40,7 +40,7 @@ def __rmatmul__(self, other: "FPBlockABC") -> "FPBlockABC": ... -class PHEBlockABC(abc.ABC): +class PHEBlockABC: """Tensor implements Partial Homomorphic Encryption schema: 1. decrypt(encrypt(a) + encrypt(b)) = a + b 2. decrypt(encrypt(a) * b) = a * b @@ -99,19 +99,19 @@ def T(self) -> "PHEBlockABC": ... -class PHEBlockEncryptorABC(abc.ABC): +class PHEBlockEncryptorABC: @abc.abstractmethod def encrypt(self, tensor: FPBlockABC) -> PHEBlockABC: ... -class PHEBlockDecryptorABC(abc.ABC): +class PHEBlockDecryptorABC: @abc.abstractmethod def decrypt(self, tensor: PHEBlockABC) -> FPBlockABC: ... -class PHEBlockCipherABC(abc.ABC): +class PHEBlockCipherABC: @abc.abstractclassmethod def keygen( cls, **kwargs diff --git a/python/fate_arch/tensor/abc/tensor.py b/python/fate_arch/tensor/abc/tensor.py index 278f930740..843a6fad5c 100644 --- a/python/fate_arch/tensor/abc/tensor.py +++ b/python/fate_arch/tensor/abc/tensor.py @@ -1,42 +1,32 @@ import abc import typing +from typing_extensions import Protocol -class FPTensorABC(abc.ABC): - @classmethod - def zeors(cls, shape) -> "FPTensorABC": - ... +class FPTensorProtocol(Protocol): - @abc.abstractmethod - def __add__(self, other: typing.Union["FPTensorABC", float, int]) -> "FPTensorABC": + def __add__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": ... - @abc.abstractmethod - def __radd__(self, other: typing.Union["FPTensorABC", float, int]) -> "FPTensorABC": + def __radd__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": ... - @abc.abstractmethod - def __sub__(self, other: typing.Union["FPTensorABC", float, int]) -> "FPTensorABC": + def __sub__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": ... - @abc.abstractmethod - def __rsub__(self, other: typing.Union["FPTensorABC", float, int]) -> "FPTensorABC": + def __rsub__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": ... - @abc.abstractmethod - def __mul__(self, other: typing.Union["FPTensorABC", float, int]) -> "FPTensorABC": + def __mul__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": ... - @abc.abstractmethod - def __rmul__(self, other: typing.Union["FPTensorABC", float, int]) -> "FPTensorABC": + def __rmul__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": ... - @abc.abstractmethod - def __matmul__(self, other: "FPTensorABC") -> "FPTensorABC": + def __matmul__(self, other: "FPTensorProtocol") -> "FPTensorProtocol": ... - @abc.abstractmethod - def __rmatmul__(self, other: "FPTensorABC") -> "FPTensorABC": + def __rmatmul__(self, other: "FPTensorProtocol") -> "FPTensorProtocol": ... @@ -48,46 +38,46 @@ class PHETensorABC(abc.ABC): @abc.abstractmethod def __add__( - self, other: typing.Union["PHETensorABC", "FPTensorABC", float, int] + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] ) -> "PHETensorABC": ... @abc.abstractmethod def __radd__( - self, other: typing.Union["PHETensorABC", "FPTensorABC", float, int] + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] ) -> "PHETensorABC": ... @abc.abstractmethod def __sub__( - self, other: typing.Union["PHETensorABC", "FPTensorABC", float, int] + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] ) -> "PHETensorABC": ... @abc.abstractmethod def __rsub__( - self, other: typing.Union["PHETensorABC", "FPTensorABC", float, int] + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] ) -> "PHETensorABC": ... @abc.abstractmethod def __mul__( - self, other: typing.Union["PHETensorABC", "FPTensorABC", float, int] + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] ) -> "PHETensorABC": ... @abc.abstractmethod def __rmul__( - self, other: typing.Union["PHETensorABC", "FPTensorABC", float, int] + self, other: typing.Union["PHETensorABC", "FPTensorProtocol", float, int] ) -> "PHETensorABC": ... @abc.abstractmethod - def __matmul__(self, other: FPTensorABC) -> "PHETensorABC": + def __matmul__(self, other: FPTensorProtocol) -> "PHETensorABC": ... @abc.abstractmethod - def __rmatmul__(self, other: FPTensorABC) -> "PHETensorABC": + def __rmatmul__(self, other: FPTensorProtocol) -> "PHETensorABC": ... @abc.abstractmethod @@ -101,13 +91,13 @@ def T(self) -> "PHETensorABC": class PHEEncryptorABC(abc.ABC): @abc.abstractmethod - def encrypt(self, tensor: FPTensorABC) -> PHETensorABC: + def encrypt(self, tensor: FPTensorProtocol) -> PHETensorABC: ... class PHEDecryptorABC(abc.ABC): @abc.abstractmethod - def decrypt(self, tensor: PHETensorABC) -> FPTensorABC: + def decrypt(self, tensor: PHETensorABC) -> FPTensorProtocol: ... diff --git a/python/federatedml/ml/toy/enterpoint.py b/python/federatedml/ml/toy/enterpoint.py index 9a02ee0f4a..581a44d1fa 100644 --- a/python/federatedml/ml/toy/enterpoint.py +++ b/python/federatedml/ml/toy/enterpoint.py @@ -1,38 +1,14 @@ import torch -from fate_arch.tensor._context import CipherKind, Context -from fate_arch.tensor import PHETensor, FPTensor, GUEST, HOST +from fate_arch.tensor import GUEST, HOST, CipherKind, Context from federatedml.model_base import ComponentOutput, ModelBase -from federatedml.transfer_variable.base_transfer_variable import BaseTransferVariables from federatedml.util import LOGGER from .params import TensorExampleParam -# noinspection PyAttributeOutsideInit -class TensorExampleTransferVariable(BaseTransferVariables): - def __init__(self, flowid=0): - super().__init__(flowid) - self.guest_cipher = self._create_variable( - name="guest_cipher", src=["guest"], dst=["host"] - ) - self.host_cipher = self._create_variable( - name="host_cipher", src=["host"], dst=["guest"] - ) - self.host_matmul_encrypted = self._create_variable( - name="host_matmul_encrypted", src=["host"], dst=["guest"] - ) - self.host_matmul = self._create_variable( - name="host_matmul", src=["host"], dst=["guest"] - ) - self.guest_matmul_encrypted = self._create_variable( - name="guest_matmul_encrypted", src=["guest"], dst=["host"] - ) - - class TensorExampleGuest(ModelBase): def __init__(self): super(TensorExampleGuest, self).__init__() - self.transfer_inst = TensorExampleTransferVariable() self.model_param = TensorExampleParam() self.data_output = None self.model_output = None @@ -46,7 +22,8 @@ def _init_model(self): self.feature_num = self.model_param.feature_num def run(self, cpn_input): - ctx = Context() + ctx = Context.from_cpn_input(cpn_input) + LOGGER.info(ctx.describe()) ctx.device_init() return self._run(ctx, cpn_input) @@ -56,34 +33,34 @@ def _run(self, ctx: Context, cpn_input): self._init_runtime_parameters(cpn_input) LOGGER.info("begin to make guest data") - self.a = FPTensor(ctx, torch.rand((self.data_num, self.feature_num))) + self.a = ctx.create_tensor(torch.rand((self.data_num, self.feature_num))) LOGGER.info("keygen") - self.pk, self.sk = ctx.cypher_utils.keygen(CipherKind.PHE, 1024) + self.pk, self.sk = ctx.keygen(CipherKind.PHE, 1024) LOGGER.info("encrypt data") self.ea = self.pk.encrypt(self.a) LOGGER.info("share encrypted data to host") - self.ea.remote(HOST, "guest_cipher") + ctx.push(HOST, "guest_cipher", self.ea) LOGGER.info("get encrypted data from host") - self.eb = PHETensor.get(ctx, HOST, "host_cipher") + self.eb = ctx.pull(HOST, "host_cipher").unwrap_phe_tensor() LOGGER.info("begin to get matmul of guest and host") self.es_guest = self.a.T @ self.eb LOGGER.info("send encrypted matmul to host") - self.es_guest.remote(HOST, "guest_matmul_encrypted") + ctx.push(HOST, "guest_matmul_encrypted", self.es_guest) LOGGER.info("receive encrypted matmul from guest") - self.es_host = PHETensor.get(ctx, HOST, "host_matmul_encrypted") + self.es_host = ctx.pull(HOST, "host_matmul_encrypted").unwrap_phe_tensor() LOGGER.info("decrypt matmul") self.s_host = self.sk.decrypt(self.es_host) LOGGER.info("get decrypted matmul") - self.s_guest = FPTensor.get(ctx, HOST, "host_matmul") + self.s_guest = ctx.pull(HOST, "host_matmul").unwrap_tensor() LOGGER.info("assert matmul close") assert torch.allclose(self.s_host._tensor.T, self.s_guest._tensor) @@ -94,7 +71,6 @@ def _run(self, ctx: Context, cpn_input): class TensorExampleHost(ModelBase): def __init__(self): super(TensorExampleHost, self).__init__() - self.transfer_inst = TensorExampleTransferVariable() self.model_param = TensorExampleParam() self.data_output = None self.model_output = None @@ -108,8 +84,9 @@ def _init_model(self): self.feature_num = self.model_param.feature_num def run(self, cpn_input): - ctx = Context() + ctx = Context.from_cpn_input(cpn_input) ctx.device_init() + LOGGER.info(ctx.describe()) return self._run(ctx, cpn_input) def _run(self, ctx: Context, cpn_input): @@ -117,33 +94,37 @@ def _run(self, ctx: Context, cpn_input): self._init_runtime_parameters(cpn_input) LOGGER.info("begin to make host data") - self.b = FPTensor(ctx, torch.rand((self.data_num, self.feature_num))) + self.b = ctx.create_tensor(torch.rand((self.data_num, self.feature_num))) + + with ctx.iter_namespaces(10, prefix_name="tree_") as iteration: + for i, _ in enumerate(iteration): + print(ctx.current_namespace()) LOGGER.info("keygen") - self.pk, self.sk = ctx.cypher_utils.keygen(CipherKind.PHE, 1024) + self.pk, self.sk = ctx.keygen(CipherKind.PHE, 1024) LOGGER.info("begin to encrypt") self.eb = self.pk.encrypt(self.b) LOGGER.info("share encrypted data to guest") - self.eb.remote(GUEST, "host_cipher") + ctx.push(GUEST, "host_cipher", self.eb) LOGGER.info("get encrypted data from guest") - self.ea = PHETensor.get(ctx, GUEST, "guest_cipher") + self.ea = ctx.pull(GUEST, "guest_cipher").unwrap_phe_tensor() LOGGER.info("begin to get matmul of host and guest") self.es_host = self.b.T @ self.ea LOGGER.info("send encrypted matmul to guest") - self.es_host.remote(GUEST, "host_matmul_encrypted") + ctx.push(GUEST, "host_matmul_encrypted", self.es_host) LOGGER.info("get encrypted matmul from guest") - self.es_guest = PHETensor.get(ctx, GUEST, "guest_matmul_encrypted") + self.es_guest = ctx.pull(GUEST, "guest_matmul_encrypted").unwrap_phe_tensor() LOGGER.info("decrypt encrypted matmul from guest") self.s_guest = self.sk.decrypt(self.es_guest) LOGGER.info("send decrypted matmul to guest") - self.s_guest.remote(GUEST, "host_matmul") + ctx.push(GUEST, "host_matmul", self.s_guest) return ComponentOutput(self.save_data(), self.export_model(), self.save_cache()) From 1369bd7a99cce41cbc98ccfb258ed710dacfa2d2 Mon Sep 17 00:00:00 2001 From: weiwee Date: Thu, 21 Jul 2022 04:34:08 -0800 Subject: [PATCH 04/11] refact: rename rust_tensor to rust_paillier Signed-off-by: weiwee --- ...heel.yml => build_rust_paillier_wheel.yml} | 6 ++-- rust/fate-tensor/fate_tensor/__init__.py | 1 - .../rust_paillier}/.projectile | 0 .../rust_paillier}/Cargo.lock | 34 +++++++++---------- .../rust_paillier}/Cargo.toml | 4 +-- .../rust_paillier}/benches/base_bench.py | 7 ++-- .../rust_paillier}/benches/iai_bench.rs | 0 .../rust_paillier}/benches/paillier_bench.rs | 0 .../rust_paillier}/pyproject.toml | 15 +------- .../rust_paillier/rust_paillier/__init__.py | 1 + .../rust_paillier/rust_paillier}/__init__.pyi | 0 .../rust_paillier}/par/__init__.py | 0 .../rust_paillier}/par/__init__.pyi | 0 .../rust_paillier}/src/block/matmul.rs | 0 .../rust_paillier}/src/block/mod.rs | 0 .../rust_paillier}/src/cb.rs | 0 .../rust_paillier}/src/fixedpoint/coder.rs | 0 .../rust_paillier}/src/fixedpoint/frexp.rs | 0 .../rust_paillier}/src/fixedpoint/mod.rs | 0 .../rust_paillier}/src/lib.rs | 8 ++--- .../rust_paillier}/src/math/mod.rs | 0 .../rust_paillier}/src/math/rug/mod.rs | 0 .../rust_paillier}/src/math/rug/ops.rs | 0 .../rust_paillier}/src/math/rug/random.rs | 0 .../rust_paillier}/src/math/rug/serde.rs | 0 .../rust_paillier}/src/paillier/mod.rs | 0 .../rust_paillier}/src/par/cb.rs | 0 .../rust_paillier}/src/par/mod.rs | 8 ++--- .../rust_paillier}/tests/test_base.py | 2 +- 29 files changed, 36 insertions(+), 50 deletions(-) rename .github/workflows/{build_tensor_wheel.yml => build_rust_paillier_wheel.yml} (92%) delete mode 100644 rust/fate-tensor/fate_tensor/__init__.py rename rust/{fate-tensor => tensor/rust_paillier}/.projectile (100%) rename rust/{fate-tensor => tensor/rust_paillier}/Cargo.lock (99%) rename rust/{fate-tensor => tensor/rust_paillier}/Cargo.toml (95%) rename rust/{fate-tensor => tensor/rust_paillier}/benches/base_bench.py (97%) rename rust/{fate-tensor => tensor/rust_paillier}/benches/iai_bench.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/benches/paillier_bench.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/pyproject.toml (54%) create mode 100644 rust/tensor/rust_paillier/rust_paillier/__init__.py rename rust/{fate-tensor/fate_tensor => tensor/rust_paillier/rust_paillier}/__init__.pyi (100%) rename rust/{fate-tensor/fate_tensor => tensor/rust_paillier/rust_paillier}/par/__init__.py (100%) rename rust/{fate-tensor/fate_tensor => tensor/rust_paillier/rust_paillier}/par/__init__.pyi (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/block/matmul.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/block/mod.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/cb.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/fixedpoint/coder.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/fixedpoint/frexp.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/fixedpoint/mod.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/lib.rs (98%) rename rust/{fate-tensor => tensor/rust_paillier}/src/math/mod.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/math/rug/mod.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/math/rug/ops.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/math/rug/random.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/math/rug/serde.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/paillier/mod.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/par/cb.rs (100%) rename rust/{fate-tensor => tensor/rust_paillier}/src/par/mod.rs (97%) rename rust/{fate-tensor => tensor/rust_paillier}/tests/test_base.py (98%) diff --git a/.github/workflows/build_tensor_wheel.yml b/.github/workflows/build_rust_paillier_wheel.yml similarity index 92% rename from .github/workflows/build_tensor_wheel.yml rename to .github/workflows/build_rust_paillier_wheel.yml index 574f4dc593..543d082855 100644 --- a/.github/workflows/build_tensor_wheel.yml +++ b/.github/workflows/build_rust_paillier_wheel.yml @@ -1,4 +1,4 @@ -name: Build FATE-CPU-Tensor +name: Build Rust Paillier on: workflow_dispatch: @@ -44,13 +44,13 @@ jobs: with: manylinux: auto command: build - args: --release -o dist -m rust/fate_tensor/Cargo.toml + args: --release -o dist -m rust/tensor/rust_paillier/Cargo.toml - name: macos-maturin if: matrix.os == 'macos' uses: messense/maturin-action@v1 with: command: build - args: --release --no-sdist -o dist -m rust/fate_tensor/Cargo.toml + args: --release --no-sdist -o dist -m rust/tensor/rust_paillier/Cargo.toml - name: Upload wheels uses: actions/upload-artifact@v2 with: diff --git a/rust/fate-tensor/fate_tensor/__init__.py b/rust/fate-tensor/fate_tensor/__init__.py deleted file mode 100644 index 1c1df9b607..0000000000 --- a/rust/fate-tensor/fate_tensor/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .fate_tensor import * diff --git a/rust/fate-tensor/.projectile b/rust/tensor/rust_paillier/.projectile similarity index 100% rename from rust/fate-tensor/.projectile rename to rust/tensor/rust_paillier/.projectile diff --git a/rust/fate-tensor/Cargo.lock b/rust/tensor/rust_paillier/Cargo.lock similarity index 99% rename from rust/fate-tensor/Cargo.lock rename to rust/tensor/rust_paillier/Cargo.lock index 518bad791c..e9be0c7c77 100644 --- a/rust/fate-tensor/Cargo.lock +++ b/rust/tensor/rust_paillier/Cargo.lock @@ -199,23 +199,6 @@ version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e78d4f1cc4ae33bbfc157ed5d5a5ef3bc29227303d595861deb238fcec4e9457" -[[package]] -name = "fate-tensor" -version = "0.1.0" -dependencies = [ - "bincode", - "criterion", - "iai", - "ndarray", - "numpy", - "pyo3", - "rand", - "rand_core", - "rayon", - "rug", - "serde", -] - [[package]] name = "getrandom" version = "0.2.7" @@ -704,6 +687,23 @@ dependencies = [ "libc", ] +[[package]] +name = "rust_paillier" +version = "0.1.0" +dependencies = [ + "bincode", + "criterion", + "iai", + "ndarray", + "numpy", + "pyo3", + "rand", + "rand_core", + "rayon", + "rug", + "serde", +] + [[package]] name = "rustc_version" version = "0.4.0" diff --git a/rust/fate-tensor/Cargo.toml b/rust/tensor/rust_paillier/Cargo.toml similarity index 95% rename from rust/fate-tensor/Cargo.toml rename to rust/tensor/rust_paillier/Cargo.toml index 552de2f9a2..f16f4cefaf 100644 --- a/rust/fate-tensor/Cargo.toml +++ b/rust/tensor/rust_paillier/Cargo.toml @@ -1,11 +1,11 @@ [package] -name = "fate-tensor" +name = "rust_paillier" version = "0.1.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] -name = "fate_tensor" +name = "rust_paillier" crate-type = ["cdylib", "staticlib", "rlib"] bench = false diff --git a/rust/fate-tensor/benches/base_bench.py b/rust/tensor/rust_paillier/benches/base_bench.py similarity index 97% rename from rust/fate-tensor/benches/base_bench.py rename to rust/tensor/rust_paillier/benches/base_bench.py index 576c88707e..977ac90172 100644 --- a/rust/fate-tensor/benches/base_bench.py +++ b/rust/tensor/rust_paillier/benches/base_bench.py @@ -1,4 +1,3 @@ -from _pytest.mark import expression import pytest import operator import os @@ -8,7 +7,7 @@ try: import gmpy2 -except: +except Exception: raise RuntimeError(f"gmpy2 not installed, lib phe without gmpy2 is slow") @@ -23,7 +22,7 @@ def get_num_threads(): def get_single_thread_keygen(): - from fate_tensor import keygen + from rust_paillier import keygen return keygen @@ -32,7 +31,7 @@ def get_single_thread_keygen(): def get_multiple_thread_keygen(): - from fate_tensor.par import keygen, set_num_threads + from rust_paillier.par import keygen, set_num_threads set_num_threads(NUM_THREADS) return keygen diff --git a/rust/fate-tensor/benches/iai_bench.rs b/rust/tensor/rust_paillier/benches/iai_bench.rs similarity index 100% rename from rust/fate-tensor/benches/iai_bench.rs rename to rust/tensor/rust_paillier/benches/iai_bench.rs diff --git a/rust/fate-tensor/benches/paillier_bench.rs b/rust/tensor/rust_paillier/benches/paillier_bench.rs similarity index 100% rename from rust/fate-tensor/benches/paillier_bench.rs rename to rust/tensor/rust_paillier/benches/paillier_bench.rs diff --git a/rust/fate-tensor/pyproject.toml b/rust/tensor/rust_paillier/pyproject.toml similarity index 54% rename from rust/fate-tensor/pyproject.toml rename to rust/tensor/rust_paillier/pyproject.toml index abfdd0b0ec..8bc22f0487 100644 --- a/rust/fate-tensor/pyproject.toml +++ b/rust/tensor/rust_paillier/pyproject.toml @@ -3,23 +3,10 @@ requires = ["maturin>=0.12,<0.13"] build-backend = "maturin" [project] -name = "fate-tensor" +name = "rust_paillier" requires-python = ">=3.6" classifiers = [ "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -[tool.pyright] -include = [] -exclude = [ - "**/__pycache__", - ".venv/" - ] -venvPath = "/Users/sage/MEGA/FATE/" -venv = "venv" -reportMissingImports = true -reportMissingTypeStubs = false -executionEnvironments = [ - { root = "."}, -] diff --git a/rust/tensor/rust_paillier/rust_paillier/__init__.py b/rust/tensor/rust_paillier/rust_paillier/__init__.py new file mode 100644 index 0000000000..86fa586d02 --- /dev/null +++ b/rust/tensor/rust_paillier/rust_paillier/__init__.py @@ -0,0 +1 @@ +from .rust_paillier import * diff --git a/rust/fate-tensor/fate_tensor/__init__.pyi b/rust/tensor/rust_paillier/rust_paillier/__init__.pyi similarity index 100% rename from rust/fate-tensor/fate_tensor/__init__.pyi rename to rust/tensor/rust_paillier/rust_paillier/__init__.pyi diff --git a/rust/fate-tensor/fate_tensor/par/__init__.py b/rust/tensor/rust_paillier/rust_paillier/par/__init__.py similarity index 100% rename from rust/fate-tensor/fate_tensor/par/__init__.py rename to rust/tensor/rust_paillier/rust_paillier/par/__init__.py diff --git a/rust/fate-tensor/fate_tensor/par/__init__.pyi b/rust/tensor/rust_paillier/rust_paillier/par/__init__.pyi similarity index 100% rename from rust/fate-tensor/fate_tensor/par/__init__.pyi rename to rust/tensor/rust_paillier/rust_paillier/par/__init__.pyi diff --git a/rust/fate-tensor/src/block/matmul.rs b/rust/tensor/rust_paillier/src/block/matmul.rs similarity index 100% rename from rust/fate-tensor/src/block/matmul.rs rename to rust/tensor/rust_paillier/src/block/matmul.rs diff --git a/rust/fate-tensor/src/block/mod.rs b/rust/tensor/rust_paillier/src/block/mod.rs similarity index 100% rename from rust/fate-tensor/src/block/mod.rs rename to rust/tensor/rust_paillier/src/block/mod.rs diff --git a/rust/fate-tensor/src/cb.rs b/rust/tensor/rust_paillier/src/cb.rs similarity index 100% rename from rust/fate-tensor/src/cb.rs rename to rust/tensor/rust_paillier/src/cb.rs diff --git a/rust/fate-tensor/src/fixedpoint/coder.rs b/rust/tensor/rust_paillier/src/fixedpoint/coder.rs similarity index 100% rename from rust/fate-tensor/src/fixedpoint/coder.rs rename to rust/tensor/rust_paillier/src/fixedpoint/coder.rs diff --git a/rust/fate-tensor/src/fixedpoint/frexp.rs b/rust/tensor/rust_paillier/src/fixedpoint/frexp.rs similarity index 100% rename from rust/fate-tensor/src/fixedpoint/frexp.rs rename to rust/tensor/rust_paillier/src/fixedpoint/frexp.rs diff --git a/rust/fate-tensor/src/fixedpoint/mod.rs b/rust/tensor/rust_paillier/src/fixedpoint/mod.rs similarity index 100% rename from rust/fate-tensor/src/fixedpoint/mod.rs rename to rust/tensor/rust_paillier/src/fixedpoint/mod.rs diff --git a/rust/fate-tensor/src/lib.rs b/rust/tensor/rust_paillier/src/lib.rs similarity index 98% rename from rust/fate-tensor/src/lib.rs rename to rust/tensor/rust_paillier/src/lib.rs index 60feeb91c0..9ba858324a 100644 --- a/rust/fate-tensor/src/lib.rs +++ b/rust/tensor/rust_paillier/src/lib.rs @@ -14,15 +14,15 @@ use pyo3::types::PyBytes; /// /// we need `new` method with zero argument (Option::None) /// for unpickle to work. -#[pyclass(module = "fate_tensor")] +#[pyclass(module = "rust_paillier")] pub struct Cipherblock(Option); -#[pyclass(module = "fate_tensor")] +#[pyclass(module = "rust_paillier")] pub struct PK { pk: fixedpoint::PK, } -#[pyclass(module = "fate_tensor")] +#[pyclass(module = "rust_paillier")] pub struct SK { sk: fixedpoint::SK, } @@ -237,7 +237,7 @@ impl Cipherblock { } #[pymodule] -fn fate_tensor(_py: Python, m: &PyModule) -> PyResult<()> { +fn rust_paillier(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/rust/fate-tensor/src/math/mod.rs b/rust/tensor/rust_paillier/src/math/mod.rs similarity index 100% rename from rust/fate-tensor/src/math/mod.rs rename to rust/tensor/rust_paillier/src/math/mod.rs diff --git a/rust/fate-tensor/src/math/rug/mod.rs b/rust/tensor/rust_paillier/src/math/rug/mod.rs similarity index 100% rename from rust/fate-tensor/src/math/rug/mod.rs rename to rust/tensor/rust_paillier/src/math/rug/mod.rs diff --git a/rust/fate-tensor/src/math/rug/ops.rs b/rust/tensor/rust_paillier/src/math/rug/ops.rs similarity index 100% rename from rust/fate-tensor/src/math/rug/ops.rs rename to rust/tensor/rust_paillier/src/math/rug/ops.rs diff --git a/rust/fate-tensor/src/math/rug/random.rs b/rust/tensor/rust_paillier/src/math/rug/random.rs similarity index 100% rename from rust/fate-tensor/src/math/rug/random.rs rename to rust/tensor/rust_paillier/src/math/rug/random.rs diff --git a/rust/fate-tensor/src/math/rug/serde.rs b/rust/tensor/rust_paillier/src/math/rug/serde.rs similarity index 100% rename from rust/fate-tensor/src/math/rug/serde.rs rename to rust/tensor/rust_paillier/src/math/rug/serde.rs diff --git a/rust/fate-tensor/src/paillier/mod.rs b/rust/tensor/rust_paillier/src/paillier/mod.rs similarity index 100% rename from rust/fate-tensor/src/paillier/mod.rs rename to rust/tensor/rust_paillier/src/paillier/mod.rs diff --git a/rust/fate-tensor/src/par/cb.rs b/rust/tensor/rust_paillier/src/par/cb.rs similarity index 100% rename from rust/fate-tensor/src/par/cb.rs rename to rust/tensor/rust_paillier/src/par/cb.rs diff --git a/rust/fate-tensor/src/par/mod.rs b/rust/tensor/rust_paillier/src/par/mod.rs similarity index 97% rename from rust/fate-tensor/src/par/mod.rs rename to rust/tensor/rust_paillier/src/par/mod.rs index 62ef3e094d..db0d0178d5 100644 --- a/rust/fate-tensor/src/par/mod.rs +++ b/rust/tensor/rust_paillier/src/par/mod.rs @@ -7,15 +7,15 @@ use crate::block; mod cb; -#[pyclass(module = "fate_tensor.par")] +#[pyclass(module = "rust_paillier.par")] pub struct Cipherblock(Option); -#[pyclass(module = "fate_tensor.par")] +#[pyclass(module = "rust_paillier.par")] pub struct PK { pk: fixedpoint::PK, } -#[pyclass(module = "fate_tensor.par")] +#[pyclass(module = "rust_paillier.par")] pub struct SK { sk: fixedpoint::SK, } @@ -230,6 +230,6 @@ pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { m.add_submodule(submodule_par)?; py.import("sys")? .getattr("modules")? - .set_item("fate_tensor.par", submodule_par)?; + .set_item("rust_paillier.par", submodule_par)?; Ok(()) } diff --git a/rust/fate-tensor/tests/test_base.py b/rust/tensor/rust_paillier/tests/test_base.py similarity index 98% rename from rust/fate-tensor/tests/test_base.py rename to rust/tensor/rust_paillier/tests/test_base.py index 7f3c36d6a1..539cc18157 100644 --- a/rust/fate-tensor/tests/test_base.py +++ b/rust/tensor/rust_paillier/tests/test_base.py @@ -8,7 +8,7 @@ def get_suites(): suites = [] - packages = ["fate_tensor", "fate_tensor.par"] + packages = ["rust_paillier", "rust_paillier.par"] for package in packages: module = importlib.import_module(package) suites.append(Suite(module.keygen)) From 4d6efd69d5a6f4d98e008f9dd01a3b1d9ecd11d6 Mon Sep 17 00:00:00 2001 From: weiwee Date: Wed, 27 Jul 2022 22:49:25 -0800 Subject: [PATCH 05/11] fix: fix pk/sk serde and distributed tensor Signed-off-by: weiwee --- python/fate_arch/tensor/_federation.py | 7 + python/fate_arch/tensor/_tensor.py | 163 ++++++++++++++---- .../tensor/impl/blocks/cpu_paillier_block.py | 4 +- .../blocks/multithread_cpu_paillier_block.py | 4 +- .../tensor/impl/tensor/distributed.py | 46 ++++- python/federatedml/ml/toy/enterpoint.py | 22 ++- rust/tensor/rust_paillier/src/lib.rs | 20 ++- 7 files changed, 213 insertions(+), 53 deletions(-) create mode 100644 python/fate_arch/tensor/_federation.py diff --git a/python/fate_arch/tensor/_federation.py b/python/fate_arch/tensor/_federation.py new file mode 100644 index 0000000000..3687d00f50 --- /dev/null +++ b/python/fate_arch/tensor/_federation.py @@ -0,0 +1,7 @@ +class FederationDeserializer: + def do_deserialize(self, ctx, party): + ... + + @classmethod + def make_frac_key(cls, base_key, frac_key): + return f"{base_key}__frac__{frac_key}" diff --git a/python/fate_arch/tensor/_tensor.py b/python/fate_arch/tensor/_tensor.py index feb30529f5..d0ee9e8b32 100644 --- a/python/fate_arch/tensor/_tensor.py +++ b/python/fate_arch/tensor/_tensor.py @@ -6,12 +6,14 @@ Callable, Generator, List, + Mapping, Optional, Tuple, TypeVar, Union, overload, ) +import typing import torch from fate_arch.common import Party @@ -19,8 +21,10 @@ from fate_arch.session import get_session from typing_extensions import Literal + from ._parties import Parties, PreludeParty from .abc.tensor import PHEDecryptorABC, PHEEncryptorABC, PHETensorABC +from ._federation import FederationDeserializer class NamespaceState: @@ -55,6 +59,13 @@ class Device(Enum): CPU = 1 GPU = 2 FPGA = 3 + CPU_Intel = 4 + + +class Distributed(Enum): + NONE = 1 + EGGROLL = 2 + SPARK = 3 T = TypeVar("T") @@ -115,7 +126,6 @@ def unwrap(self, check: Optional[Callable[[T], bool]] = None) -> List[T]: class _ContextInside: def __init__(self, cpn_input) -> None: - self._device = None self._push_gc_dict = {} self._pull_gc_dict = {} @@ -126,6 +136,13 @@ def __init__(self, cpn_input) -> None: self._parameters = cpn_input.parameters self._flow_feeded_parameters = cpn_input.flow_feeded_parameters + self._device = Device.CPU + self._distributed = Distributed.EGGROLL + + @property + def device(self): + return self._device + @property def is_guest(self): return self._roles["local"]["role"] == "guest" @@ -175,11 +192,7 @@ def from_cpn_input(cls, cpn_input): return Context(states, namespace) def describe(self): - return json.dumps( - dict( - states=self._inside.describe(), - ) - ) + return json.dumps(dict(states=self._inside.describe(),)) @property def party(self): @@ -205,46 +218,52 @@ def is_host(self): def is_arbiter(self): return self._inside.is_guest - def device_init(self, **kwargs): - self._device = Device.CPU - + @property def device(self) -> Device: - if self._device is None: - raise RuntimeError(f"init device first") - return self._device + return self._inside.device + + @property + def distributed(self) -> Distributed: + return self._inside._distributed def current_namespace(self): return self._namespace_state.get_namespce() def push(self, target: Parties, key: str, value): - return self._push(target, key, value) + return self._push(target.get_parties(), key, value) - def pull( - self, - source: Literal[PreludeParty.GUEST, PreludeParty.HOST, PreludeParty.ARBITER], - key: str, - ) -> Future: - return Future(self._pull(source, key)[0]) + def pull(self, source: Parties, key: str,) -> Future: + return Future(self._pull(source.get_parties(), key)[0]) def pulls(self, source: Parties, key: str) -> Futures: - return Futures(self._pull(source, key)) + return Futures(self._pull(source.get_parties(), key)) - def _push(self, parties: Parties, key, value): - get_session().federation.remote( - v=value, - name=key, - tag=self.current_namespace(), - parties=parties.get_parties(), - gc=self._inside.get_or_set_push_gc(key), - ) + def _push(self, parties: typing.List[Party], key, value): + if hasattr(value, "__federation_hook__"): + value.__federation_hook__(self, key, parties) + else: + get_session().federation.remote( + v=value, + name=key, + tag=self.current_namespace(), + parties=parties, + gc=self._inside.get_or_set_push_gc(key), + ) - def _pull(self, parties: Parties, key): - return get_session().federation.get( + def _pull(self, parties: typing.List[Party], key): + raw_values = get_session().federation.get( name=key, tag=self.current_namespace(), - parties=parties.get_parties(), + parties=parties, gc=self._inside.get_or_set_pull_gc(key), ) + values = [] + for party, raw_value in zip(parties, raw_values): + if isinstance(raw_value, FederationDeserializer): + values.append(raw_value.do_deserialize(self, party)) + else: + values.append(raw_value) + return values @overload def keygen( @@ -257,16 +276,54 @@ def keygen(self, kind: CipherKind, **kwargs) -> Any: ... def keygen(self, kind, key_length: int, **kwargs): + # TODO: exploring expansion eechanisms if kind == CipherKind.PHE or kind == CipherKind.PHE_PAILLIER: - if self._device == Device.CPU: - from .impl.tensor.multithread import PaillierPHECipherLocal + if self.distributed == Distributed.NONE: + if self.device == Device.CPU: + from .impl.tensor.multithread_cpu_tensor import ( + PaillierPHECipherLocal, + ) + + encryptor, decryptor = PaillierPHECipherLocal().keygen( + key_length=key_length + ) + return PHEEncryptor(encryptor), PHEDecryptor(decryptor) + if self.distributed == Distributed.EGGROLL: + if self.device == Device.CPU: + from .impl.tensor.distributed import PaillierPHECipherDistributed + + encryptor, decryptor = PaillierPHECipherDistributed().keygen( + key_length=key_length + ) + return PHEEncryptor(encryptor), PHEDecryptor(decryptor) + + raise NotImplementedError( + f"keygen for kind<{kind}>-distributed<{self.distributed}>-device<{self.device}> is not implemented" + ) - encryptor, decryptor = PaillierPHECipherLocal().keygen( - key_length=key_length - ) - return PHEEncryptor(encryptor), PHEDecryptor(decryptor) + def random_tensor(self, shape, num_partition = 1) -> "FPTensor": + if self.distributed == Distributed.NONE: + return FPTensor(self, torch.rand(shape)) else: - raise NotImplementedError(f"keygen for kind `{kind}` is not implemented") + from fate_arch.session import computing_session + from fate_arch.tensor.impl.tensor.distributed import FPTensorDistributed + + parts = [] + last_dim = shape[-1] + for i in range(num_partition): + if i == num_partition - 1: + parts.append(torch.tensor((*shape[:-1], last_dim))) + else: + parts.append(torch.tensor((*shape[:-1], shape[-1] / num_partition))) + last_dim -= shape[-1] / num_partition + return FPTensor( + self, + FPTensorDistributed( + computing_session.parallelize( + parts, include_key=False, partition=num_partition + ) + ), + ) def create_tensor(self, tensor: torch.Tensor) -> "FPTensor": @@ -429,6 +486,12 @@ def _binary_op(self, other, func): def T(self): return FPTensor(self._ctx, self._tensor.T) + def __federation_hook__(self, ctx, key, parties): + deserializer = FPTensorFederationDeserializer(key) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._tensor) class PHETensor: def __init__(self, ctx: Context, tensor: PHETensorABC) -> None: @@ -485,3 +548,27 @@ def _binary_op(self, other, func): elif isinstance(other, (int, float)): return PHETensor(self._ctx, func(other)) return NotImplemented + + def __federation_hook__(self, ctx, key, parties): + deserializer = PHETensorFederationDeserializer(key) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._tensor) + + +class PHETensorFederationDeserializer(FederationDeserializer): + def __init__(self, key) -> None: + self.table_key = self.make_frac_key(key, "tensor") + + def do_deserialize(self, ctx: Context, party: Party) -> PHETensor: + tensor = ctx._pull([party], self.table_key)[0] + return PHETensor(ctx, tensor) + +class FPTensorFederationDeserializer(FederationDeserializer): + def __init__(self, key) -> None: + self.table_key = self.make_frac_key(key, "tensor") + + def do_deserialize(self, ctx: Context, party: Party) -> FPTensor: + tensor = ctx._pull([party], self.table_key)[0] + return FPTensor(ctx, tensor) diff --git a/python/fate_arch/tensor/impl/blocks/cpu_paillier_block.py b/python/fate_arch/tensor/impl/blocks/cpu_paillier_block.py index bddd8ed727..d592762430 100644 --- a/python/fate_arch/tensor/impl/blocks/cpu_paillier_block.py +++ b/python/fate_arch/tensor/impl/blocks/cpu_paillier_block.py @@ -15,7 +15,7 @@ # -import fate_tensor.par +import rust_paillier import torch from ._metaclass import ( @@ -44,7 +44,7 @@ class BlockPaillierDecryptor( class BlockPaillierCipher( metaclass=phe_keygen_metaclass( - BlockPaillierEncryptor, BlockPaillierDecryptor, fate_tensor.par.keygen + BlockPaillierEncryptor, BlockPaillierDecryptor, rust_paillier.keygen ) ): pass diff --git a/python/fate_arch/tensor/impl/blocks/multithread_cpu_paillier_block.py b/python/fate_arch/tensor/impl/blocks/multithread_cpu_paillier_block.py index 8864750a59..1e89f92e5c 100644 --- a/python/fate_arch/tensor/impl/blocks/multithread_cpu_paillier_block.py +++ b/python/fate_arch/tensor/impl/blocks/multithread_cpu_paillier_block.py @@ -15,7 +15,7 @@ # -import fate_tensor +import rust_paillier.par import torch from ._metaclass import ( @@ -44,7 +44,7 @@ class BlockPaillierDecryptor( class BlockPaillierCipher( metaclass=phe_keygen_metaclass( - BlockPaillierEncryptor, BlockPaillierDecryptor, fate_tensor.keygen + BlockPaillierEncryptor, BlockPaillierDecryptor, rust_paillier.par.keygen ) ): pass diff --git a/python/fate_arch/tensor/impl/tensor/distributed.py b/python/fate_arch/tensor/impl/tensor/distributed.py index 2326e8f862..f121d6d1bd 100644 --- a/python/fate_arch/tensor/impl/tensor/distributed.py +++ b/python/fate_arch/tensor/impl/tensor/distributed.py @@ -8,6 +8,8 @@ PHEEncryptorABC, PHETensorABC, ) +from ..._federation import FederationDeserializer +from ..._tensor import Context, Party Numeric = typing.Union[int, float] @@ -74,6 +76,12 @@ def __rmatmul__(self, other: "FPTensorDistributed") -> "FPTensorDistributed": # todo: fix ... + def __federation_hook__(self, ctx, key, parties): + deserializer = FPTensorFederationDeserializer(key) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._blocks_table) class PHETensorDistributed(PHETensorABC): def __init__(self, blocks_table) -> None: @@ -131,8 +139,14 @@ def T(self) -> "PHETensorDistributed": return transposed def serialize(self): - # TODO: impl me - ... + return self._blocks_table + + def deserialize(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __getstates__(self): + return {"_is_transpose": self._is_transpose} def _binary_op(self, other, func_name): if isinstance(other, (FPTensorDistributed, PHETensorDistributed)): @@ -161,6 +175,12 @@ def _binary_op_limited(self, other, func_name): ) return NotImplemented + def __federation_hook__(self, ctx, key, parties): + deserializer = PHETensorFederationDeserializer(key, self._is_transpose) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._blocks_table) class PaillierPHEEncryptorDistributed(PHEEncryptorABC): def __init__(self, block_encryptor) -> None: @@ -187,10 +207,30 @@ class PaillierPHECipherDistributed(PHECipherABC): def keygen( cls, **kwargs ) -> typing.Tuple[PaillierPHEEncryptorDistributed, PaillierPHEDecryptorDistributed]: - from ..blocks.python_paillier_block import BlockPaillierCipher + from ..blocks.cpu_paillier_block import BlockPaillierCipher block_encrytor, block_decryptor = BlockPaillierCipher.keygen(**kwargs) return ( PaillierPHEEncryptorDistributed(block_encrytor), PaillierPHEDecryptorDistributed(block_decryptor), ) + +class PHETensorFederationDeserializer(FederationDeserializer): + def __init__(self, key, is_transpose) -> None: + self.table_key = self.make_frac_key(key, "table") + self.is_transpose = is_transpose + + def do_deserialize(self, ctx: Context, party: Party) -> PHETensorDistributed: + table = ctx._pull([party], self.table_key)[0] + tensor = PHETensorDistributed(table) + tensor._is_transpose = self.is_transpose + return tensor + +class FPTensorFederationDeserializer(FederationDeserializer): + def __init__(self, key) -> None: + self.table_key = self.make_frac_key(key, "table") + + def do_deserialize(self, ctx: Context, party: Party) -> FPTensorDistributed: + table = ctx._pull([party], self.table_key)[0] + tensor = FPTensorDistributed(table) + return tensor diff --git a/python/federatedml/ml/toy/enterpoint.py b/python/federatedml/ml/toy/enterpoint.py index 581a44d1fa..39c37ec84d 100644 --- a/python/federatedml/ml/toy/enterpoint.py +++ b/python/federatedml/ml/toy/enterpoint.py @@ -24,7 +24,6 @@ def _init_model(self): def run(self, cpn_input): ctx = Context.from_cpn_input(cpn_input) LOGGER.info(ctx.describe()) - ctx.device_init() return self._run(ctx, cpn_input) def _run(self, ctx: Context, cpn_input): @@ -33,7 +32,7 @@ def _run(self, ctx: Context, cpn_input): self._init_runtime_parameters(cpn_input) LOGGER.info("begin to make guest data") - self.a = ctx.create_tensor(torch.rand((self.data_num, self.feature_num))) + self.a = ctx.random_tensor((self.data_num, self.feature_num)) LOGGER.info("keygen") self.pk, self.sk = ctx.keygen(CipherKind.PHE, 1024) @@ -48,7 +47,9 @@ def _run(self, ctx: Context, cpn_input): self.eb = ctx.pull(HOST, "host_cipher").unwrap_phe_tensor() LOGGER.info("begin to get matmul of guest and host") - self.es_guest = self.a.T @ self.eb + self.es_guest = self.a + self.eb + # LOGGER.info("begin to get matmul of guest and host") + # self.es_guest = self.a.T @ self.eb LOGGER.info("send encrypted matmul to host") ctx.push(HOST, "guest_matmul_encrypted", self.es_guest) @@ -63,7 +64,13 @@ def _run(self, ctx: Context, cpn_input): self.s_guest = ctx.pull(HOST, "host_matmul").unwrap_tensor() LOGGER.info("assert matmul close") - assert torch.allclose(self.s_host._tensor.T, self.s_guest._tensor) + sb = self.s_host._tensor._blocks_table.count() + sa = self.s_guest._tensor._blocks_table.count() + assert sa == sb + a = list(self.s_guest._tensor._blocks_table.collect())[0] + b = list(self.s_host._tensor._blocks_table.collect())[0] + assert torch.allclose(a[1], b[1]) + # assert torch.allclose(self.s_host._tensor.T, self.s_guest._tensor) return ComponentOutput(self.save_data(), self.export_model(), self.save_cache()) @@ -85,7 +92,6 @@ def _init_model(self): def run(self, cpn_input): ctx = Context.from_cpn_input(cpn_input) - ctx.device_init() LOGGER.info(ctx.describe()) return self._run(ctx, cpn_input) @@ -94,7 +100,7 @@ def _run(self, ctx: Context, cpn_input): self._init_runtime_parameters(cpn_input) LOGGER.info("begin to make host data") - self.b = ctx.create_tensor(torch.rand((self.data_num, self.feature_num))) + self.b = ctx.random_tensor((self.data_num, self.feature_num)) with ctx.iter_namespaces(10, prefix_name="tree_") as iteration: for i, _ in enumerate(iteration): @@ -113,7 +119,9 @@ def _run(self, ctx: Context, cpn_input): self.ea = ctx.pull(GUEST, "guest_cipher").unwrap_phe_tensor() LOGGER.info("begin to get matmul of host and guest") - self.es_host = self.b.T @ self.ea + self.es_host = self.b + self.ea + # LOGGER.info("begin to get matmul of host and guest") + # self.es_host = self.b.T @ self.ea LOGGER.info("send encrypted matmul to guest") ctx.push(GUEST, "host_matmul_encrypted", self.es_host) diff --git a/rust/tensor/rust_paillier/src/lib.rs b/rust/tensor/rust_paillier/src/lib.rs index 9ba858324a..d08aa2b4ca 100644 --- a/rust/tensor/rust_paillier/src/lib.rs +++ b/rust/tensor/rust_paillier/src/lib.rs @@ -19,8 +19,26 @@ pub struct Cipherblock(Option); #[pyclass(module = "rust_paillier")] pub struct PK { - pk: fixedpoint::PK, + pk: Option, } +#[pymethods] +impl PK { + #[new] + fn __new__() -> Self { + PK(None) + } + pub fn __getstate__(&self, py: Python) -> PyResult { + Ok(PyBytes::new(py, &serialize(&self.0).unwrap()).to_object(py)) + } + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + match state.extract::<&PyBytes>(py) { + Ok(s) => { + self.0 = deserialize(s.as_bytes()).unwrap(); + Ok(()) + } + Err(e) => Err(e), + } + } #[pyclass(module = "rust_paillier")] pub struct SK { From a9f91736c2c11ea9c2ef74810e813af58c24332a Mon Sep 17 00:00:00 2001 From: weiwee Date: Wed, 27 Jul 2022 22:55:17 -0800 Subject: [PATCH 06/11] format: makes pep8 happy Signed-off-by: weiwee --- python/fate_arch/tensor/_tensor.py | 4 +++- python/fate_arch/tensor/impl/tensor/distributed.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/fate_arch/tensor/_tensor.py b/python/fate_arch/tensor/_tensor.py index d0ee9e8b32..aed01120c9 100644 --- a/python/fate_arch/tensor/_tensor.py +++ b/python/fate_arch/tensor/_tensor.py @@ -301,7 +301,7 @@ def keygen(self, kind, key_length: int, **kwargs): f"keygen for kind<{kind}>-distributed<{self.distributed}>-device<{self.device}> is not implemented" ) - def random_tensor(self, shape, num_partition = 1) -> "FPTensor": + def random_tensor(self, shape, num_partition=1) -> "FPTensor": if self.distributed == Distributed.NONE: return FPTensor(self, torch.rand(shape)) else: @@ -493,6 +493,7 @@ def __federation_hook__(self, ctx, key, parties): # 2. remote table ctx._push(parties, deserializer.table_key, self._tensor) + class PHETensor: def __init__(self, ctx: Context, tensor: PHETensorABC) -> None: self._tensor = tensor @@ -565,6 +566,7 @@ def do_deserialize(self, ctx: Context, party: Party) -> PHETensor: tensor = ctx._pull([party], self.table_key)[0] return PHETensor(ctx, tensor) + class FPTensorFederationDeserializer(FederationDeserializer): def __init__(self, key) -> None: self.table_key = self.make_frac_key(key, "tensor") diff --git a/python/fate_arch/tensor/impl/tensor/distributed.py b/python/fate_arch/tensor/impl/tensor/distributed.py index f121d6d1bd..762715bd33 100644 --- a/python/fate_arch/tensor/impl/tensor/distributed.py +++ b/python/fate_arch/tensor/impl/tensor/distributed.py @@ -83,6 +83,7 @@ def __federation_hook__(self, ctx, key, parties): # 2. remote table ctx._push(parties, deserializer.table_key, self._blocks_table) + class PHETensorDistributed(PHETensorABC): def __init__(self, blocks_table) -> None: """ @@ -182,6 +183,7 @@ def __federation_hook__(self, ctx, key, parties): # 2. remote table ctx._push(parties, deserializer.table_key, self._blocks_table) + class PaillierPHEEncryptorDistributed(PHEEncryptorABC): def __init__(self, block_encryptor) -> None: self._block_encryptor = block_encryptor @@ -215,6 +217,7 @@ def keygen( PaillierPHEDecryptorDistributed(block_decryptor), ) + class PHETensorFederationDeserializer(FederationDeserializer): def __init__(self, key, is_transpose) -> None: self.table_key = self.make_frac_key(key, "table") @@ -226,6 +229,7 @@ def do_deserialize(self, ctx: Context, party: Party) -> PHETensorDistributed: tensor._is_transpose = self.is_transpose return tensor + class FPTensorFederationDeserializer(FederationDeserializer): def __init__(self, key) -> None: self.table_key = self.make_frac_key(key, "table") From 416999c7ff59eb897229094b89fa662025eacc68 Mon Sep 17 00:00:00 2001 From: weiwee Date: Tue, 2 Aug 2022 01:48:09 -0800 Subject: [PATCH 07/11] fix: pk/sk pickle and compare Signed-off-by: weiwee --- rust/tensor/rust_paillier/src/cb.rs | 5 +- rust/tensor/rust_paillier/src/lib.rs | 83 ++++++++++++++----- rust/tensor/rust_paillier/src/par/cb.rs | 9 +-- rust/tensor/rust_paillier/src/par/mod.rs | 84 ++++++++++++++++++-- rust/tensor/rust_paillier/tests/test_base.py | 7 ++ 5 files changed, 155 insertions(+), 33 deletions(-) diff --git a/rust/tensor/rust_paillier/src/cb.rs b/rust/tensor/rust_paillier/src/cb.rs index 08c2a9caf8..56f12eef54 100644 --- a/rust/tensor/rust_paillier/src/cb.rs +++ b/rust/tensor/rust_paillier/src/cb.rs @@ -116,7 +116,6 @@ impl Cipherblock { block::Cipherblock::rmatmul_plaintext_ix2, ArrayView2 ); - } impl Cipherblock { @@ -146,12 +145,12 @@ impl Cipherblock { impl SK { pub fn decrypt_array(&self, a: &Cipherblock) -> ArrayD { let array = a.0.as_ref().unwrap(); - self.sk.decrypt_array(array) + self.as_ref().decrypt_array(array) } } impl PK { pub fn encrypt_array(&self, array: ArrayViewD) -> Cipherblock { - Cipherblock::new(self.pk.encrypt_array(array)) + Cipherblock::new(self.as_ref().encrypt_array(array)) } } diff --git a/rust/tensor/rust_paillier/src/lib.rs b/rust/tensor/rust_paillier/src/lib.rs index d08aa2b4ca..117d5f7076 100644 --- a/rust/tensor/rust_paillier/src/lib.rs +++ b/rust/tensor/rust_paillier/src/lib.rs @@ -7,6 +7,7 @@ mod par; use bincode::{deserialize, serialize}; use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArrayDyn}; +use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::PyBytes; @@ -21,34 +22,33 @@ pub struct Cipherblock(Option); pub struct PK { pk: Option, } -#[pymethods] impl PK { - #[new] - fn __new__() -> Self { - PK(None) + fn new(pk: fixedpoint::PK) -> Self { + Self { pk: Some(pk) } } - pub fn __getstate__(&self, py: Python) -> PyResult { - Ok(PyBytes::new(py, &serialize(&self.0).unwrap()).to_object(py)) - } - pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { - match state.extract::<&PyBytes>(py) { - Ok(s) => { - self.0 = deserialize(s.as_bytes()).unwrap(); - Ok(()) - } - Err(e) => Err(e), - } + fn as_ref(&self) -> &fixedpoint::PK { + self.pk.as_ref().unwrap() } +} #[pyclass(module = "rust_paillier")] pub struct SK { - sk: fixedpoint::SK, + sk: Option, +} + +impl SK { + fn new(sk: fixedpoint::SK) -> Self { + Self { sk: Some(sk) } + } + fn as_ref(&self) -> &fixedpoint::SK { + self.sk.as_ref().unwrap() + } } #[pyfunction] fn keygen(bit_size: u32) -> (PK, SK) { let (sk, pk) = fixedpoint::keygen(bit_size); - (PK { pk }, SK { sk }) + (PK::new(pk), SK::new(sk)) } /// public key for paillier system used to encrypt arrays @@ -56,6 +56,30 @@ fn keygen(bit_size: u32) -> (PK, SK) { /// Notes: we could not use Generics Types or rule macro here, sad. #[pymethods] impl PK { + #[new] + fn __new__() -> Self { + Self { pk: None } + } + pub fn __getstate__(&self, py: Python) -> PyResult { + Ok(PyBytes::new(py, &serialize(self.as_ref()).unwrap()).to_object(py)) + } + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + match state.extract::<&PyBytes>(py) { + Ok(s) => { + self.pk = Some(deserialize(s.as_bytes()).unwrap()); + Ok(()) + } + Err(e) => Err(e), + } + } + pub fn __richcmp__(&self, other: &PK, cmp: pyo3::basic::CompareOp) -> PyResult { + match cmp { + pyo3::basic::CompareOp::Eq => Ok(self.as_ref() == other.as_ref()), + _ => Err(PyTypeError::new_err( + "not supported between instances PK and PK", + )), + } + } fn encrypt_f64(&self, a: PyReadonlyArrayDyn) -> Cipherblock { self.encrypt_array(a.as_array()) } @@ -75,6 +99,30 @@ impl PK { /// Notes: we could not use Generics Types or rule macro here, sad. #[pymethods] impl SK { + #[new] + fn __new__() -> Self { + Self { sk: None } + } + pub fn __getstate__(&self, py: Python) -> PyResult { + Ok(PyBytes::new(py, &serialize(self.as_ref()).unwrap()).to_object(py)) + } + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + match state.extract::<&PyBytes>(py) { + Ok(s) => { + self.sk = Some(deserialize(s.as_bytes()).unwrap()); + Ok(()) + } + Err(e) => Err(e), + } + } + pub fn __richcmp__(&self, other: &SK, cmp: pyo3::basic::CompareOp) -> PyResult { + match cmp { + pyo3::basic::CompareOp::Eq => Ok(self.as_ref() == other.as_ref()), + _ => Err(PyTypeError::new_err( + "not supported between instances PK and PK", + )), + } + } fn decrypt_f64<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { self.decrypt_array(a).into_pyarray(py) } @@ -252,7 +300,6 @@ impl Cipherblock { pub fn mean(&self) -> Cipherblock { self.sum_cb() } - } #[pymodule] fn rust_paillier(_py: Python, m: &PyModule) -> PyResult<()> { diff --git a/rust/tensor/rust_paillier/src/par/cb.rs b/rust/tensor/rust_paillier/src/par/cb.rs index 33b902dc16..38a8193a5d 100644 --- a/rust/tensor/rust_paillier/src/par/cb.rs +++ b/rust/tensor/rust_paillier/src/par/cb.rs @@ -153,15 +153,12 @@ impl Cipherblock { impl SK { pub fn decrypt_array(&self, a: &Cipherblock) -> ArrayD { let array = a.0.as_ref().unwrap(); - self.sk.decrypt_array_par(array) + self.as_ref().decrypt_array_par(array) } } impl PK { - pub fn encrypt_array( - &self, - array: ArrayViewD, - ) -> Cipherblock { - Cipherblock::new(self.pk.encrypt_array_par(array)) + pub fn encrypt_array(&self, array: ArrayViewD) -> Cipherblock { + Cipherblock::new(self.as_ref().encrypt_array_par(array)) } } diff --git a/rust/tensor/rust_paillier/src/par/mod.rs b/rust/tensor/rust_paillier/src/par/mod.rs index db0d0178d5..81cb7cbd3b 100644 --- a/rust/tensor/rust_paillier/src/par/mod.rs +++ b/rust/tensor/rust_paillier/src/par/mod.rs @@ -1,9 +1,10 @@ +use crate::block; +use crate::fixedpoint; use bincode::{deserialize, serialize}; use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArray1, PyReadonlyArray2, PyReadonlyArrayDyn}; +use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use pyo3::types::PyBytes; -use crate::fixedpoint; -use crate::block; mod cb; @@ -12,27 +13,71 @@ pub struct Cipherblock(Option); #[pyclass(module = "rust_paillier.par")] pub struct PK { - pk: fixedpoint::PK, + pk: Option, +} +impl PK { + fn new(pk: fixedpoint::PK) -> Self { + Self { pk: Some(pk) } + } + fn as_ref(&self) -> &fixedpoint::PK { + self.pk.as_ref().unwrap() + } } #[pyclass(module = "rust_paillier.par")] pub struct SK { - sk: fixedpoint::SK, + sk: Option, +} + +impl SK { + fn new(sk: fixedpoint::SK) -> Self { + Self { sk: Some(sk) } + } + fn as_ref(&self) -> &fixedpoint::SK { + self.sk.as_ref().unwrap() + } } #[pyfunction] fn keygen(bit_size: u32) -> (PK, SK) { let (sk, pk) = fixedpoint::keygen(bit_size); - (PK { pk }, SK { sk }) + (PK::new(pk), SK::new(sk)) } #[pyfunction] fn set_num_threads(num_threads: usize) { - rayon::ThreadPoolBuilder::new().num_threads(num_threads).build_global().unwrap(); + rayon::ThreadPoolBuilder::new() + .num_threads(num_threads) + .build_global() + .unwrap(); } #[pymethods] impl PK { + #[new] + fn __new__() -> Self { + Self { pk: None } + } + pub fn __getstate__(&self, py: Python) -> PyResult { + Ok(PyBytes::new(py, &serialize(self.as_ref()).unwrap()).to_object(py)) + } + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + match state.extract::<&PyBytes>(py) { + Ok(s) => { + self.pk = Some(deserialize(s.as_bytes()).unwrap()); + Ok(()) + } + Err(e) => Err(e), + } + } + pub fn __richcmp__(&self, other: &PK, cmp: pyo3::basic::CompareOp) -> PyResult { + match cmp { + pyo3::basic::CompareOp::Eq => Ok(self.as_ref() == other.as_ref()), + _ => Err(PyTypeError::new_err( + "not supported between instances PK and PK", + )), + } + } fn encrypt_f64(&self, a: PyReadonlyArrayDyn) -> Cipherblock { self.encrypt_array(a.as_array()) } @@ -49,6 +94,30 @@ impl PK { #[pymethods] impl SK { + #[new] + fn __new__() -> Self { + Self { sk: None } + } + pub fn __getstate__(&self, py: Python) -> PyResult { + Ok(PyBytes::new(py, &serialize(self.as_ref()).unwrap()).to_object(py)) + } + pub fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { + match state.extract::<&PyBytes>(py) { + Ok(s) => { + self.sk = Some(deserialize(s.as_bytes()).unwrap()); + Ok(()) + } + Err(e) => Err(e), + } + } + pub fn __richcmp__(&self, other: &SK, cmp: pyo3::basic::CompareOp) -> PyResult { + match cmp { + pyo3::basic::CompareOp::Eq => Ok(self.as_ref() == other.as_ref()), + _ => Err(PyTypeError::new_err( + "not supported between instances PK and PK", + )), + } + } fn decrypt_f64<'py>(&self, a: &Cipherblock, py: Python<'py>) -> &'py PyArrayDyn { self.decrypt_array(a).into_pyarray(py) } @@ -227,6 +296,9 @@ pub(crate) fn register(py: Python, m: &PyModule) -> PyResult<()> { let submodule_par = PyModule::new(py, "par")?; submodule_par.add_function(wrap_pyfunction!(keygen, submodule_par)?)?; submodule_par.add_function(wrap_pyfunction!(set_num_threads, submodule_par)?)?; + submodule_par.add_class::()?; + submodule_par.add_class::()?; + submodule_par.add_class::()?; m.add_submodule(submodule_par)?; py.import("sys")? .getattr("modules")? diff --git a/rust/tensor/rust_paillier/tests/test_base.py b/rust/tensor/rust_paillier/tests/test_base.py index 539cc18157..4163cb1b76 100644 --- a/rust/tensor/rust_paillier/tests/test_base.py +++ b/rust/tensor/rust_paillier/tests/test_base.py @@ -4,6 +4,7 @@ import cachetools import numpy as np import pytest +import pickle def get_suites(): @@ -66,6 +67,12 @@ def data(fp, index, shape=(3, 5), scalar=False) -> np.ndarray: return np.random.randint(low=-100, high=100, size=1, dtype=np.int32)[0] +@pytest.mark.parametrize("fp", ["f64"]) +def test_serde(suite: Suite, fp): + assert suite.pk == pickle.loads(pickle.dumps(suite.pk)) + assert suite.sk == pickle.loads(pickle.dumps(suite.sk)) + + @pytest.mark.parametrize("fp", ["f64", "f32", "i32", "i64"]) def test_cipher(suite: Suite, fp): e = suite.decrypt(fp, suite.encrypt(fp, data(fp, 0))) From 91686f808e03e677a5b6acd2998879a6ee57b651 Mon Sep 17 00:00:00 2001 From: weiwee Date: Tue, 2 Aug 2022 02:47:18 -0800 Subject: [PATCH 08/11] feat: add shape Signed-off-by: weiwee --- python/fate_arch/tensor/impl/tensor/_metaclass.py | 6 ++++++ rust/tensor/rust_paillier/rust_paillier/__init__.pyi | 2 ++ rust/tensor/rust_paillier/rust_paillier/par/__init__.pyi | 2 ++ rust/tensor/rust_paillier/src/lib.rs | 5 ++++- rust/tensor/rust_paillier/src/par/mod.rs | 4 ++++ 5 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/fate_arch/tensor/impl/tensor/_metaclass.py b/python/fate_arch/tensor/impl/tensor/_metaclass.py index f27166efab..45b8a4a11f 100644 --- a/python/fate_arch/tensor/impl/tensor/_metaclass.py +++ b/python/fate_arch/tensor/impl/tensor/_metaclass.py @@ -19,6 +19,12 @@ def __init__(self, block) -> None: setattr(phe_cls, "__init__", __init__) + @property + def shape(self): + return self._block.shape + + setattr(phe_cls, "shape", shape) + @property def T(self) -> phe_cls: transposed = phe_cls(self._block) diff --git a/rust/tensor/rust_paillier/rust_paillier/__init__.pyi b/rust/tensor/rust_paillier/rust_paillier/__init__.pyi index a4a8509100..f2c6525150 100644 --- a/rust/tensor/rust_paillier/rust_paillier/__init__.pyi +++ b/rust/tensor/rust_paillier/rust_paillier/__init__.pyi @@ -4,6 +4,8 @@ import numpy as np import numpy.typing as npt class Cipherblock: + @property + def shape(self) -> typing.List[int]: ... def add_cipherblock(self, other: Cipherblock) -> Cipherblock: ... def add_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... def add_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... diff --git a/rust/tensor/rust_paillier/rust_paillier/par/__init__.pyi b/rust/tensor/rust_paillier/rust_paillier/par/__init__.pyi index 7de2099ddb..1b24c95938 100644 --- a/rust/tensor/rust_paillier/rust_paillier/par/__init__.pyi +++ b/rust/tensor/rust_paillier/rust_paillier/par/__init__.pyi @@ -4,6 +4,8 @@ import numpy as np import numpy.typing as npt class Cipherblock: + @property + def shape(self) -> typing.List[int]: ... def add_cipherblock(self, other: Cipherblock) -> Cipherblock: ... def add_plaintext_f64(self, other: npt.NDArray[np.float64]) -> Cipherblock: ... def add_plaintext_f32(self, other: npt.NDArray[np.float32]) -> Cipherblock: ... diff --git a/rust/tensor/rust_paillier/src/lib.rs b/rust/tensor/rust_paillier/src/lib.rs index 117d5f7076..dfeb5116c8 100644 --- a/rust/tensor/rust_paillier/src/lib.rs +++ b/rust/tensor/rust_paillier/src/lib.rs @@ -158,7 +158,10 @@ impl Cipherblock { Err(e) => Err(e), } } - + #[getter] + pub fn shape(&self) -> Vec { + self.0.as_ref().map(|cb| cb.shape.clone()).unwrap() + } // add pub fn add_cipherblock(&self, other: &Cipherblock) -> Cipherblock { self.add_cb(other) diff --git a/rust/tensor/rust_paillier/src/par/mod.rs b/rust/tensor/rust_paillier/src/par/mod.rs index 81cb7cbd3b..3c85d1bbd2 100644 --- a/rust/tensor/rust_paillier/src/par/mod.rs +++ b/rust/tensor/rust_paillier/src/par/mod.rs @@ -150,6 +150,10 @@ impl Cipherblock { Err(e) => Err(e), } } + #[getter] + pub fn shape(&self) -> Vec { + self.0.as_ref().map(|cb| cb.shape.clone()).unwrap() + } // add pub fn add_cipherblock(&self, other: &Cipherblock) -> Cipherblock { self.add_cb(other) From b16eddad0c8c05ae34abccde843381d7d226abc3 Mon Sep 17 00:00:00 2001 From: weiwee Date: Thu, 4 Aug 2022 22:34:43 -0800 Subject: [PATCH 09/11] feat: add shape Signed-off-by: weiwee --- python/fate_arch/tensor/_tensor.py | 14 +++++++-- python/fate_arch/tensor/abc/tensor.py | 25 +++++++++++----- .../tensor/impl/blocks/_metaclass.py | 14 +++++---- .../tensor/impl/tensor/distributed.py | 30 ++++++++++++------- python/federatedml/ml/toy/enterpoint.py | 2 ++ 5 files changed, 59 insertions(+), 26 deletions(-) diff --git a/python/fate_arch/tensor/_tensor.py b/python/fate_arch/tensor/_tensor.py index aed01120c9..78f626ae0e 100644 --- a/python/fate_arch/tensor/_tensor.py +++ b/python/fate_arch/tensor/_tensor.py @@ -73,7 +73,7 @@ class Distributed(Enum): class Future: """ - get maybe async in future, in this version, + `get` maybe async in future, in this version, we wrap obj to support explicit typing and check """ @@ -312,9 +312,9 @@ def random_tensor(self, shape, num_partition=1) -> "FPTensor": last_dim = shape[-1] for i in range(num_partition): if i == num_partition - 1: - parts.append(torch.tensor((*shape[:-1], last_dim))) + parts.append(torch.rand((*shape[:-1], last_dim))) else: - parts.append(torch.tensor((*shape[:-1], shape[-1] / num_partition))) + parts.append(torch.rand((*shape[:-1], shape[-1] / num_partition))) last_dim -= shape[-1] / num_partition return FPTensor( self, @@ -428,6 +428,10 @@ def __init__(self, ctx: Context, tensor) -> None: self._ctx = ctx self._tensor = tensor + @property + def shape(self): + return self._tensor.shape + def __add__(self, other: Union["FPTensor", float, int]) -> "FPTensor": if not hasattr(self._tensor, "__add__"): return NotImplemented @@ -499,6 +503,10 @@ def __init__(self, ctx: Context, tensor: PHETensorABC) -> None: self._tensor = tensor self._ctx = ctx + @property + def shape(self): + return self._tensor.shape + def __add__(self, other: Union["PHETensor", FPTensor, int, float]) -> "PHETensor": return self._binary_op(other, self._tensor.__add__) diff --git a/python/fate_arch/tensor/abc/tensor.py b/python/fate_arch/tensor/abc/tensor.py index 843a6fad5c..f8f86716b9 100644 --- a/python/fate_arch/tensor/abc/tensor.py +++ b/python/fate_arch/tensor/abc/tensor.py @@ -4,23 +4,34 @@ class FPTensorProtocol(Protocol): - - def __add__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": + def __add__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": ... - def __radd__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": + def __radd__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": ... - def __sub__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": + def __sub__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": ... - def __rsub__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": + def __rsub__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": ... - def __mul__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": + def __mul__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": ... - def __rmul__(self, other: typing.Union["FPTensorProtocol", float, int]) -> "FPTensorProtocol": + def __rmul__( + self, other: typing.Union["FPTensorProtocol", float, int] + ) -> "FPTensorProtocol": ... def __matmul__(self, other: "FPTensorProtocol") -> "FPTensorProtocol": diff --git a/python/fate_arch/tensor/impl/blocks/_metaclass.py b/python/fate_arch/tensor/impl/blocks/_metaclass.py index 9797a26865..eee79ceb78 100644 --- a/python/fate_arch/tensor/impl/blocks/_metaclass.py +++ b/python/fate_arch/tensor/impl/blocks/_metaclass.py @@ -100,9 +100,7 @@ def __new__(cls, name, bases, dict): def phe_decryptor_metaclass(pheblock_cls, fpblock_cls): class PHEDecryptorMetaclass(type): def __new__(cls, name, bases, dict): - decryptor_cls = super().__new__( - cls, name, bases, dict - ) + decryptor_cls = super().__new__(cls, name, bases, dict) setattr(decryptor_cls, "__init__", _impl_decryptor_init()) setattr( @@ -132,9 +130,7 @@ def _decrypt_numpy(sk, cb, dtype): def phe_encryptor_metaclass(pheblock_cls, fpblock_cls): class PHEEncryptorMetaclass(type): def __new__(cls, name, bases, dict): - encryptor_cls = super().__new__( - cls, name, bases, dict - ) + encryptor_cls = super().__new__(cls, name, bases, dict) setattr(encryptor_cls, "__init__", _impl_encryptor_init()) setattr( @@ -167,6 +163,12 @@ def __new__(cls, name, bases, dict): class_obj = super().__new__(cls, name, bases, dict) setattr(class_obj, "__init__", _impl_init()) + + @property + def shape(self): + return self._cb.shape + + setattr(class_obj, "shape", shape) _maybe_setattr(class_obj, "serialize", _impl_serialize()) for impl_name, ops in { "__add__": PHEBlockMetaclass._add, diff --git a/python/fate_arch/tensor/impl/tensor/distributed.py b/python/fate_arch/tensor/impl/tensor/distributed.py index 762715bd33..84cd69184a 100644 --- a/python/fate_arch/tensor/impl/tensor/distributed.py +++ b/python/fate_arch/tensor/impl/tensor/distributed.py @@ -19,12 +19,19 @@ class FPTensorDistributed(FPTensorProtocol): Demo of Distributed Fixed Presicion Tensor """ - def __init__(self, blocks_table): + def __init__(self, blocks_table, shape=None): """ use table to store blocks in format (blockid, block) """ self._blocks_table = blocks_table + # assume block is verticel aranged + if shape is None: + shapes = list(self._blocks_table.mapValues(lambda cb: cb.shape).collect()) + self.shape = (sum(s[0] for s in shapes), shapes[0][1]) + else: + self.shape = shape + def _binary_op(self, other, func_name): if isinstance(other, FPTensorDistributed): return FPTensorDistributed( @@ -68,11 +75,11 @@ def __rmul__( ) -> "FPTensorDistributed": return self._binary_op(other, "__rmul__") - def __matmul__(self, other: "FPTensorDistributed") -> "FPTensorDistributed": + def __matmul__(self, other: "PHETensorDistributed") -> "PHETensorDistributed": # todo: fix ... - def __rmatmul__(self, other: "FPTensorDistributed") -> "FPTensorDistributed": + def __rmatmul__(self, other: "PHETensorDistributed") -> "FPTensorDistributed": # todo: fix ... @@ -85,13 +92,20 @@ def __federation_hook__(self, ctx, key, parties): class PHETensorDistributed(PHETensorABC): - def __init__(self, blocks_table) -> None: + def __init__(self, blocks_table, shape=None): """ use table to store blocks in format (blockid, encrypted_block) """ self._blocks_table = blocks_table self._is_transpose = False + # assume block is verticel aranged + if shape is None: + shapes = list(self._blocks_table.mapValues(lambda cb: cb.shape).collect()) + self.shape = (sum(s[1][0] for s in shapes), shapes[0][1][1]) + else: + self.shape = shape + def __add__( self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] ) -> "PHETensorDistributed": @@ -122,15 +136,11 @@ def __rmul__( ) -> "PHETensorDistributed": return self._binary_op_limited(other, "__rmul__") - def __matmul__( - self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] - ) -> "PHETensorDistributed": + def __matmul__(self, other: FPTensorDistributed) -> "PHETensorDistributed": # TODO: impl me ... - def __rmatmul__( - self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] - ) -> "PHETensorDistributed": + def __rmatmul__(self, other: FPTensorDistributed) -> "PHETensorDistributed": # TODO: impl me ... diff --git a/python/federatedml/ml/toy/enterpoint.py b/python/federatedml/ml/toy/enterpoint.py index 39c37ec84d..e6a04c9ca1 100644 --- a/python/federatedml/ml/toy/enterpoint.py +++ b/python/federatedml/ml/toy/enterpoint.py @@ -33,12 +33,14 @@ def _run(self, ctx: Context, cpn_input): LOGGER.info("begin to make guest data") self.a = ctx.random_tensor((self.data_num, self.feature_num)) + LOGGER.info(f"shape of a: {self.a.shape}") LOGGER.info("keygen") self.pk, self.sk = ctx.keygen(CipherKind.PHE, 1024) LOGGER.info("encrypt data") self.ea = self.pk.encrypt(self.a) + LOGGER.info(f"shape of ea: {self.ea.shape}") LOGGER.info("share encrypted data to host") ctx.push(HOST, "guest_cipher", self.ea) From 24c4c06328d34214e17a076ab91a4e04bc3891de Mon Sep 17 00:00:00 2001 From: weiwee Date: Mon, 19 Sep 2022 00:29:11 -0800 Subject: [PATCH 10/11] fix: fix random Signed-off-by: weiwee --- python/fate_arch/tensor/_tensor.py | 9 +++-- .../tensor/impl/tensor/_metaclass.py | 11 +++++- .../tensor/impl/tensor/distributed.py | 38 +++++++++++++++++-- python/federatedml/ml/toy/enterpoint.py | 2 - 4 files changed, 49 insertions(+), 11 deletions(-) diff --git a/python/fate_arch/tensor/_tensor.py b/python/fate_arch/tensor/_tensor.py index 78f626ae0e..0fc5dceea1 100644 --- a/python/fate_arch/tensor/_tensor.py +++ b/python/fate_arch/tensor/_tensor.py @@ -309,13 +309,14 @@ def random_tensor(self, shape, num_partition=1) -> "FPTensor": from fate_arch.tensor.impl.tensor.distributed import FPTensorDistributed parts = [] - last_dim = shape[-1] + first_dim_approx = shape[0] // num_partition + last_part_first_dim = shape[0] - (num_partition - 1) * first_dim_approx + assert first_dim_approx > 0 for i in range(num_partition): if i == num_partition - 1: - parts.append(torch.rand((*shape[:-1], last_dim))) + parts.append(torch.rand((last_part_first_dim, *shape[1:],))) else: - parts.append(torch.rand((*shape[:-1], shape[-1] / num_partition))) - last_dim -= shape[-1] / num_partition + parts.append(torch.rand((first_dim_approx, *shape[1:]))) return FPTensor( self, FPTensorDistributed( diff --git a/python/fate_arch/tensor/impl/tensor/_metaclass.py b/python/fate_arch/tensor/impl/tensor/_metaclass.py index 45b8a4a11f..6a4e2c97bc 100644 --- a/python/fate_arch/tensor/impl/tensor/_metaclass.py +++ b/python/fate_arch/tensor/impl/tensor/_metaclass.py @@ -8,10 +8,19 @@ ) +class Local: + @property + def block(self): + ... + + def is_distributed(self): + return False + + def phe_tensor_metaclass(fp_cls): class PHETensorMetaclass(type): def __new__(cls, name, bases, dict): - phe_cls = super().__new__(cls, name, bases, dict) + phe_cls = super().__new__(cls, name, (*bases, Local), dict) def __init__(self, block) -> None: self._block = block diff --git a/python/fate_arch/tensor/impl/tensor/distributed.py b/python/fate_arch/tensor/impl/tensor/distributed.py index 84cd69184a..e430daf331 100644 --- a/python/fate_arch/tensor/impl/tensor/distributed.py +++ b/python/fate_arch/tensor/impl/tensor/distributed.py @@ -1,6 +1,8 @@ import typing from typing import Union +import torch + from ...abc.tensor import ( FPTensorProtocol, PHECipherABC, @@ -10,11 +12,21 @@ ) from ..._federation import FederationDeserializer from ..._tensor import Context, Party +from ....abc._computing import CTableABC Numeric = typing.Union[int, float] -class FPTensorDistributed(FPTensorProtocol): +class Distributed: + @property + def blocks(self) -> CTableABC: + ... + + def is_distributed(self): + return True + + +class FPTensorDistributed(FPTensorProtocol, Distributed): """ Demo of Distributed Fixed Presicion Tensor """ @@ -25,13 +37,17 @@ def __init__(self, blocks_table, shape=None): """ self._blocks_table = blocks_table - # assume block is verticel aranged + # assuming blocks are arranged vertically if shape is None: shapes = list(self._blocks_table.mapValues(lambda cb: cb.shape).collect()) self.shape = (sum(s[0] for s in shapes), shapes[0][1]) else: self.shape = shape + @property + def blocks(self): + return self._blocks_table + def _binary_op(self, other, func_name): if isinstance(other, FPTensorDistributed): return FPTensorDistributed( @@ -45,6 +61,10 @@ def _binary_op(self, other, func_name): ) return NotImplemented + def collect(self): + blocks = sorted(self._blocks_table.collect()) + return torch.cat([pair[1] for pair in blocks]) + def __add__( self, other: Union["FPTensorDistributed", int, float] ) -> "FPTensorDistributed": @@ -76,8 +96,14 @@ def __rmul__( return self._binary_op(other, "__rmul__") def __matmul__(self, other: "PHETensorDistributed") -> "PHETensorDistributed": - # todo: fix - ... + assert self.shape[1] == other.shape[0] + # support one dimension only + assert len(other.shape) == 1 + + def func(cb): + return cb @ other._blocks_table.collect() + + self._blocks_table.mapValues() def __rmatmul__(self, other: "PHETensorDistributed") -> "FPTensorDistributed": # todo: fix @@ -106,6 +132,10 @@ def __init__(self, blocks_table, shape=None): else: self.shape = shape + def collect(self): + blocks = sorted(self._blocks_table.collect()) + return torch.cat([pair[1] for pair in blocks]) + def __add__( self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] ) -> "PHETensorDistributed": diff --git a/python/federatedml/ml/toy/enterpoint.py b/python/federatedml/ml/toy/enterpoint.py index e6a04c9ca1..39c37ec84d 100644 --- a/python/federatedml/ml/toy/enterpoint.py +++ b/python/federatedml/ml/toy/enterpoint.py @@ -33,14 +33,12 @@ def _run(self, ctx: Context, cpn_input): LOGGER.info("begin to make guest data") self.a = ctx.random_tensor((self.data_num, self.feature_num)) - LOGGER.info(f"shape of a: {self.a.shape}") LOGGER.info("keygen") self.pk, self.sk = ctx.keygen(CipherKind.PHE, 1024) LOGGER.info("encrypt data") self.ea = self.pk.encrypt(self.a) - LOGGER.info(f"shape of ea: {self.ea.shape}") LOGGER.info("share encrypted data to host") ctx.push(HOST, "guest_cipher", self.ea) From 5a5c38e94c6844de09f1911e081f9b6cec3cd24e Mon Sep 17 00:00:00 2001 From: weiwee Date: Mon, 19 Sep 2022 00:30:50 -0800 Subject: [PATCH 11/11] feat: add ops Signed-off-by: weiwee --- .../tensor/impl/tensor/row_distributed.py | 240 ++++++++++++++++++ python/fate_arch/tensor/ops/__init__.py | 3 + 2 files changed, 243 insertions(+) create mode 100644 python/fate_arch/tensor/impl/tensor/row_distributed.py create mode 100644 python/fate_arch/tensor/ops/__init__.py diff --git a/python/fate_arch/tensor/impl/tensor/row_distributed.py b/python/fate_arch/tensor/impl/tensor/row_distributed.py new file mode 100644 index 0000000000..89a19abbb8 --- /dev/null +++ b/python/fate_arch/tensor/impl/tensor/row_distributed.py @@ -0,0 +1,240 @@ +import typing +from typing import Union + +from ...abc.tensor import ( + FPTensorProtocol, + PHECipherABC, + PHEDecryptorABC, + PHEEncryptorABC, + PHETensorABC, +) +from ..._federation import FederationDeserializer +from ..._tensor import Context, Party + +Numeric = typing.Union[int, float] + + +class FPTensorDistributed(FPTensorProtocol): + """ + Demo of Distributed Fixed Presicion Tensor + """ + + def __init__(self, blocks_table): + """ + use table to store blocks in format (blockid, block) + """ + self._blocks_table = blocks_table + + def _binary_op(self, other, func_name): + if isinstance(other, FPTensorDistributed): + return FPTensorDistributed( + other._blocks_table.join( + self._blocks_table, lambda x, y: getattr(x, func_name)(y) + ) + ) + elif isinstance(other, (int, float)): + return FPTensorDistributed( + self._blocks_table.mapValues(lambda x: getattr(x, func_name)(other)) + ) + return NotImplemented + + def __add__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__add__") + + def __radd__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__radd__") + + def __sub__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__sub__") + + def __rsub__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__rsub__") + + def __mul__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__mul__") + + def __rmul__( + self, other: Union["FPTensorDistributed", int, float] + ) -> "FPTensorDistributed": + return self._binary_op(other, "__rmul__") + + def __matmul__(self, other: "FPTensorDistributed") -> "FPTensorDistributed": + # todo: fix + ... + + def __rmatmul__(self, other: "FPTensorDistributed") -> "FPTensorDistributed": + # todo: fix + ... + + def __federation_hook__(self, ctx, key, parties): + deserializer = FPTensorFederationDeserializer(key) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._blocks_table) + + +class PHETensorDistributed(PHETensorABC): + def __init__(self, blocks_table) -> None: + """ + use table to store blocks in format (blockid, encrypted_block) + """ + self._blocks_table = blocks_table + self._is_transpose = False + + def __add__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__add__") + + def __radd__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__radd__") + + def __sub__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__sub__") + + def __rsub__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op(other, "__rsub__") + + def __mul__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op_limited(other, "__mul__") + + def __rmul__( + self, other: Union["PHETensorDistributed", FPTensorDistributed, int, float] + ) -> "PHETensorDistributed": + return self._binary_op_limited(other, "__rmul__") + + def __matmul__( + self, other: FPTensorDistributed + ) -> "PHETensorDistributed": + # TODO: impl me + ... + + def __rmatmul__( + self, other: FPTensorDistributed + ) -> "PHETensorDistributed": + # TODO: impl me + ... + + def T(self) -> "PHETensorDistributed": + transposed = PHETensorDistributed(self._blocks_table) + transposed._is_transpose = not self._is_transpose + return transposed + + def serialize(self): + return self._blocks_table + + def deserialize(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __getstates__(self): + return {"_is_transpose": self._is_transpose} + + def _binary_op(self, other, func_name): + if isinstance(other, (FPTensorDistributed, PHETensorDistributed)): + return PHETensorDistributed( + self._blocks_table.join( + other._blocks_table, lambda x, y: getattr(x, func_name)(y) + ) + ) + elif isinstance(other, (int, float)): + return PHETensorDistributed( + self._blocks_table.mapValues(lambda x: x.__add__(other)) + ) + + return NotImplemented + + def _binary_op_limited(self, other, func_name): + if isinstance(other, FPTensorDistributed): + return PHETensorDistributed( + self._blocks_table.join( + other._blocks_table, lambda x, y: getattr(x, func_name)(y) + ) + ) + elif isinstance(other, (int, float)): + return PHETensorDistributed( + self._blocks_table.mapValues(lambda x: x.__add__(other)) + ) + return NotImplemented + + def __federation_hook__(self, ctx, key, parties): + deserializer = PHETensorFederationDeserializer(key, self._is_transpose) + # 1. remote deserializer with objs + ctx._push(parties, key, deserializer) + # 2. remote table + ctx._push(parties, deserializer.table_key, self._blocks_table) + + +class PaillierPHEEncryptorDistributed(PHEEncryptorABC): + def __init__(self, block_encryptor) -> None: + self._block_encryptor = block_encryptor + + def encrypt(self, tensor: FPTensorDistributed) -> PHETensorDistributed: + return PHETensorDistributed( + tensor._blocks_table.mapValues(lambda x: self._block_encryptor.encrypt(x)) + ) + + +class PaillierPHEDecryptorDistributed(PHEDecryptorABC): + def __init__(self, block_decryptor) -> None: + self._block_decryptor = block_decryptor + + def decrypt(self, tensor: PHETensorDistributed) -> FPTensorDistributed: + return FPTensorDistributed( + tensor._blocks_table.mapValues(lambda x: self._block_decryptor.decrypt(x)) + ) + + +class PaillierPHECipherDistributed(PHECipherABC): + @classmethod + def keygen( + cls, **kwargs + ) -> typing.Tuple[PaillierPHEEncryptorDistributed, PaillierPHEDecryptorDistributed]: + from ..blocks.cpu_paillier_block import BlockPaillierCipher + + block_encrytor, block_decryptor = BlockPaillierCipher.keygen(**kwargs) + return ( + PaillierPHEEncryptorDistributed(block_encrytor), + PaillierPHEDecryptorDistributed(block_decryptor), + ) + + +class PHETensorFederationDeserializer(FederationDeserializer): + def __init__(self, key, is_transpose) -> None: + self.table_key = self.make_frac_key(key, "table") + self.is_transpose = is_transpose + + def do_deserialize(self, ctx: Context, party: Party) -> PHETensorDistributed: + table = ctx._pull([party], self.table_key)[0] + tensor = PHETensorDistributed(table) + tensor._is_transpose = self.is_transpose + return tensor + + +class FPTensorFederationDeserializer(FederationDeserializer): + def __init__(self, key) -> None: + self.table_key = self.make_frac_key(key, "table") + + def do_deserialize(self, ctx: Context, party: Party) -> FPTensorDistributed: + table = ctx._pull([party], self.table_key)[0] + tensor = FPTensorDistributed(table) + return tensor diff --git a/python/fate_arch/tensor/ops/__init__.py b/python/fate_arch/tensor/ops/__init__.py new file mode 100644 index 0000000000..2144eddcc9 --- /dev/null +++ b/python/fate_arch/tensor/ops/__init__.py @@ -0,0 +1,3 @@ + +def broadcast_matmul(matrix, bc_matrix): + return matrix.blocks.mapValues(lambda cb: cb @ bc_matrix)