From 71656fbbc3d5b7d48cd32d256fe89fb2934f8de0 Mon Sep 17 00:00:00 2001 From: lucasvanmol Date: Mon, 9 Dec 2024 17:51:10 +0100 Subject: [PATCH] Add Dataflows + tests --- README.md | 2 +- src/cascade/dataflow/dataflow.py | 171 +++++++++++++++++++++++++- src/cascade/dataflow/operator.py | 99 +++++++++++++++ src/cascade/dataflow/test_dataflow.py | 111 +++++++++++++++++ src/cascade/runtime/flink_runtime.py | 90 ++++++++++++-- 5 files changed, 463 insertions(+), 10 deletions(-) create mode 100644 src/cascade/dataflow/operator.py create mode 100644 src/cascade/dataflow/test_dataflow.py diff --git a/README.md b/README.md index 1be047a..4be243e 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Documentation is done with [Google style docstrings](https://google.github.io/st It can be generated using [pdoc](https://pdoc.dev/docs/pdoc.html): ``` -pdoc src/cascade +pdoc --mermaid src/cascade ``` diff --git a/src/cascade/dataflow/dataflow.py b/src/cascade/dataflow/dataflow.py index b801a24..5ad3d57 100644 --- a/src/cascade/dataflow/dataflow.py +++ b/src/cascade/dataflow/dataflow.py @@ -1,2 +1,169 @@ -class Dataflow: - pass \ No newline at end of file +from abc import ABC +from dataclasses import dataclass, field +from typing import Any, List, Optional, Union + + +@dataclass +class InitClass: + """A method type corresponding to an `__init__` call.""" + pass + +@dataclass +class InvokeMethod: + """A method invocation of the underlying method indentifier.""" + method_name: str + +@dataclass +class Node(ABC): + """Base class for Nodes.""" + id: int = field(init=False) + """This node's unique id.""" + + _id_counter: int = field(init=False, default=0, repr=False) + + def __post_init__(self): + # Assign a unique ID from the class-level counter + self.id = Node._id_counter + Node._id_counter += 1 + +@dataclass +class OpNode(Node): + """A node in a `Dataflow` corresponding to a method call of a `StatefulOperator`. + + A `Dataflow` may reference the same `StatefulOperator` multiple times. + The `StatefulOperator` that this node belongs to is referenced by `cls`.""" + cls: Any + method_type: Union[InitClass, InvokeMethod] + +@dataclass +class MergeNode(Node): + """A node in a `Dataflow` corresponding to a merge operator. + + It will aggregate incoming edges and output them as a list to the outgoing edge. + Their actual implementation is runtime-dependent.""" + pass + +@dataclass +class Edge(): + """An Edge in the Dataflow graph.""" + from_node: Node + to_node: Node + +class DataFlow: + """A Dataflow is a graph consisting of `OpNode`s, `MergeNode`s, and `Edge`s. + + Example Usage + ------------- + + Consider two entities, `User` and `Item`, and a method `User.buy_items(item1, item2)`. + The resulting method could be created into the following Dataflow graph. + + ```mermaid + flowchart TD; + user1[User.buy_items_0] + item1[Item.get_price] + item2[Item.get_price] + user2[User.buy_items_1] + merge{Merge} + user1-- item1_key -->item1; + user1-- item2_key -->item2; + item1-- item1_price -->merge; + item2-- item2_price -->merge; + merge-- [item1_price, item2_price] -->user2; + ``` + + In code, one would write: + + ```py + df = DataFlow("user.buy_items") + n0 = OpNode(User, InvokeMethod("buy_items_0")) + n1 = OpNode(Item, InvokeMethod("get_price")) + n2 = OpNode(Item, InvokeMethod("get_price")) + n3 = MergeNode() + n4 = OpNode(User, InvokeMethod("buy_items_1")) + df.add_edge(Edge(n0, n1)) + df.add_edge(Edge(n0, n2)) + df.add_edge(Edge(n1, n3)) + df.add_edge(Edge(n2, n3)) + df.add_edge(Edge(n3, n4)) + ``` + """ + def __init__(self, name): + self.name = name + self.adjacency_list = {} + self.nodes = {} + + def add_node(self, node: Node): + """Add a node to the Dataflow graph if it doesn't already exist.""" + if node.id not in self.adjacency_list: + self.adjacency_list[node.id] = [] + self.nodes[node.id] = node + + def add_edge(self, edge: Edge): + """Add an edge to the Dataflow graph. Nodes that don't exist will be added to the graph automatically.""" + self.add_node(edge.from_node) + self.add_node(edge.to_node) + self.adjacency_list[edge.from_node.id].append(edge.to_node.id) + + def get_neighbors(self, node: Node) -> List[Node]: + """Get the outgoing neighbors of this `Node`""" + return [self.nodes[id] for id in self.adjacency_list.get(node.id, [])] + +@dataclass +class Event(): + """An Event is an object that travels through the Dataflow graph.""" + + target: 'Node' + """The Node that this Event wants to go to.""" + + key_stack: list[str] + """The keys this event is concerned with. + The top of the stack, i.e. `key_stack[-1]`, should always correspond to a key + on the StatefulOperator of `target.cls` if `target` is an `OpNode`.""" + + args: List[Any] + kwargs: dict[str, Any] + """The args and kwargs to be passed to the `target`. + If `target` is an `OpNode` this corresponds to the method args/kwargs.""" + + dataflow: Optional['DataFlow'] + """The Dataflow that this event is a part of. If None, it won't propogate. + This might be remove in the future in favour of a routing operator.""" + + _id: int = field(default=None) + """Unique ID for this event. Except in `propogate`, this `id` should not be set.""" + _id_counter: int = field(init=False, default=0, repr=False) + + def __post_init__(self): + if self._id is None: + # Assign a unique ID from the class-level counter + self._id = Event._id_counter + Event._id_counter += 1 + + def propogate(self, key_stack, args, kwargs) -> list['Event']: + """Propogate this event through the Dataflow.""" + if self.dataflow is None or len(key_stack) == 0: + self.args = args + self.kwargs = kwargs + return [self] + + targets = self.dataflow.get_neighbors(self.target) + + if len(targets) == 0: + self.args = args + self.kwargs = kwargs + return [self] + else: + # An event with multiple targets should have the same number of keys in a list on top of its key stack + keys = key_stack.pop() + if not isinstance(keys, list): + keys = [keys] + return [Event( + target, + key_stack + [key], + args, + kwargs, + self.dataflow, + _id=self._id) + + for target, key in zip(targets, keys)] \ No newline at end of file diff --git a/src/cascade/dataflow/operator.py b/src/cascade/dataflow/operator.py new file mode 100644 index 0000000..8bd3449 --- /dev/null +++ b/src/cascade/dataflow/operator.py @@ -0,0 +1,99 @@ +from typing import Any, Generic, Protocol, Type, TypeVar +from cascade.dataflow.dataflow import InvokeMethod + +T = TypeVar('T') + + +class MethodCall(Generic[T], Protocol): + """A helper class for type-safety of method signature for compiled methods. + + It corresponds to functions with the following signature: + ```py + def my_compiled_method(*args: Any, state: T, key_stack: list[str], **kwargs: Any) -> Any: + ... + ``` + + `T` corresponds to a Python class, which, if modified, should return as the 2nd item in the tuple. + + The first item in the returned tuple corresponds to the actual return value of the function. + + The third item in the tuple corresponds to the `key_stack` which should be updated accordingly. + Notably, a terminal function should pop a key off the `key_stack`, whereas a function that calls + other functions should push the correct key(s) onto the `key_stack`. + """ + + def __call__(self, *args: Any, state: T, key_stack: list[str], **kwargs: Any) -> Any: ... + """@private""" + + +class StatefulOperator(Generic[T]): + """An abstraction for a user-defined python class. + + A StatefulOperator handles incoming events, such as `cascade.dataflow.dataflow.InitClass` and `cascade.dataflow.dataflow.InvokeMethod`. + It is created using a class `cls` and a collection of `methods`. + + These methods map a method identifier (str) to a python function. + Importantly, these functions are "stateless" in the sense that they are not methods, + instead reading and modifying the underlying class `T` through a state variable, see `handle_invoke_method`. + """ + def __init__(self, cls: Type[T], methods: dict[str, MethodCall[T]]): + """Create the StatefulOperator from a class and its compiled methods. + + Typically, a class could be comprised of split and non-split methods. Take the following example: + + ```py + class User: + def __init__(self, key: str, balance: int): + self.key = key + self.balance = balance + + def get_balance(self) -> int: + return self.balance + + def buy_item(self, item: Item) -> bool: + self.balance -= item.get_price() + return self.balance >= 0 + ``` + + Here, the class could be turned into a StatefulOperator as follows: + + ```py + def user_get_balance(*, state: User, key_stack: list[str]): + key_stack.pop() + return state.balance + + def user_buy_item_0(item_key: str, *, state: User, key_stack: list[str]): + key_stack.append(item_key) + + def user_buy_item_1(item_get_price: int, *, state: User, key_stack: list[str]): + state.balance -= item_get_price + return state.balance >= 0 + + op = StatefulOperator( + User, + { + "buy_item": user_buy_item_0, + "get_balance": user_get_balance, + "buy_item_1": user_buy_item_1 + }) + ``` + """ + # methods maps function name to a function. Ideally this is done once in the object + self._methods = methods + self._cls = cls + + + def handle_init_class(self, *args, **kwargs) -> T: + """Create an instance of the underlying class. Equivalent to `T.__init__(*args, **kwargs)`.""" + return self._cls(*args, **kwargs) + + def handle_invoke_method(self, method: InvokeMethod, *args, state: T, key_stack: list[str], **kwargs) -> tuple[Any, T, list[str]]: + """Invoke the method of the underlying class. + + The `cascade.dataflow.dataflow.InvokeMethod` object must contain a method identifier + that exists on the underlying compiled class functions. + + The state `T` and key_stack is passed along to the function, and may be modified. + """ + return self._methods[method.method_name](*args, state=state, key_stack=key_stack, **kwargs) + \ No newline at end of file diff --git a/src/cascade/dataflow/test_dataflow.py b/src/cascade/dataflow/test_dataflow.py new file mode 100644 index 0000000..556e8f0 --- /dev/null +++ b/src/cascade/dataflow/test_dataflow.py @@ -0,0 +1,111 @@ +from typing import Any +from cascade.dataflow.dataflow import DataFlow, Edge, Event, InvokeMethod, MergeNode, OpNode + +class DummyUser: + def __init__(self, key: str, balance: int): + self.key: str = key + self.balance: int = balance + + def buy_item(self, item: 'DummyItem') -> bool: + item_price = item.get_price() # SSA + self.balance -= item_price + return self.balance >= 0 + +def buy_item_0_compiled(item_key, *, state: DummyUser, key_stack: list[str]) -> Any: + key_stack.append(item_key) + return item_key + +def buy_item_1_compiled(item_price: int, *, state: DummyUser, key_stack: list[str]) -> Any: + key_stack.pop() + state.balance -= item_price + return state.balance >= 0 + +class DummyItem: + def __init__(self, key: str, price: int): + self.key: str = key + self.price: int = price + + def get_price(self) -> int: + return self.price + +def get_price_compiled(*args, state: DummyItem, key_stack: list[str]) -> Any: + key_stack.pop() # final function + return state.price + +################## TESTS ####################### + +user = DummyUser("user", 100) +item = DummyItem("fork", 5) + +def test_simple_df_propogation(): + df = DataFlow("user.buy_item") + n1 = OpNode(DummyUser, InvokeMethod("buy_item_0")) + n2 = OpNode(DummyItem, InvokeMethod("get_price")) + n3 = OpNode(DummyUser, InvokeMethod("buy_item_1")) + df.add_edge(Edge(n1, n2)) + df.add_edge(Edge(n2, n3)) + + event = Event(n1, ["user"], ["fork"], {}, df) + + # Manually propogate + item_key = buy_item_0_compiled(event.args, state=user, key_stack=event.key_stack) + next_event = event.propogate(event.key_stack, item_key, None) + + assert len(next_event) == 1 + assert isinstance(next_event[0].target, OpNode) + assert next_event[0].target.cls == DummyItem + assert next_event[0].key_stack == ["user", "fork"] + event = next_event[0] + + item_price = get_price_compiled(event.args, state=item, key_stack=event.key_stack) + next_event = event.propogate(event.key_stack, item_price, None) + + assert len(next_event) == 1 + assert isinstance(next_event[0].target, OpNode) + assert next_event[0].target.cls == DummyUser + assert next_event[0].key_stack == ["user"] + event = next_event[0] + + positive_balance = buy_item_1_compiled(event.args, state=user, key_stack=event.key_stack) + next_event = event.propogate(event.key_stack, None, None) + assert next_event[0].key_stack == [] + + +def test_merge_df_propogation(): + df = DataFlow("user.buy_2_items") + n0 = OpNode(DummyUser, InvokeMethod("buy_2_items_0")) + n1 = OpNode(DummyItem, InvokeMethod("get_price")) + n2 = OpNode(DummyItem, InvokeMethod("get_price")) + n3 = MergeNode() + n4 = OpNode(DummyUser, InvokeMethod("buy_2_items_1")) + df.add_edge(Edge(n0, n1)) + df.add_edge(Edge(n0, n2)) + df.add_edge(Edge(n1, n3)) + df.add_edge(Edge(n2, n3)) + df.add_edge(Edge(n3, n4)) + + # User with key "foo" buys items with keys "fork" and "spoon" + event = Event(n0, ["foo"], ["fork", "spoon"], {}, df) + + # Propogate the event (without actually doing any calculation) + # Normally, the key_stack should've been updated by the runtime here: + key_stack = ["foo", ["fork", "spoon"]] + next_event = event.propogate(key_stack, None, None) + + assert len(next_event) == 2 + assert isinstance(next_event[0].target, OpNode) + assert isinstance(next_event[1].target, OpNode) + assert next_event[0].target.cls == DummyItem + assert next_event[1].target.cls == DummyItem + + event1, event2 = next_event + next_event = event1.propogate(event1.key_stack, None, None) + assert len(next_event) == 1 + assert isinstance(next_event[0].target, MergeNode) + + next_event = event2.propogate(event2.key_stack, None, None) + assert len(next_event) == 1 + assert isinstance(next_event[0].target, MergeNode) + + final_event = next_event[0].propogate(next_event[0].key_stack, None, None) + assert final_event[0].target == n4 diff --git a/src/cascade/runtime/flink_runtime.py b/src/cascade/runtime/flink_runtime.py index 40e2c2b..21579da 100644 --- a/src/cascade/runtime/flink_runtime.py +++ b/src/cascade/runtime/flink_runtime.py @@ -1,11 +1,78 @@ import os from pyflink.common.typeinfo import Types, get_gateway from pyflink.common import Configuration, DeserializationSchema, SerializationSchema -from pyflink.datastream.functions import KeyedProcessFunction, RuntimeContext, ValueState +from pyflink.datastream.functions import KeyedProcessFunction, RuntimeContext, ValueState, ValueStateDescriptor from pyflink.datastream.connectors.kafka import FlinkKafkaConsumer from pyflink.datastream import StreamExecutionEnvironment import pickle +from cascade.dataflow.dataflow import Event, InitClass, InvokeMethod, OpNode +from cascade.dataflow.operator import StatefulOperator from confluent_kafka import Producer +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(level=logging.DEBUG) +console_handler = logging.StreamHandler() +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +console_handler.setFormatter(formatter) +logger.addHandler(console_handler) + +class FlinkOperator(KeyedProcessFunction): + """Wraps an `cascade.dataflow.datflow.StatefulOperator` in a KeyedProcessFunction so that it can run in Flink. + """ + def __init__(self, operator: StatefulOperator) -> None: + self.state: ValueState = None # type: ignore (expect state to be initialised on .open()) + self.operator = operator + + + def open(self, runtime_context: RuntimeContext): + descriptor = ValueStateDescriptor("state", Types.PICKLED_BYTE_ARRAY()) + self.state: ValueState = runtime_context.get_state(descriptor) + + def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): + key_stack = event.key_stack + assert(isinstance(event.target, OpNode)) # should be handled by filters on this FlinkOperator + logger.debug(f"FlinkOperator {event.target.cls.__name__}[{ctx.get_current_key()}]: Processing: {event}") + if isinstance(event.target.method_type, InitClass): + result = self.operator.handle_init_class(*event.args, **event.kwargs) + # Pop this key from the key stack so that we exit + key_stack.pop() + self.state.update(pickle.dumps(result)) + elif isinstance(event.target.method_type, InvokeMethod): + state = pickle.loads(self.state.value()) + result = self.operator.handle_invoke_method(event.target.method_type, *event.args, state=state, key_stack=key_stack, **event.kwargs) + + # TODO: check if state actually needs to be updated + if state is not None: + self.state.update(pickle.dumps(state)) + + new_events = event.propogate(key_stack, [result], {}) + logger.debug(f"FlinkOperator {event.target.cls.__name__}[{ctx.get_current_key()}]: Propogated {len(new_events)} new Events") + yield from new_events + + +class FlinkMergeOperator(KeyedProcessFunction): + """Flink implementation of a merge operator.""" + def __init__(self) -> None: + self.other: ValueState = None # type: ignore (expect state to be initialised on .open()) + + def open(self, runtime_context: RuntimeContext): + descriptor = ValueStateDescriptor("merge_state", Types.PICKLED_BYTE_ARRAY()) + self.other = runtime_context.get_state(descriptor) + + def process_element(self, event: Event, ctx: KeyedProcessFunction.Context): + other_args = self.other.value() + logger.debug(f"FlinkMergeOp [{ctx.get_current_key()}]: Processing: {event}") + if other_args == None: + logger.debug(f"FlinkMergeOp [{ctx.get_current_key()}]: Saving merge value: {event.args}") + self.other.update(event.args) + else: + self.other.clear() + merged_args = [*event.args, *other_args] + logger.debug(f"FlinkMergeOp [{ctx.get_current_key()}]: Yielding merge values: {merged_args}") + new_event = event.propogate(event.key_stack, [*event.args, *other_args], {}) + yield from new_event + class ByteSerializer(SerializationSchema, DeserializationSchema): """A custom serializer which maps bytes to bytes. @@ -50,29 +117,38 @@ def __init__(self): self, j_deserialization_schema=j_byte_string_schema ) -class FlinkOperator(KeyedProcessFunction): - pass - INPUT_TOPIC = "input-topic" """@private""" class FlinkRuntime(): + """A Runtime that runs Dataflows on Flink.""" def __init__(self): self.env: StreamExecutionEnvironment = None self.producer: Producer = None - def _initialise(self, kafka_broker="localhost:9092"): + def _initialise(self, kafka_broker="localhost:9092", bundle_time=1, bundle_size=5): config = Configuration() # Add the Flink Web UI at http://localhost:8081 config.set_string("rest.port", "8081") - self.env = StreamExecutionEnvironment.get_execution_environment() + + # Sets the waiting timeout(in milliseconds) before processing a bundle for Python user-defined function execution. + # The timeout defines how long the elements of a bundle will be buffered before being processed. + # Lower timeouts lead to lower tail latencies, but may affect throughput. + config.set_integer("python.fn-execution.bundle.time", bundle_time) + + # The maximum number of elements to include in a bundle for Python user-defined function execution. + # The elements are processed asynchronously. One bundle of elements are processed before processing the next bundle of elements. + # A larger value can improve the throughput, but at the cost of more memory usage and higher latency. + config.set_integer("python.fn-execution.bundle.size", bundle_size) + + self.env = StreamExecutionEnvironment.get_execution_environment(config) kafka_jar = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'bin/flink-sql-connector-kafka-3.3.0-1.20.jar') serializer_jar = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'bin/flink-kafka-bytes-serializer.jar') if os.name == 'nt': - self.env.add_jars(f"file:///{kafka_jar}",f"file://{serializer_jar}") + self.env.add_jars(f"file:///{kafka_jar}",f"file:///{serializer_jar}") else: self.env.add_jars(f"file://{kafka_jar}",f"file://{serializer_jar}")