Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve context and tensor implemention #4133

Open
wants to merge 11 commits into
base: feature/heterogeneous-acceleration-architecture
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Build FATE-CPU-Tensor
name: Build Rust Paillier

on:
workflow_dispatch:
Expand Down Expand Up @@ -44,13 +44,13 @@ jobs:
with:
manylinux: auto
command: build
args: --release -o dist -m rust/fate_tensor/Cargo.toml
args: --release -o dist -m rust/tensor/rust_paillier/Cargo.toml
- name: macos-maturin
if: matrix.os == 'macos'
uses: messense/maturin-action@v1
with:
command: build
args: --release --no-sdist -o dist -m rust/fate_tensor/Cargo.toml
args: --release --no-sdist -o dist -m rust/tensor/rust_paillier/Cargo.toml
- name: Upload wheels
uses: actions/upload-artifact@v2
with:
Expand Down
15 changes: 8 additions & 7 deletions python/fate_arch/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from ._dataloader import LabeledDataloaderWrapper, UnlabeledDataloaderWrapper
from ._federation import ARBITER, GUEST, HOST
from ._context import Context, CipherKind
from ._tensor import (
FPTensor,
PHETensor,
)
from ._parties import Parties, PreludeParty
from ._tensor import CipherKind, Context, FPTensor, PHETensor

ARBITER = PreludeParty.ARBITER
GUEST = PreludeParty.GUEST
HOST = PreludeParty.HOST

__all__ = [
"FPTensor",
"PHETensor",
"Parties",
"ARBITER",
"GUEST",
"HOST",
"Context",
"LabeledDataloaderWrapper",
"UnlabeledDataloaderWrapper",
"CipherKind"
"CipherKind",
]
264 changes: 0 additions & 264 deletions python/fate_arch/tensor/_context.py

This file was deleted.

79 changes: 6 additions & 73 deletions python/fate_arch/tensor/_federation.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,7 @@
from typing import List, Union
class FederationDeserializer:
def do_deserialize(self, ctx, party):
...

from fate_arch.common import Party
from fate_arch.session import get_parties


def _get_role_parties(role: str):
return get_parties().roles_to_parties([role], strict=False)


class _RoleIndexedParty:
def __init__(self, role: str, index: int) -> None:
assert index >= 0, "index should >= 0"
self._role = role
self._index = index

@property
def party(self) -> Party:
parties = _get_role_parties(self._role)
if 0 <= self._index < len(parties):
return parties[self._index]
raise KeyError(
f"index `{self._index}` out of bound `0 <= index < {len(parties)}`"
)


class _Parties:
def __init__(
self, parties: List[Union[str, Party, _RoleIndexedParty, "_Parties"]]
) -> None:
self._parties = parties

def _reverse(self):
self._parties.reverse()
return self

@property
def parties(self) -> List[Party]:
flatten = []
for p in self._parties:
if isinstance(p, str) and (p == "guest" or p == "host" or p == "arbiter"):
flatten.extend(_get_role_parties(p))
elif isinstance(p, Party):
flatten.append(p)
elif isinstance(p, _RoleIndexedParty):
flatten.append(p.party)
elif isinstance(p, _Parties):
flatten.extend(p.parties)
return flatten

def __add__(self, other) -> "_Parties":
if isinstance(other, Party):
return _Parties([self, other])
elif isinstance(other, list):
return _Parties([self, *other])
else:
raise ValueError(f"can't add `{other}`")

def __radd__(self, other) -> "_Parties":
return self.__add__(other)._reverse()


class _Role(_Parties):
def __init__(self, role: str) -> None:
self._role = role
super().__init__([role])

def __getitem__(self, key) -> _RoleIndexedParty:
return _RoleIndexedParty(self._role, key)


ARBITER = _Role("arbiter")
GUEST = _Role("guest")
HOST = _Role("host")
@classmethod
def make_frac_key(cls, base_key, frac_key):
return f"{base_key}__frac__{frac_key}"
Loading