Skip to content

Commit

Permalink
Add Dataflows + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasvanmol committed Dec 9, 2024
1 parent ac902aa commit 71656fb
Show file tree
Hide file tree
Showing 5 changed files with 463 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Documentation is done with [Google style docstrings](https://google.github.io/st
It can be generated using [pdoc](https://pdoc.dev/docs/pdoc.html):

```
pdoc src/cascade
pdoc --mermaid src/cascade
```


171 changes: 169 additions & 2 deletions src/cascade/dataflow/dataflow.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,169 @@
class Dataflow:
pass
from abc import ABC
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union


@dataclass
class InitClass:
"""A method type corresponding to an `__init__` call."""
pass

@dataclass
class InvokeMethod:
"""A method invocation of the underlying method indentifier."""
method_name: str

@dataclass
class Node(ABC):
"""Base class for Nodes."""
id: int = field(init=False)
"""This node's unique id."""

_id_counter: int = field(init=False, default=0, repr=False)

def __post_init__(self):
# Assign a unique ID from the class-level counter
self.id = Node._id_counter
Node._id_counter += 1

@dataclass
class OpNode(Node):
"""A node in a `Dataflow` corresponding to a method call of a `StatefulOperator`.
A `Dataflow` may reference the same `StatefulOperator` multiple times.
The `StatefulOperator` that this node belongs to is referenced by `cls`."""
cls: Any
method_type: Union[InitClass, InvokeMethod]

@dataclass
class MergeNode(Node):
"""A node in a `Dataflow` corresponding to a merge operator.
It will aggregate incoming edges and output them as a list to the outgoing edge.
Their actual implementation is runtime-dependent."""
pass

@dataclass
class Edge():
"""An Edge in the Dataflow graph."""
from_node: Node
to_node: Node

class DataFlow:
"""A Dataflow is a graph consisting of `OpNode`s, `MergeNode`s, and `Edge`s.
Example Usage
-------------
Consider two entities, `User` and `Item`, and a method `User.buy_items(item1, item2)`.
The resulting method could be created into the following Dataflow graph.
```mermaid
flowchart TD;
user1[User.buy_items_0]
item1[Item.get_price]
item2[Item.get_price]
user2[User.buy_items_1]
merge{Merge}
user1-- item1_key -->item1;
user1-- item2_key -->item2;
item1-- item1_price -->merge;
item2-- item2_price -->merge;
merge-- [item1_price, item2_price] -->user2;
```
In code, one would write:
```py
df = DataFlow("user.buy_items")
n0 = OpNode(User, InvokeMethod("buy_items_0"))
n1 = OpNode(Item, InvokeMethod("get_price"))
n2 = OpNode(Item, InvokeMethod("get_price"))
n3 = MergeNode()
n4 = OpNode(User, InvokeMethod("buy_items_1"))
df.add_edge(Edge(n0, n1))
df.add_edge(Edge(n0, n2))
df.add_edge(Edge(n1, n3))
df.add_edge(Edge(n2, n3))
df.add_edge(Edge(n3, n4))
```
"""
def __init__(self, name):
self.name = name
self.adjacency_list = {}
self.nodes = {}

def add_node(self, node: Node):
"""Add a node to the Dataflow graph if it doesn't already exist."""
if node.id not in self.adjacency_list:
self.adjacency_list[node.id] = []
self.nodes[node.id] = node

def add_edge(self, edge: Edge):
"""Add an edge to the Dataflow graph. Nodes that don't exist will be added to the graph automatically."""
self.add_node(edge.from_node)
self.add_node(edge.to_node)
self.adjacency_list[edge.from_node.id].append(edge.to_node.id)

def get_neighbors(self, node: Node) -> List[Node]:
"""Get the outgoing neighbors of this `Node`"""
return [self.nodes[id] for id in self.adjacency_list.get(node.id, [])]

@dataclass
class Event():
"""An Event is an object that travels through the Dataflow graph."""

target: 'Node'
"""The Node that this Event wants to go to."""

key_stack: list[str]
"""The keys this event is concerned with.
The top of the stack, i.e. `key_stack[-1]`, should always correspond to a key
on the StatefulOperator of `target.cls` if `target` is an `OpNode`."""

args: List[Any]
kwargs: dict[str, Any]
"""The args and kwargs to be passed to the `target`.
If `target` is an `OpNode` this corresponds to the method args/kwargs."""

dataflow: Optional['DataFlow']
"""The Dataflow that this event is a part of. If None, it won't propogate.
This might be remove in the future in favour of a routing operator."""

_id: int = field(default=None)
"""Unique ID for this event. Except in `propogate`, this `id` should not be set."""
_id_counter: int = field(init=False, default=0, repr=False)

def __post_init__(self):
if self._id is None:
# Assign a unique ID from the class-level counter
self._id = Event._id_counter
Event._id_counter += 1

def propogate(self, key_stack, args, kwargs) -> list['Event']:
"""Propogate this event through the Dataflow."""
if self.dataflow is None or len(key_stack) == 0:
self.args = args
self.kwargs = kwargs
return [self]

targets = self.dataflow.get_neighbors(self.target)

if len(targets) == 0:
self.args = args
self.kwargs = kwargs
return [self]
else:
# An event with multiple targets should have the same number of keys in a list on top of its key stack
keys = key_stack.pop()
if not isinstance(keys, list):
keys = [keys]
return [Event(
target,
key_stack + [key],
args,
kwargs,
self.dataflow,
_id=self._id)

for target, key in zip(targets, keys)]
99 changes: 99 additions & 0 deletions src/cascade/dataflow/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Any, Generic, Protocol, Type, TypeVar
from cascade.dataflow.dataflow import InvokeMethod

T = TypeVar('T')


class MethodCall(Generic[T], Protocol):
"""A helper class for type-safety of method signature for compiled methods.
It corresponds to functions with the following signature:
```py
def my_compiled_method(*args: Any, state: T, key_stack: list[str], **kwargs: Any) -> Any:
...
```
`T` corresponds to a Python class, which, if modified, should return as the 2nd item in the tuple.
The first item in the returned tuple corresponds to the actual return value of the function.
The third item in the tuple corresponds to the `key_stack` which should be updated accordingly.
Notably, a terminal function should pop a key off the `key_stack`, whereas a function that calls
other functions should push the correct key(s) onto the `key_stack`.
"""

def __call__(self, *args: Any, state: T, key_stack: list[str], **kwargs: Any) -> Any: ...
"""@private"""


class StatefulOperator(Generic[T]):
"""An abstraction for a user-defined python class.
A StatefulOperator handles incoming events, such as `cascade.dataflow.dataflow.InitClass` and `cascade.dataflow.dataflow.InvokeMethod`.
It is created using a class `cls` and a collection of `methods`.
These methods map a method identifier (str) to a python function.
Importantly, these functions are "stateless" in the sense that they are not methods,
instead reading and modifying the underlying class `T` through a state variable, see `handle_invoke_method`.
"""
def __init__(self, cls: Type[T], methods: dict[str, MethodCall[T]]):
"""Create the StatefulOperator from a class and its compiled methods.
Typically, a class could be comprised of split and non-split methods. Take the following example:
```py
class User:
def __init__(self, key: str, balance: int):
self.key = key
self.balance = balance
def get_balance(self) -> int:
return self.balance
def buy_item(self, item: Item) -> bool:
self.balance -= item.get_price()
return self.balance >= 0
```
Here, the class could be turned into a StatefulOperator as follows:
```py
def user_get_balance(*, state: User, key_stack: list[str]):
key_stack.pop()
return state.balance
def user_buy_item_0(item_key: str, *, state: User, key_stack: list[str]):
key_stack.append(item_key)
def user_buy_item_1(item_get_price: int, *, state: User, key_stack: list[str]):
state.balance -= item_get_price
return state.balance >= 0
op = StatefulOperator(
User,
{
"buy_item": user_buy_item_0,
"get_balance": user_get_balance,
"buy_item_1": user_buy_item_1
})
```
"""
# methods maps function name to a function. Ideally this is done once in the object
self._methods = methods
self._cls = cls


def handle_init_class(self, *args, **kwargs) -> T:
"""Create an instance of the underlying class. Equivalent to `T.__init__(*args, **kwargs)`."""
return self._cls(*args, **kwargs)

def handle_invoke_method(self, method: InvokeMethod, *args, state: T, key_stack: list[str], **kwargs) -> tuple[Any, T, list[str]]:
"""Invoke the method of the underlying class.
The `cascade.dataflow.dataflow.InvokeMethod` object must contain a method identifier
that exists on the underlying compiled class functions.
The state `T` and key_stack is passed along to the function, and may be modified.
"""
return self._methods[method.method_name](*args, state=state, key_stack=key_stack, **kwargs)

111 changes: 111 additions & 0 deletions src/cascade/dataflow/test_dataflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Any
from cascade.dataflow.dataflow import DataFlow, Edge, Event, InvokeMethod, MergeNode, OpNode

class DummyUser:
def __init__(self, key: str, balance: int):
self.key: str = key
self.balance: int = balance

def buy_item(self, item: 'DummyItem') -> bool:
item_price = item.get_price() # SSA
self.balance -= item_price
return self.balance >= 0

def buy_item_0_compiled(item_key, *, state: DummyUser, key_stack: list[str]) -> Any:
key_stack.append(item_key)
return item_key

def buy_item_1_compiled(item_price: int, *, state: DummyUser, key_stack: list[str]) -> Any:
key_stack.pop()
state.balance -= item_price
return state.balance >= 0

class DummyItem:
def __init__(self, key: str, price: int):
self.key: str = key
self.price: int = price

def get_price(self) -> int:
return self.price

def get_price_compiled(*args, state: DummyItem, key_stack: list[str]) -> Any:
key_stack.pop() # final function
return state.price

################## TESTS #######################

user = DummyUser("user", 100)
item = DummyItem("fork", 5)

def test_simple_df_propogation():
df = DataFlow("user.buy_item")
n1 = OpNode(DummyUser, InvokeMethod("buy_item_0"))
n2 = OpNode(DummyItem, InvokeMethod("get_price"))
n3 = OpNode(DummyUser, InvokeMethod("buy_item_1"))
df.add_edge(Edge(n1, n2))
df.add_edge(Edge(n2, n3))

event = Event(n1, ["user"], ["fork"], {}, df)

# Manually propogate
item_key = buy_item_0_compiled(event.args, state=user, key_stack=event.key_stack)
next_event = event.propogate(event.key_stack, item_key, None)

assert len(next_event) == 1
assert isinstance(next_event[0].target, OpNode)
assert next_event[0].target.cls == DummyItem
assert next_event[0].key_stack == ["user", "fork"]
event = next_event[0]

item_price = get_price_compiled(event.args, state=item, key_stack=event.key_stack)
next_event = event.propogate(event.key_stack, item_price, None)

assert len(next_event) == 1
assert isinstance(next_event[0].target, OpNode)
assert next_event[0].target.cls == DummyUser
assert next_event[0].key_stack == ["user"]
event = next_event[0]

positive_balance = buy_item_1_compiled(event.args, state=user, key_stack=event.key_stack)
next_event = event.propogate(event.key_stack, None, None)
assert next_event[0].key_stack == []


def test_merge_df_propogation():
df = DataFlow("user.buy_2_items")
n0 = OpNode(DummyUser, InvokeMethod("buy_2_items_0"))
n1 = OpNode(DummyItem, InvokeMethod("get_price"))
n2 = OpNode(DummyItem, InvokeMethod("get_price"))
n3 = MergeNode()
n4 = OpNode(DummyUser, InvokeMethod("buy_2_items_1"))
df.add_edge(Edge(n0, n1))
df.add_edge(Edge(n0, n2))
df.add_edge(Edge(n1, n3))
df.add_edge(Edge(n2, n3))
df.add_edge(Edge(n3, n4))

# User with key "foo" buys items with keys "fork" and "spoon"
event = Event(n0, ["foo"], ["fork", "spoon"], {}, df)

# Propogate the event (without actually doing any calculation)
# Normally, the key_stack should've been updated by the runtime here:
key_stack = ["foo", ["fork", "spoon"]]
next_event = event.propogate(key_stack, None, None)

assert len(next_event) == 2
assert isinstance(next_event[0].target, OpNode)
assert isinstance(next_event[1].target, OpNode)
assert next_event[0].target.cls == DummyItem
assert next_event[1].target.cls == DummyItem

event1, event2 = next_event
next_event = event1.propogate(event1.key_stack, None, None)
assert len(next_event) == 1
assert isinstance(next_event[0].target, MergeNode)

next_event = event2.propogate(event2.key_stack, None, None)
assert len(next_event) == 1
assert isinstance(next_event[0].target, MergeNode)

final_event = next_event[0].propogate(next_event[0].key_stack, None, None)
assert final_event[0].target == n4
Loading

0 comments on commit 71656fb

Please sign in to comment.