-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ac902aa
commit 71656fb
Showing
5 changed files
with
463 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,169 @@ | ||
class Dataflow: | ||
pass | ||
from abc import ABC | ||
from dataclasses import dataclass, field | ||
from typing import Any, List, Optional, Union | ||
|
||
|
||
@dataclass | ||
class InitClass: | ||
"""A method type corresponding to an `__init__` call.""" | ||
pass | ||
|
||
@dataclass | ||
class InvokeMethod: | ||
"""A method invocation of the underlying method indentifier.""" | ||
method_name: str | ||
|
||
@dataclass | ||
class Node(ABC): | ||
"""Base class for Nodes.""" | ||
id: int = field(init=False) | ||
"""This node's unique id.""" | ||
|
||
_id_counter: int = field(init=False, default=0, repr=False) | ||
|
||
def __post_init__(self): | ||
# Assign a unique ID from the class-level counter | ||
self.id = Node._id_counter | ||
Node._id_counter += 1 | ||
|
||
@dataclass | ||
class OpNode(Node): | ||
"""A node in a `Dataflow` corresponding to a method call of a `StatefulOperator`. | ||
A `Dataflow` may reference the same `StatefulOperator` multiple times. | ||
The `StatefulOperator` that this node belongs to is referenced by `cls`.""" | ||
cls: Any | ||
method_type: Union[InitClass, InvokeMethod] | ||
|
||
@dataclass | ||
class MergeNode(Node): | ||
"""A node in a `Dataflow` corresponding to a merge operator. | ||
It will aggregate incoming edges and output them as a list to the outgoing edge. | ||
Their actual implementation is runtime-dependent.""" | ||
pass | ||
|
||
@dataclass | ||
class Edge(): | ||
"""An Edge in the Dataflow graph.""" | ||
from_node: Node | ||
to_node: Node | ||
|
||
class DataFlow: | ||
"""A Dataflow is a graph consisting of `OpNode`s, `MergeNode`s, and `Edge`s. | ||
Example Usage | ||
------------- | ||
Consider two entities, `User` and `Item`, and a method `User.buy_items(item1, item2)`. | ||
The resulting method could be created into the following Dataflow graph. | ||
```mermaid | ||
flowchart TD; | ||
user1[User.buy_items_0] | ||
item1[Item.get_price] | ||
item2[Item.get_price] | ||
user2[User.buy_items_1] | ||
merge{Merge} | ||
user1-- item1_key -->item1; | ||
user1-- item2_key -->item2; | ||
item1-- item1_price -->merge; | ||
item2-- item2_price -->merge; | ||
merge-- [item1_price, item2_price] -->user2; | ||
``` | ||
In code, one would write: | ||
```py | ||
df = DataFlow("user.buy_items") | ||
n0 = OpNode(User, InvokeMethod("buy_items_0")) | ||
n1 = OpNode(Item, InvokeMethod("get_price")) | ||
n2 = OpNode(Item, InvokeMethod("get_price")) | ||
n3 = MergeNode() | ||
n4 = OpNode(User, InvokeMethod("buy_items_1")) | ||
df.add_edge(Edge(n0, n1)) | ||
df.add_edge(Edge(n0, n2)) | ||
df.add_edge(Edge(n1, n3)) | ||
df.add_edge(Edge(n2, n3)) | ||
df.add_edge(Edge(n3, n4)) | ||
``` | ||
""" | ||
def __init__(self, name): | ||
self.name = name | ||
self.adjacency_list = {} | ||
self.nodes = {} | ||
|
||
def add_node(self, node: Node): | ||
"""Add a node to the Dataflow graph if it doesn't already exist.""" | ||
if node.id not in self.adjacency_list: | ||
self.adjacency_list[node.id] = [] | ||
self.nodes[node.id] = node | ||
|
||
def add_edge(self, edge: Edge): | ||
"""Add an edge to the Dataflow graph. Nodes that don't exist will be added to the graph automatically.""" | ||
self.add_node(edge.from_node) | ||
self.add_node(edge.to_node) | ||
self.adjacency_list[edge.from_node.id].append(edge.to_node.id) | ||
|
||
def get_neighbors(self, node: Node) -> List[Node]: | ||
"""Get the outgoing neighbors of this `Node`""" | ||
return [self.nodes[id] for id in self.adjacency_list.get(node.id, [])] | ||
|
||
@dataclass | ||
class Event(): | ||
"""An Event is an object that travels through the Dataflow graph.""" | ||
|
||
target: 'Node' | ||
"""The Node that this Event wants to go to.""" | ||
|
||
key_stack: list[str] | ||
"""The keys this event is concerned with. | ||
The top of the stack, i.e. `key_stack[-1]`, should always correspond to a key | ||
on the StatefulOperator of `target.cls` if `target` is an `OpNode`.""" | ||
|
||
args: List[Any] | ||
kwargs: dict[str, Any] | ||
"""The args and kwargs to be passed to the `target`. | ||
If `target` is an `OpNode` this corresponds to the method args/kwargs.""" | ||
|
||
dataflow: Optional['DataFlow'] | ||
"""The Dataflow that this event is a part of. If None, it won't propogate. | ||
This might be remove in the future in favour of a routing operator.""" | ||
|
||
_id: int = field(default=None) | ||
"""Unique ID for this event. Except in `propogate`, this `id` should not be set.""" | ||
_id_counter: int = field(init=False, default=0, repr=False) | ||
|
||
def __post_init__(self): | ||
if self._id is None: | ||
# Assign a unique ID from the class-level counter | ||
self._id = Event._id_counter | ||
Event._id_counter += 1 | ||
|
||
def propogate(self, key_stack, args, kwargs) -> list['Event']: | ||
"""Propogate this event through the Dataflow.""" | ||
if self.dataflow is None or len(key_stack) == 0: | ||
self.args = args | ||
self.kwargs = kwargs | ||
return [self] | ||
|
||
targets = self.dataflow.get_neighbors(self.target) | ||
|
||
if len(targets) == 0: | ||
self.args = args | ||
self.kwargs = kwargs | ||
return [self] | ||
else: | ||
# An event with multiple targets should have the same number of keys in a list on top of its key stack | ||
keys = key_stack.pop() | ||
if not isinstance(keys, list): | ||
keys = [keys] | ||
return [Event( | ||
target, | ||
key_stack + [key], | ||
args, | ||
kwargs, | ||
self.dataflow, | ||
_id=self._id) | ||
|
||
for target, key in zip(targets, keys)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
from typing import Any, Generic, Protocol, Type, TypeVar | ||
from cascade.dataflow.dataflow import InvokeMethod | ||
|
||
T = TypeVar('T') | ||
|
||
|
||
class MethodCall(Generic[T], Protocol): | ||
"""A helper class for type-safety of method signature for compiled methods. | ||
It corresponds to functions with the following signature: | ||
```py | ||
def my_compiled_method(*args: Any, state: T, key_stack: list[str], **kwargs: Any) -> Any: | ||
... | ||
``` | ||
`T` corresponds to a Python class, which, if modified, should return as the 2nd item in the tuple. | ||
The first item in the returned tuple corresponds to the actual return value of the function. | ||
The third item in the tuple corresponds to the `key_stack` which should be updated accordingly. | ||
Notably, a terminal function should pop a key off the `key_stack`, whereas a function that calls | ||
other functions should push the correct key(s) onto the `key_stack`. | ||
""" | ||
|
||
def __call__(self, *args: Any, state: T, key_stack: list[str], **kwargs: Any) -> Any: ... | ||
"""@private""" | ||
|
||
|
||
class StatefulOperator(Generic[T]): | ||
"""An abstraction for a user-defined python class. | ||
A StatefulOperator handles incoming events, such as `cascade.dataflow.dataflow.InitClass` and `cascade.dataflow.dataflow.InvokeMethod`. | ||
It is created using a class `cls` and a collection of `methods`. | ||
These methods map a method identifier (str) to a python function. | ||
Importantly, these functions are "stateless" in the sense that they are not methods, | ||
instead reading and modifying the underlying class `T` through a state variable, see `handle_invoke_method`. | ||
""" | ||
def __init__(self, cls: Type[T], methods: dict[str, MethodCall[T]]): | ||
"""Create the StatefulOperator from a class and its compiled methods. | ||
Typically, a class could be comprised of split and non-split methods. Take the following example: | ||
```py | ||
class User: | ||
def __init__(self, key: str, balance: int): | ||
self.key = key | ||
self.balance = balance | ||
def get_balance(self) -> int: | ||
return self.balance | ||
def buy_item(self, item: Item) -> bool: | ||
self.balance -= item.get_price() | ||
return self.balance >= 0 | ||
``` | ||
Here, the class could be turned into a StatefulOperator as follows: | ||
```py | ||
def user_get_balance(*, state: User, key_stack: list[str]): | ||
key_stack.pop() | ||
return state.balance | ||
def user_buy_item_0(item_key: str, *, state: User, key_stack: list[str]): | ||
key_stack.append(item_key) | ||
def user_buy_item_1(item_get_price: int, *, state: User, key_stack: list[str]): | ||
state.balance -= item_get_price | ||
return state.balance >= 0 | ||
op = StatefulOperator( | ||
User, | ||
{ | ||
"buy_item": user_buy_item_0, | ||
"get_balance": user_get_balance, | ||
"buy_item_1": user_buy_item_1 | ||
}) | ||
``` | ||
""" | ||
# methods maps function name to a function. Ideally this is done once in the object | ||
self._methods = methods | ||
self._cls = cls | ||
|
||
|
||
def handle_init_class(self, *args, **kwargs) -> T: | ||
"""Create an instance of the underlying class. Equivalent to `T.__init__(*args, **kwargs)`.""" | ||
return self._cls(*args, **kwargs) | ||
|
||
def handle_invoke_method(self, method: InvokeMethod, *args, state: T, key_stack: list[str], **kwargs) -> tuple[Any, T, list[str]]: | ||
"""Invoke the method of the underlying class. | ||
The `cascade.dataflow.dataflow.InvokeMethod` object must contain a method identifier | ||
that exists on the underlying compiled class functions. | ||
The state `T` and key_stack is passed along to the function, and may be modified. | ||
""" | ||
return self._methods[method.method_name](*args, state=state, key_stack=key_stack, **kwargs) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from typing import Any | ||
from cascade.dataflow.dataflow import DataFlow, Edge, Event, InvokeMethod, MergeNode, OpNode | ||
|
||
class DummyUser: | ||
def __init__(self, key: str, balance: int): | ||
self.key: str = key | ||
self.balance: int = balance | ||
|
||
def buy_item(self, item: 'DummyItem') -> bool: | ||
item_price = item.get_price() # SSA | ||
self.balance -= item_price | ||
return self.balance >= 0 | ||
|
||
def buy_item_0_compiled(item_key, *, state: DummyUser, key_stack: list[str]) -> Any: | ||
key_stack.append(item_key) | ||
return item_key | ||
|
||
def buy_item_1_compiled(item_price: int, *, state: DummyUser, key_stack: list[str]) -> Any: | ||
key_stack.pop() | ||
state.balance -= item_price | ||
return state.balance >= 0 | ||
|
||
class DummyItem: | ||
def __init__(self, key: str, price: int): | ||
self.key: str = key | ||
self.price: int = price | ||
|
||
def get_price(self) -> int: | ||
return self.price | ||
|
||
def get_price_compiled(*args, state: DummyItem, key_stack: list[str]) -> Any: | ||
key_stack.pop() # final function | ||
return state.price | ||
|
||
################## TESTS ####################### | ||
|
||
user = DummyUser("user", 100) | ||
item = DummyItem("fork", 5) | ||
|
||
def test_simple_df_propogation(): | ||
df = DataFlow("user.buy_item") | ||
n1 = OpNode(DummyUser, InvokeMethod("buy_item_0")) | ||
n2 = OpNode(DummyItem, InvokeMethod("get_price")) | ||
n3 = OpNode(DummyUser, InvokeMethod("buy_item_1")) | ||
df.add_edge(Edge(n1, n2)) | ||
df.add_edge(Edge(n2, n3)) | ||
|
||
event = Event(n1, ["user"], ["fork"], {}, df) | ||
|
||
# Manually propogate | ||
item_key = buy_item_0_compiled(event.args, state=user, key_stack=event.key_stack) | ||
next_event = event.propogate(event.key_stack, item_key, None) | ||
|
||
assert len(next_event) == 1 | ||
assert isinstance(next_event[0].target, OpNode) | ||
assert next_event[0].target.cls == DummyItem | ||
assert next_event[0].key_stack == ["user", "fork"] | ||
event = next_event[0] | ||
|
||
item_price = get_price_compiled(event.args, state=item, key_stack=event.key_stack) | ||
next_event = event.propogate(event.key_stack, item_price, None) | ||
|
||
assert len(next_event) == 1 | ||
assert isinstance(next_event[0].target, OpNode) | ||
assert next_event[0].target.cls == DummyUser | ||
assert next_event[0].key_stack == ["user"] | ||
event = next_event[0] | ||
|
||
positive_balance = buy_item_1_compiled(event.args, state=user, key_stack=event.key_stack) | ||
next_event = event.propogate(event.key_stack, None, None) | ||
assert next_event[0].key_stack == [] | ||
|
||
|
||
def test_merge_df_propogation(): | ||
df = DataFlow("user.buy_2_items") | ||
n0 = OpNode(DummyUser, InvokeMethod("buy_2_items_0")) | ||
n1 = OpNode(DummyItem, InvokeMethod("get_price")) | ||
n2 = OpNode(DummyItem, InvokeMethod("get_price")) | ||
n3 = MergeNode() | ||
n4 = OpNode(DummyUser, InvokeMethod("buy_2_items_1")) | ||
df.add_edge(Edge(n0, n1)) | ||
df.add_edge(Edge(n0, n2)) | ||
df.add_edge(Edge(n1, n3)) | ||
df.add_edge(Edge(n2, n3)) | ||
df.add_edge(Edge(n3, n4)) | ||
|
||
# User with key "foo" buys items with keys "fork" and "spoon" | ||
event = Event(n0, ["foo"], ["fork", "spoon"], {}, df) | ||
|
||
# Propogate the event (without actually doing any calculation) | ||
# Normally, the key_stack should've been updated by the runtime here: | ||
key_stack = ["foo", ["fork", "spoon"]] | ||
next_event = event.propogate(key_stack, None, None) | ||
|
||
assert len(next_event) == 2 | ||
assert isinstance(next_event[0].target, OpNode) | ||
assert isinstance(next_event[1].target, OpNode) | ||
assert next_event[0].target.cls == DummyItem | ||
assert next_event[1].target.cls == DummyItem | ||
|
||
event1, event2 = next_event | ||
next_event = event1.propogate(event1.key_stack, None, None) | ||
assert len(next_event) == 1 | ||
assert isinstance(next_event[0].target, MergeNode) | ||
|
||
next_event = event2.propogate(event2.key_stack, None, None) | ||
assert len(next_event) == 1 | ||
assert isinstance(next_event[0].target, MergeNode) | ||
|
||
final_event = next_event[0].propogate(next_event[0].key_stack, None, None) | ||
assert final_event[0].target == n4 |
Oops, something went wrong.