Skip to content

Commit

Permalink
Implement stateful functions for StreamingDataFrames
Browse files Browse the repository at this point in the history
- Set default state dir to `state`
- Add `stateful` parameter to `StreamingDataFrame.apply()`
- Add TransactionState class that implements a very limited key-value storage interface to be passed to stateful functions
  • Loading branch information
daniil-quix committed Oct 25, 2023
1 parent 53ba93b commit ac5dab5
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 97 deletions.
49 changes: 21 additions & 28 deletions src/StreamingDataFrames/streamingdataframes/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import contextlib
import logging
from itertools import chain
from typing import Optional, List, Callable

from confluent_kafka import TopicPartition
from typing import Optional, List, Callable
from typing_extensions import Self

from .dataframe import StreamingDataFrame
Expand Down Expand Up @@ -45,7 +44,7 @@ def __init__(
partitioner: Partitioner = "murmur2",
consumer_extra_config: Optional[dict] = None,
producer_extra_config: Optional[dict] = None,
state_dir: Optional[str] = None,
state_dir: str = "state",
rocksdb_options: Optional[RocksDBOptionsType] = None,
on_consumer_error: Optional[ConsumerErrorCallback] = None,
on_processing_error: Optional[ProcessingErrorCallback] = None,
Expand Down Expand Up @@ -87,9 +86,8 @@ def __init__(
will be passed to `confluent_kafka.Consumer` as is.
:param producer_extra_config: A dictionary with additional options that
will be passed to `confluent_kafka.Producer` as is.
:param state_dir: path to the application state directory, optional.
It should be passed if the application uses stateful operations, otherwise
the exception will be raised.
:param state_dir: path to the application state directory.
Default - ".state".
:param rocksdb_options: RocksDB options.
If `None`, the default options will be used.
:param consumer_poll_timeout: timeout for `RowConsumer.poll()`. Default - 1.0s
Expand Down Expand Up @@ -132,12 +130,11 @@ def __init__(
self._on_message_processed = on_message_processed
self._quix_config_builder: Optional[QuixKafkaConfigsBuilder] = None
self._state_manager: Optional[StateStoreManager] = None
if state_dir:
self._state_manager = StateStoreManager(
group_id=consumer_group,
state_dir=state_dir,
rocksdb_options=rocksdb_options,
)
self._state_manager = StateStoreManager(
group_id=consumer_group,
state_dir=state_dir,
rocksdb_options=rocksdb_options,
)

def set_quix_config_builder(self, config_builder: QuixKafkaConfigsBuilder):
self._quix_config_builder = config_builder
Expand All @@ -152,7 +149,7 @@ def Quix(
partitioner: Partitioner = "murmur2",
consumer_extra_config: Optional[dict] = None,
producer_extra_config: Optional[dict] = None,
state_dir: Optional[str] = None,
state_dir: str = "state",
rocksdb_options: Optional[RocksDBOptionsType] = None,
on_consumer_error: Optional[ConsumerErrorCallback] = None,
on_processing_error: Optional[ProcessingErrorCallback] = None,
Expand Down Expand Up @@ -190,15 +187,16 @@ def Quix(
will be passed to `confluent_kafka.Consumer` as is.
:param producer_extra_config: A dictionary with additional options that
will be passed to `confluent_kafka.Producer` as is.
:param state_dir: path to the application state directory, optional.
It should be passed if the application uses stateful operations, otherwise
the exception will be raised.
:param state_dir: path to the application state directory.
Default - ".state".
:param rocksdb_options: RocksDB options.
If `None`, the default options will be used.
:param consumer_poll_timeout: timeout for `RowConsumer.poll()`. Default - 1.0s
:param producer_poll_timeout: timeout for `RowProducer.poll()`. Default - 0s.
:param on_message_processed: a callback triggered when message is successfully
processed.
:param quix_config_builder: instance of `QuixKafkaConfigsBuilder` to be used
instead of the default one.
To handle errors, `Application` accepts callbacks triggered when exceptions
occur on different stages of stream processing.
Expand Down Expand Up @@ -305,7 +303,7 @@ def dataframe(
to be used as input topics.
:return: `StreamingDataFrame` object
"""
sdf = StreamingDataFrame(topics_in=topics_in)
sdf = StreamingDataFrame(topics_in=topics_in, state_manager=self._state_manager)
sdf.consumer = self._consumer
sdf.producer = self._producer
return sdf
Expand All @@ -316,10 +314,6 @@ def stop(self):
"""
self._running = False

@property
def is_stateful(self) -> bool:
return bool(self._state_manager and self._state_manager.stores)

def _quix_runtime_init(self):
"""
Do some runtime setup only applicable to an Application.Quix instance
Expand Down Expand Up @@ -348,8 +342,7 @@ def run(
exit_stack = contextlib.ExitStack()
exit_stack.enter_context(self._producer)
exit_stack.enter_context(self._consumer)
if self.is_stateful:
exit_stack.enter_context(self._state_manager)
exit_stack.enter_context(self._state_manager)

exit_stack.callback(
lambda *_: logger.debug("Closing Kafka consumers & producers")
Expand Down Expand Up @@ -388,8 +381,8 @@ def run(
first_row.offset,
)

if self.is_stateful:
# Store manager has stores registered, starting a transaction
if self._state_manager.stores:
# Store manager has stores registered, starting a state transaction
state_transaction = self._state_manager.start_store_transaction(
topic=topic_name, partition=partition, offset=offset
)
Expand Down Expand Up @@ -430,7 +423,7 @@ def _on_assign(self, _, topic_partitions: List[TopicPartition]):
:param topic_partitions: list of `TopicPartition` from Kafka
"""
if self.is_stateful:
if self._state_manager.stores:
logger.info(f"Rebalancing: assigning state store partitions")
for tp in topic_partitions:
self._state_manager.on_partition_assign(tp)
Expand All @@ -439,7 +432,7 @@ def _on_revoke(self, _, topic_partitions: List[TopicPartition]):
"""
Revoke partitions from consumer and state
"""
if self.is_stateful:
if self._state_manager.stores:
logger.info(f"Rebalancing: revoking state store partitions")
for tp in topic_partitions:
self._state_manager.on_partition_revoke(tp)
Expand All @@ -448,7 +441,7 @@ def _on_lost(self, _, topic_partitions: List[TopicPartition]):
"""
Dropping lost partitions from consumer and state
"""
if self.is_stateful:
if self._state_manager.stores:
logger.info(f"Rebalancing: dropping lost state store partitions")
for tp in topic_partitions:
self._state_manager.on_partition_lost(tp)
63 changes: 50 additions & 13 deletions src/StreamingDataFrames/streamingdataframes/dataframe/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import uuid
from typing import (
Optional,
Callable,
Union,
List,
Mapping,
)

from typing import Optional, Callable, Union, List, Mapping
from typing_extensions import Self, TypeAlias

from .column import Column, OpValue
Expand All @@ -9,8 +15,12 @@
from ..models import Row, Topic
from ..rowconsumer import RowConsumerProto
from ..rowproducer import RowProducerProto
from ..state import State, StateStoreManager

RowApplier: TypeAlias = Callable[[Row], Optional[Union[Row, List[Row]]]]
ApplyFunc: TypeAlias = Callable[[dict], Optional[Union[dict, List[dict]]]]
StatefulApplyFunc: TypeAlias = Callable[
[dict, State], Optional[Union[dict, List[dict]]]
]

__all__ = ("StreamingDataFrame",)

Expand All @@ -27,17 +37,31 @@ def setitem(k: str, v: Union[Column, OpValue], row: Row) -> Row:

def apply(
row: Row,
func: Callable[[dict], Optional[Union[dict, List[dict]]]],
func: Union[ApplyFunc, StatefulApplyFunc],
expand: bool = False,
state_manager: Optional[StateStoreManager] = None,
) -> Union[Row, List[Row]]:
result = func(row.value)
if result is None and isinstance(row.value, dict): # assume edited in-place
# Providing state to the function if state_manager is passed
if state_manager is not None:
transaction = state_manager.get_store_transaction()
# Prefix all the state keys by the message key
with transaction.with_prefix(prefix=row.key):
# Pass a State object with an interface limited to the key updates only
result = func(row.value, transaction.state)
else:
result = func(row.value)

if result is None and isinstance(row.value, dict):
# Function returned None, assume it changed the incoming dict in-place
return row
if isinstance(result, dict):
# Function returned dict, assume it's a new value for the Row
row.value = result
return row
if isinstance(result, list):
if expand:
# Function returned a list and `expand=True` - treat each item in the list
# as a new Row object downstream
return [row.clone(value=r) for r in result]
raise InvalidApplyResultType(
"Returning 'list' types is not allowed unless 'expand=True' is passed"
Expand All @@ -50,7 +74,8 @@ def apply(
class StreamingDataFrame:
"""
Allows you to define transformations on a kafka message as if it were a Pandas
DataFrame. Currently implements a small subset of the Pandas interface, along with
DataFrame.
Currently, it implements a small subset of the Pandas interface, along with
some differences/accommodations for kafka-specific functionality.
A `StreamingDataFrame` expects to interact with a QuixStreams `Row`, which is
Expand All @@ -66,9 +91,9 @@ class StreamingDataFrame:
print(df.process(row_obj))
Note that just like Pandas, you can "filter" out rows with your operations, like:
```
df = df[df['column_b'] >= 5]
```
If a processing step nulls the Row in some way, all further processing on that
row (including kafka operations, besides committing) will be skipped.
Expand Down Expand Up @@ -102,6 +127,8 @@ class StreamingDataFrame:
def __init__(
self,
topics_in: List[Topic],
state_manager: StateStoreManager,
# TODO: Do we need these params?
_pipeline: Pipeline = None,
_id: str = None,
):
Expand All @@ -113,18 +140,29 @@ def __init__(
raise ValueError("Topic Input list cannot be empty")
self._topics_in = {t.name: t for t in topics_in}
self._topics_out = {}
self._state_manager = state_manager

def apply(
self,
func: Callable[[dict], Optional[Union[dict, List[dict]]]],
func: Union[ApplyFunc, StatefulApplyFunc],
expand: bool = False,
stateful: bool = False,
) -> Self:
"""
Apply a user-defined function that where the `Row.value` is the expected input.
Apply a custom function with `Row.value` as the expected input.
It should either return a new dict, a list of dicts (with expand=True), or
modifying a dict in-place (i.e. returning None).
The function either return a new dict, a list of dicts (with expand=True), or
None (to modify a dict in-place)
"""
if stateful:
# Register the default store for each input topic
for topic in self._topics_in.values():
self._state_manager.register_store(topic_name=topic.name)
return self._apply(
lambda row: apply(
row, func, expand=expand, state_manager=self._state_manager
)
)
return self._apply(lambda row: apply(row, func, expand=expand))

def process(self, row: Row) -> Optional[Union[Row, List[Row]]]:
Expand All @@ -135,7 +173,6 @@ def process(self, row: Row) -> Optional[Union[Row, List[Row]]]:
"""
return self._pipeline.process(row)

# TODO: maybe we should just allow list(Topics) as well (in many spots actually)
def to_topic(self, topic: Topic):
"""
Produce a row to a desired topic.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
PartitionTransaction,
StorePartition,
)
from ..state import TransactionState
from .exceptions import (
StateTransactionError,
NestedPrefixError,
Expand All @@ -27,6 +28,7 @@
"RocksDBPartitionTransaction",
)


logger = logging.getLogger(__name__)

_sentinel = object()
Expand Down Expand Up @@ -288,6 +290,11 @@ def __init__(
self._completed = False
self._dumps = dumps
self._loads = loads
self._state = TransactionState(transaction=self)

@property
def state(self) -> TransactionState:
return self._state

@contextlib.contextmanager
def with_prefix(self, prefix: Any = b"") -> Self:
Expand Down
49 changes: 49 additions & 0 deletions src/StreamingDataFrames/streamingdataframes/state/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Any, Optional

from .types import State, PartitionTransaction


class TransactionState(State):
def __init__(self, transaction: PartitionTransaction):
"""
Simple key-value state to be provided into `StreamingDataFrame` functions
:param transaction: instance of `PartitionTransaction`
"""
self._transaction = transaction

def get(self, key: Any, default: Any = None) -> Optional[Any]:
"""
Get the value for key if key is present in the state, else default
:param key: key
:param default: default value to return if the key is not found
:return: value or None if the key is not found and `default` is not provided
"""
return self._transaction.get(key=key, default=default)

def set(self, key: Any, value: Any):
"""
Set value for the key.
:param key: key
:param value: value
"""
return self._transaction.set(key=key, value=value)

def delete(self, key: Any):
"""
Delete value for the key.
This function always returns `None`, even if value is not found.
:param key: key
"""
return self._transaction.delete(key=key)

def exists(self, key: Any) -> bool:
"""
Check if the key exists in state.
:param key: key
:return: True if key exists, False otherwise
"""

return self._transaction.exists(key=key)
7 changes: 7 additions & 0 deletions src/StreamingDataFrames/streamingdataframes/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ class PartitionTransaction(State):
"get", "set", "delete" and "exists" on a single storage partition.
"""

@property
def state(self) -> State:
"""
An instance of State to be provided to `StreamingDataFrame` functions
:return:
"""

@property
def failed(self) -> bool:
"""
Expand Down
1 change: 0 additions & 1 deletion src/StreamingDataFrames/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"tests.test_dataframes.test_models.fixtures",
"tests.test_dataframes.test_platforms.test_quix.fixtures",
"tests.test_dataframes.test_state.test_rocksdb.fixtures",
"tests.test_dataframes.test_state.fixtures",
]

KafkaContainer = namedtuple("KafkaContainer", ("broker_address",))
Expand Down
Loading

0 comments on commit ac5dab5

Please sign in to comment.