diff --git a/src/StreamingDataFrames/streamingdataframes/app.py b/src/StreamingDataFrames/streamingdataframes/app.py index e4da34d09..46ff617d8 100644 --- a/src/StreamingDataFrames/streamingdataframes/app.py +++ b/src/StreamingDataFrames/streamingdataframes/app.py @@ -17,9 +17,12 @@ from .platforms.quix import QuixKafkaConfigsBuilder from .rowconsumer import RowConsumer from .rowproducer import RowProducer +from .state import StateStoreManager +from .state.rocksdb import RocksDBOptionsType __all__ = ("Application",) + logger = logging.getLogger(__name__) MessageProcessedCallback = Callable[[str, int, int], None] @@ -35,6 +38,8 @@ def __init__( partitioner: Partitioner = "murmur2", consumer_extra_config: Optional[dict] = None, producer_extra_config: Optional[dict] = None, + state_dir: Optional[str] = None, + rocksdb_options: Optional[RocksDBOptionsType] = None, on_consumer_error: Optional[ConsumerErrorCallback] = None, on_processing_error: Optional[ProcessingErrorCallback] = None, on_producer_error: Optional[ProducerErrorCallback] = None, @@ -75,6 +80,11 @@ def __init__( will be passed to `confluent_kafka.Consumer` as is. :param producer_extra_config: A dictionary with additional options that will be passed to `confluent_kafka.Producer` as is. + :param state_dir: path to the application state directory, optional. + It should be passed if the application uses stateful operations, otherwise + the exception will be raised. + :param rocksdb_options: RocksDB options. + If `None`, the default options will be used. :param consumer_poll_timeout: timeout for `RowConsumer.poll()`. Default - 1.0s :param producer_poll_timeout: timeout for `RowProducer.poll()`. Default - 0s. :param on_message_processed: a callback triggered when message is successfully @@ -114,6 +124,13 @@ def __init__( self._on_processing_error = on_processing_error or default_on_processing_error self._on_message_processed = on_message_processed self._quix_config_builder: Optional[QuixKafkaConfigsBuilder] = None + self._state_manager: Optional[StateStoreManager] = None + if state_dir: + self._state_manager = StateStoreManager( + group_id=consumer_group, + state_dir=state_dir, + rocksdb_options=rocksdb_options, + ) def set_quix_config_builder(self, config_builder: QuixKafkaConfigsBuilder): self._quix_config_builder = config_builder @@ -128,6 +145,8 @@ def Quix( partitioner: Partitioner = "murmur2", consumer_extra_config: Optional[dict] = None, producer_extra_config: Optional[dict] = None, + state_dir: Optional[str] = None, + rocksdb_options: Optional[RocksDBOptionsType] = None, on_consumer_error: Optional[ConsumerErrorCallback] = None, on_processing_error: Optional[ProcessingErrorCallback] = None, on_producer_error: Optional[ProducerErrorCallback] = None, @@ -163,6 +182,11 @@ def Quix( will be passed to `confluent_kafka.Consumer` as is. :param producer_extra_config: A dictionary with additional options that will be passed to `confluent_kafka.Producer` as is. + :param state_dir: path to the application state directory, optional. + It should be passed if the application uses stateful operations, otherwise + the exception will be raised. + :param rocksdb_options: RocksDB options. + If `None`, the default options will be used. :param consumer_poll_timeout: timeout for `RowConsumer.poll()`. Default - 1.0s :param producer_poll_timeout: timeout for `RowProducer.poll()`. Default - 0s. :param on_message_processed: a callback triggered when message is successfully @@ -207,6 +231,8 @@ def Quix( on_message_processed=on_message_processed, consumer_poll_timeout=consumer_poll_timeout, producer_poll_timeout=producer_poll_timeout, + state_dir=state_dir, + rocksdb_options=rocksdb_options, ) # Inject Quix config builder to use it in other methods app.set_quix_config_builder(quix_config_builder) @@ -270,6 +296,10 @@ def stop(self): """ self._running = False + @property + def is_stateful(self) -> bool: + return bool(self._state_manager and self._state_manager.stores) + def run( self, dataframe: StreamingDataFrame, @@ -279,11 +309,14 @@ def run( :param dataframe: instance of `StreamingDataFrame` """ - logger.debug("Starting application") + logger.info("Start processing of the streaming dataframe") exit_stack = contextlib.ExitStack() exit_stack.enter_context(self._producer) exit_stack.enter_context(self._consumer) + if self.is_stateful: + exit_stack.enter_context(self._state_manager) + exit_stack.callback( lambda *_: logger.debug("Closing Kafka consumers & producers") ) @@ -293,7 +326,12 @@ def run( logger.info("Start processing of the streaming dataframe") # Subscribe to topics in Kafka and start polling - self._consumer.subscribe(list(dataframe.topics_in.values())) + self._consumer.subscribe( + list(dataframe.topics_in.values()), + on_assign=self._on_assign, + on_revoke=self._on_revoke, + on_lost=self._on_lost, + ) # Start polling Kafka for messages and callbacks self._running = True while self._running: @@ -310,22 +348,32 @@ def run( continue first_row = rows[0] - - for row in rows: - try: - dataframe.process(row=row) - except Exception as exc: - # TODO: This callback might be triggered because of Producer - # errors too because they happen within ".process()" - to_suppress = self._on_processing_error(exc, row, logger) - if not to_suppress: - raise - topic_name, partition, offset = ( first_row.topic, first_row.partition, first_row.offset, ) + + if self.is_stateful: + # Store manager has stores registered, starting a transaction + state_transaction = self._state_manager.start_store_transaction( + topic=topic_name, partition=partition, offset=offset + ) + else: + # The application is stateless, use noop transaction + state_transaction = contextlib.nullcontext() + + with state_transaction: + for row in rows: + try: + dataframe.process(row=row) + except Exception as exc: + # TODO: This callback might be triggered because of Producer + # errors too because they happen within ".process()" + to_suppress = self._on_processing_error(exc, row, logger) + if not to_suppress: + raise + # Store the message offset after it's successfully processed self._consumer.store_offsets( offsets=[ @@ -341,3 +389,32 @@ def run( self._on_message_processed(topic_name, partition, offset) logger.info("Stop processing of the streaming dataframe") + + def _on_assign(self, _, topic_partitions: List[TopicPartition]): + """ + Assign new topic partitions to consumer and state. + + :param topic_partitions: list of `TopicPartition` from Kafka + """ + if self.is_stateful: + logger.info(f"Rebalancing: assigning state store partitions") + for tp in topic_partitions: + self._state_manager.on_partition_assign(tp) + + def _on_revoke(self, _, topic_partitions: List[TopicPartition]): + """ + Revoke partitions from consumer and state + """ + if self.is_stateful: + logger.info(f"Rebalancing: revoking state store partitions") + for tp in topic_partitions: + self._state_manager.on_partition_revoke(tp) + + def _on_lost(self, _, topic_partitions: List[TopicPartition]): + """ + Dropping lost partitions from consumer and state + """ + if self.is_stateful: + logger.info(f"Rebalancing: dropping lost state store partitions") + for tp in topic_partitions: + self._state_manager.on_partition_lost(tp) diff --git a/src/StreamingDataFrames/streamingdataframes/exceptions/__init__.py b/src/StreamingDataFrames/streamingdataframes/exceptions/__init__.py index 9b5ed21c9..a572a9c98 100644 --- a/src/StreamingDataFrames/streamingdataframes/exceptions/__init__.py +++ b/src/StreamingDataFrames/streamingdataframes/exceptions/__init__.py @@ -1 +1,2 @@ from .base import * +from .assignment import * diff --git a/src/StreamingDataFrames/streamingdataframes/exceptions/assignment.py b/src/StreamingDataFrames/streamingdataframes/exceptions/assignment.py new file mode 100644 index 000000000..18609a9e4 --- /dev/null +++ b/src/StreamingDataFrames/streamingdataframes/exceptions/assignment.py @@ -0,0 +1,14 @@ +from .base import QuixException + +__all__ = ("PartitionAssignmentError", "KafkaPartitionError") + + +class PartitionAssignmentError(QuixException): + """ + Error happened during partition rebalancing. + Raised from `on_assign`, `on_revoke` and `on_lost` callbacks + """ + + +class KafkaPartitionError(QuixException): + ... diff --git a/src/StreamingDataFrames/streamingdataframes/kafka/consumer.py b/src/StreamingDataFrames/streamingdataframes/kafka/consumer.py index 9d5bbf03d..703a09f8e 100644 --- a/src/StreamingDataFrames/streamingdataframes/kafka/consumer.py +++ b/src/StreamingDataFrames/streamingdataframes/kafka/consumer.py @@ -11,6 +11,8 @@ ) from confluent_kafka.admin import ClusterMetadata +from streamingdataframes.exceptions import PartitionAssignmentError, KafkaPartitionError + __all__ = ( "Consumer", "AutoOffsetReset", @@ -47,6 +49,22 @@ def _default_on_commit_cb( on_commit(error, partitions) +def _wrap_assignment_errors(func): + """ + Wrap exceptions raised from "on_assign", "on_revoke" and "on_lost" callbacks + into `PartitionAssignmentError` + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as exc: + raise PartitionAssignmentError("Error during partition assignment") from exc + + return wrapper + + class Consumer: def __init__( self, @@ -158,8 +176,16 @@ def subscribe( assigned or revoked. """ + @_wrap_assignment_errors def _on_assign_wrapper(consumer: Consumer, partitions: List[TopicPartition]): for partition in partitions: + if partition.error: + raise KafkaPartitionError( + f"Kafka partition error " + f'(topic "{partition.topic}", ' + f'partition "{partition.partition}"): ' + f"{partition.error}" + ) logger.debug( "Assigned partition to a consumer", extra={"topic": partition.topic, "partition": partition.partition}, @@ -167,8 +193,16 @@ def _on_assign_wrapper(consumer: Consumer, partitions: List[TopicPartition]): if on_assign is not None: on_assign(consumer, partitions) + @_wrap_assignment_errors def _on_revoke_wrapper(consumer: Consumer, partitions: List[TopicPartition]): for partition in partitions: + if partition.error: + raise KafkaPartitionError( + f"Kafka partition error " + f'(topic "{partition.topic}", ' + f'partition "{partition.partition}"): ' + f"{partition.error}" + ) logger.debug( "Revoking partition from a consumer", extra={"topic": partition.topic, "partition": partition.partition}, @@ -176,12 +210,20 @@ def _on_revoke_wrapper(consumer: Consumer, partitions: List[TopicPartition]): if on_revoke is not None: on_revoke(consumer, partitions) + @_wrap_assignment_errors def _on_lost_wrapper(consumer: Consumer, partitions: List[TopicPartition]): for partition in partitions: logger.debug( "Consumer lost a partition", extra={"topic": partition.topic, "partition": partition.partition}, ) + if partition.error: + raise KafkaPartitionError( + f"Kafka partition error " + f'(topic "{partition.topic}", ' + f'partition "{partition.partition}"): ' + f"{partition.error}" + ) if on_lost is not None: on_lost(consumer, partitions) diff --git a/src/StreamingDataFrames/streamingdataframes/rowconsumer.py b/src/StreamingDataFrames/streamingdataframes/rowconsumer.py index f994ed3cb..77f981d7e 100644 --- a/src/StreamingDataFrames/streamingdataframes/rowconsumer.py +++ b/src/StreamingDataFrames/streamingdataframes/rowconsumer.py @@ -5,7 +5,7 @@ from typing_extensions import Protocol from .error_callbacks import ConsumerErrorCallback, default_on_consumer_error -from .exceptions import QuixException +from .exceptions import QuixException, PartitionAssignmentError from .kafka import Consumer, AssignmentStrategy, AutoOffsetReset from .kafka.consumer import RebalancingCallback from .models import Topic, Row @@ -159,6 +159,9 @@ def poll_row(self, timeout: float = None) -> Union[Row, List[Row], None]: """ try: msg = self.poll(timeout=timeout) + except PartitionAssignmentError: + # Always propagate errors happened during assignment + raise except Exception as exc: to_suppress = self._on_error(exc, None, logger) if to_suppress: diff --git a/src/StreamingDataFrames/streamingdataframes/state/__init__.py b/src/StreamingDataFrames/streamingdataframes/state/__init__.py index e69de29bb..a763ac5f5 100644 --- a/src/StreamingDataFrames/streamingdataframes/state/__init__.py +++ b/src/StreamingDataFrames/streamingdataframes/state/__init__.py @@ -0,0 +1,2 @@ +from .manager import * +from .types import * diff --git a/src/StreamingDataFrames/streamingdataframes/state/exceptions.py b/src/StreamingDataFrames/streamingdataframes/state/exceptions.py new file mode 100644 index 000000000..c7c534c1a --- /dev/null +++ b/src/StreamingDataFrames/streamingdataframes/state/exceptions.py @@ -0,0 +1,13 @@ +from streamingdataframes.exceptions import QuixException + + +class PartitionNotAssignedError(QuixException): + ... + + +class StoreNotRegisteredError(QuixException): + ... + + +class InvalidStoreTransactionStateError(QuixException): + ... diff --git a/src/StreamingDataFrames/streamingdataframes/state/manager.py b/src/StreamingDataFrames/streamingdataframes/state/manager.py new file mode 100644 index 000000000..596e11a8e --- /dev/null +++ b/src/StreamingDataFrames/streamingdataframes/state/manager.py @@ -0,0 +1,274 @@ +import contextlib +import logging +from pathlib import Path +from typing import List, Dict, Optional, Iterator + +from streamingdataframes.types import TopicPartition +from .exceptions import ( + StoreNotRegisteredError, + InvalidStoreTransactionStateError, +) +from .rocksdb import RocksDBStore, RocksDBOptionsType +from .types import ( + Store, + PartitionTransaction, + StorePartition, +) + +__all__ = ("StateStoreManager",) + +logger = logging.getLogger(__name__) + +_DEFAULT_STATE_STORE_NAME = "default" + + +class StateStoreManager: + """ + Class for managing state stores and partitions. + + StateStoreManager is responsible for: + - reacting to rebalance callbacks + - managing the individual state stores + - providing access to store transactions + """ + + def __init__( + self, + group_id: str, + state_dir: str, + rocksdb_options: Optional[RocksDBOptionsType] = None, + ): + self._group_id = group_id + self._state_dir = (Path(state_dir) / group_id).absolute() + self._rocksdb_options = rocksdb_options + self._stores: Dict[str, Dict[str, Store]] = {} + self._transaction: Optional[_MultiStoreTransaction] = None + + def _init_state_dir(self): + logger.info(f'Initializing state directory at "{self._state_dir}"') + if self._state_dir.exists(): + if not self._state_dir.is_dir(): + raise FileExistsError( + f'Path "{self._state_dir}" already exists, ' + f"but it is not a directory" + ) + logger.info(f'State directory already exists at "{self._state_dir}"') + else: + self._state_dir.mkdir(parents=True) + logger.info(f'Created state directory at "{self._state_dir}"') + + @property + def stores(self) -> Dict[str, Dict[str, Store]]: + """ + Map of registered state stores + :return: dict in format {topic: {store_name: store}} + """ + return self._stores + + def get_store( + self, topic: str, store_name: str = _DEFAULT_STATE_STORE_NAME + ) -> Store: + """ + Get a store for given name and topic + :param topic: topic name + :param store_name: store name + :return: instance of `Store` + """ + store = self._stores.get(topic, {}).get(store_name) + if store is None: + raise StoreNotRegisteredError( + f'Store "{store_name}" (topic "{topic}") is not registered' + ) + return store + + def register_store( + self, topic_name: str, store_name: str = _DEFAULT_STATE_STORE_NAME + ): + """ + Register a state store to be managed by StateStoreManager. + + During processing, the StateStoreManager will react to rebalancing callbacks + and assign/revoke the partitions for registered stores. + + Each store can be registered only once for each topic. + + :param topic_name: topic name + :param store_name: store name + """ + store = self._stores.get(topic_name, {}).get(store_name) + if store is None: + self._stores.setdefault(topic_name, {})[store_name] = RocksDBStore( + name=store_name, + topic=topic_name, + base_dir=str(self._state_dir), + options=self._rocksdb_options, + ) + + def on_partition_assign(self, tp: TopicPartition) -> List[StorePartition]: + """ + Assign store partitions for each registered store for the given `TopicPartition` + and return a list of assigned `StorePartition` objects. + + :param tp: `TopicPartition` from Kafka consumer + :return: list of assigned `StorePartition` + """ + + store_partitions = [] + for store in self._stores.get(tp.topic, {}).values(): + store_partitions.append(store.assign_partition(tp.partition)) + return store_partitions + + def on_partition_revoke(self, tp: TopicPartition): + """ + Revoke store partitions for each registered store for the given `TopicPartition` + + :param tp: `TopicPartition` from Kafka consumer + """ + for store in self._stores.get(tp.topic, {}).values(): + store.revoke_partition(tp.partition) + + def on_partition_lost(self, tp: TopicPartition): + """ + Revoke and close store partitions for each registered store for the given + `TopicPartition` + + :param tp: `TopicPartition` from Kafka consumer + """ + for store in self._stores.get(tp.topic, {}).values(): + store.revoke_partition(tp.partition) + + def init(self): + """ + Initialize `StateStoreManager` and create a store directory + :return: + """ + self._init_state_dir() + + def close(self): + """ + Close all registered stores + """ + for topic_stores in self._stores.values(): + for store in topic_stores.values(): + store.close() + + def get_store_transaction( + self, store_name: str = _DEFAULT_STATE_STORE_NAME + ) -> PartitionTransaction: + """ + Get active `PartitionTransaction` for the store + :param store_name: + :return: + """ + if self._transaction is None: + raise InvalidStoreTransactionStateError( + "Store transaction is not started yet" + ) + return self._transaction.get_store_transaction(store_name=store_name) + + @contextlib.contextmanager + def start_store_transaction( + self, topic: str, partition: int, offset: int + ) -> Iterator["_MultiStoreTransaction"]: + """ + Starting the multi-store transaction for the Kafka message. + + This transaction will keep track of all used stores and flush them in the end. + If any exception is catched during this transaction, none of them + will be flushed as a best effort to keep stores consistent in "at-least-once" setting. + + There can be only one active transaction at a time. Starting a new transaction + before the end of the current one will fail. + + + :param topic: message topic + :param partition: message partition + :param offset: message offset + """ + if not self._stores.get(topic): + raise StoreNotRegisteredError( + f'Topic "{topic}" does not have stores registered' + ) + + if self._transaction is not None: + raise InvalidStoreTransactionStateError( + "Another transaction is already in progress" + ) + self._transaction = _MultiStoreTransaction( + manager=self, topic=topic, partition=partition, offset=offset + ) + try: + yield self._transaction + self._transaction.flush() + finally: + self._transaction = None + + def __enter__(self): + self.init() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + +class _MultiStoreTransaction: + """ + A transaction-like class to manage flushing of multiple state partitions for each + processed message. + + It is responsible for: + - Keeping track of actual DBTransactions for the individiual stores + - Flushing of the opened transactions in the end + + """ + + def __init__( + self, manager: "StateStoreManager", topic: str, partition: int, offset: int + ): + self._manager = manager + self._transactions: Dict[str, PartitionTransaction] = {} + self._topic = topic + self._partition = partition + self._offset = offset + + def get_store_transaction( + self, store_name: str = _DEFAULT_STATE_STORE_NAME + ) -> PartitionTransaction: + """ + Get a PartitionTransaction for the given store + + It will return already started transaction if there's one. + + :param store_name: store name + :return: instance of `PartitionTransaction` + """ + transaction = self._transactions.get(store_name) + if transaction is not None: + return transaction + + store = self._manager.get_store(topic=self._topic, store_name=store_name) + transaction = store.start_partition_transaction(partition=self._partition) + self._transactions[store_name] = transaction + return transaction + + def flush(self): + """ + Flush all `PartitionTransaction` instances for each registered store and + save the last processed offset for each partition. + + Empty transactions without any updates will not be flushed. + + If there are any failed transactions, no transactions will be flushed + to keep the stores consistent. + """ + for store_name, transaction in self._transactions.items(): + if transaction.failed: + logger.warning( + f'Detected failed transaction for store "{store_name}" ' + f'(topic "{self._topic}" partition "{self._partition}" ' + f'offset "{self._offset}), state transactions will not be flushed"' + ) + return + + for transaction in self._transactions.values(): + transaction.maybe_flush(offset=self._offset) diff --git a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/__init__.py b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/__init__.py index 4e125834e..7bda6512f 100644 --- a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/__init__.py +++ b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/__init__.py @@ -1,3 +1,5 @@ -from .stores import * from .exceptions import * from .options import * +from .partition import * +from .store import * +from .types import * diff --git a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/options.py b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/options.py index 51f180918..3dfbaa198 100644 --- a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/options.py +++ b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/options.py @@ -1,15 +1,13 @@ import dataclasses -from typing import Optional, Mapping, Literal +from typing import Optional, Mapping import rocksdict from rocksdict import DBCompressionType -from .types import RocksDBOptionsProto +from .types import RocksDBOptionsType, CompressionType __all__ = ("RocksDBOptions",) -CompressionType = Literal["none", "snappy", "zlib", "bz2", "lz4", "lz4hc", "zstd"] - COMPRESSION_TYPES: Mapping[CompressionType, DBCompressionType] = { "none": DBCompressionType.none(), "snappy": DBCompressionType.snappy(), @@ -22,7 +20,7 @@ @dataclasses.dataclass(frozen=True) -class RocksDBOptions(RocksDBOptionsProto): +class RocksDBOptions(RocksDBOptionsType): """ Common RocksDB database options. diff --git a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/stores.py b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/partition.py similarity index 72% rename from src/StreamingDataFrames/streamingdataframes/state/rocksdb/stores.py rename to src/StreamingDataFrames/streamingdataframes/state/rocksdb/partition.py index b054562cc..b08875938 100644 --- a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/stores.py +++ b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/partition.py @@ -1,23 +1,30 @@ import contextlib import functools import logging +import struct import time from typing import Any, Union, Optional import rocksdict from typing_extensions import Self +from streamingdataframes.state.types import ( + DumpsFunc, + LoadsFunc, + PartitionTransaction, + StorePartition, +) from .exceptions import ( StateTransactionError, NestedPrefixError, ) from .options import RocksDBOptions from .serialization import serialize, deserialize, serialize_key -from .types import RocksDBOptionsProto, DumpsFunc, LoadsFunc +from .types import RocksDBOptionsType __all__ = ( - "RocksDBStorage", - "TransactionStore", + "RocksDBStorePartition", + "RocksDBPartitionTransaction", ) logger = logging.getLogger(__name__) @@ -28,11 +35,21 @@ _DEFAULT_PREFIX = b"" +_PROCESSED_OFFSET_KEY = b"__topic_offset__" + + +def _int_to_int64_bytes(value: int) -> bytes: + return struct.pack(">q", value) + + +def _int_from_int64_bytes(value: bytes) -> int: + return struct.unpack(">q", value)[0] -class RocksDBStorage: + +class RocksDBStorePartition(StorePartition): """ - A base class to for accessing state in RocksDB. - It represents a single partition of the state. + A base class to access state in RocksDB. + It represents a single RocksDB database. Responsibilities: 1. Managing access to the RocksDB instance @@ -40,7 +57,7 @@ class RocksDBStorage: 3. Flushing WriteBatches to the RocksDB It opens the RocksDB on `__init__`. If the db is locked by another process, - it will retry opening according to `open_max_retries` and `open_retry_backoff` + it will retry according to `open_max_retries` and `open_retry_backoff`. :param path: an absolute path to the RocksDB folder :param options: RocksDB options. If `None`, the default options will be used. @@ -56,7 +73,7 @@ class RocksDBStorage: def __init__( self, path: str, - options: Optional[RocksDBOptionsProto] = None, + options: Optional[RocksDBOptionsType] = None, open_max_retries: int = 10, open_retry_backoff: float = 3.0, dumps: Optional[DumpsFunc] = None, @@ -68,16 +85,18 @@ def __init__( self._open_retry_backoff = open_retry_backoff self._dumps = dumps self._loads = loads - self._db = self._open_db() + self._db = self._init_db() - def begin(self) -> "TransactionStore": + def begin(self) -> "RocksDBPartitionTransaction": """ - Create a new `TransactionStore` object. - Using `TransactionStore` is a recommended way for accessing the data. + Create a new `RocksDBTransaction` object. + Using `RocksDBTransaction` is a recommended way for accessing the data. - :return: an instance of `TransactionStore` + :return: an instance of `RocksDBTransaction` """ - return TransactionStore(storage=self, dumps=self._dumps, loads=self._loads) + return RocksDBPartitionTransaction( + partition=self, dumps=self._dumps, loads=self._loads + ) def write(self, batch: rocksdict.WriteBatch): """ @@ -105,13 +124,22 @@ def exists(self, key: bytes) -> bool: """ return key in self._db + def get_processed_offset(self) -> Optional[int]: + """ + Get last processed offset for the given partition + :return: offset or `None` if there's no processed offset yet + """ + offset_bytes = self._db.get(_PROCESSED_OFFSET_KEY) + if offset_bytes is not None: + return _int_from_int64_bytes(offset_bytes) + def close(self): """ Close the underlying RocksDB """ - logger.info(f'Closing db partition on "{self._path}"') + logger.debug(f'Closing rocksdb partition on "{self._path}"') self._db.close() - logger.info(f'Successfully closed db partition on "{self._path}"') + logger.debug(f'Closed rocksdb partition on "{self._path}"') @property def path(self) -> str: @@ -132,11 +160,15 @@ def destroy(cls, path: str): """ rocksdict.Rdict.destroy(path=path) # noqa - def _open_db(self) -> rocksdict.Rdict: + def _init_db(self) -> rocksdict.Rdict: + db = self._open_rocksdb() + return db + + def _open_rocksdb(self) -> rocksdict.Rdict: attempt = 1 while True: - logger.info( - f'Opening db partition on "{self._path}" attempt={attempt}', + logger.debug( + f'Opening rocksdb partition on "{self._path}" attempt={attempt}', ) try: db = rocksdict.Rdict( @@ -144,8 +176,8 @@ def _open_db(self) -> rocksdict.Rdict: options=self._options.to_options(), access_type=rocksdict.AccessType.read_write(), # noqa ) - logger.info( - f'Successfully opened db partition on "{self._path}"', + logger.debug( + f'Successfully opened rocksdb partition on "{self._path}"', ) return db except Exception as exc: @@ -157,7 +189,7 @@ def _open_db(self) -> rocksdict.Rdict: raise logger.warning( - f"Failed to open db partition, cannot acquire a lock. " + f"Failed to open rocksdb partition, cannot acquire a lock. " f"Retrying in {self._open_retry_backoff}sec." ) @@ -173,11 +205,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): def _validate_transaction_state(func): """ - Check that the state of `TransactionStore` is valid before calling a method + Check that the state of `RocksDBTransaction` is valid before calling a method """ @functools.wraps(func) - def wrapper(self: "TransactionStore", *args, **kwargs): + def wrapper(*args, **kwargs): + self: "RocksDBPartitionTransaction" = args[0] if self.failed: raise StateTransactionError( "Transaction is failed, create a new one to proceed" @@ -187,53 +220,51 @@ def wrapper(self: "TransactionStore", *args, **kwargs): "Transaction is already finished, create a new one to proceed" ) - return func(self, *args, **kwargs) + return func(*args, **kwargs) return wrapper -class TransactionStore: +class RocksDBPartitionTransaction(PartitionTransaction): """ - A transaction-based store to perform simple key-value operations like - "get", "set", "delete" and "exists". - It is supposed to be used by stateful operators in Streaming DataFrames. - + A transaction class to perform simple key-value operations like + "get", "set", "delete" and "exists" on a single RocksDB partition. Serialization ************* - `TransactionStore` automatically serializes keys and values to JSON. + `RocksDBTransaction` automatically serializes keys and values to JSON. Prefixing ********* - `TransactionStore` allows to set prefixes for the keys in the given code block + `RocksDBTransaction` allows to set prefixes for the keys in the given code block using :meth:`with_prefix()` context manager. Normally, `StreamingDataFrame` class will use message keys as prefixes in order to namespace the stored keys across different messages. Transactional properties ************************ - `TransactionStore` uses a combination of in-memory update cache + `RocksDBTransaction` uses a combination of in-memory update cache and RocksDB's WriteBatch in order to accumulate all the state mutations in a single batch, flush them atomically, and allow the updates be visible within the transaction before it's flushed (aka "read-your-own-writes" problem). - A single transaction can span across multiple incoming messages, - If any mutation fails during the transaction - (e.g. we failed to write the updates to the RocksDB), the whole transaction is now - invalid and cannot be used anymore. - In this case, a new `TransactionStore` should be created. + (e.g. we failed to write the updates to the RocksDB), the whole transaction + will be marked as failed and cannot be used anymore. + In this case, a new `RocksDBTransaction` should be created. - The `TransactionStore` also cannot be used after it's flushed, and a new - instance of `TransactionStore` should be used instead. + `RocksDBTransaction` can be used only once. - :param storage: instance of `StateStorage` to be used for accessing + :param partition: instance of `RocksDBStatePartition` to be used for accessing the underlying RocksDB - + :param dumps: a function to serialize data to bytes, optional. + By default, `json.dumps` will be used. + :param loads: a function to deserialize data from bytes, optional. + By default, `json.loads` will be used. """ __slots__ = ( - "_storage", + "_partition", "_update_cache", "_batch", "_prefix", @@ -245,11 +276,11 @@ class TransactionStore: def __init__( self, - storage: RocksDBStorage, + partition: RocksDBStorePartition, dumps: Optional[DumpsFunc] = None, loads: Optional[LoadsFunc] = None, ): - self._storage = storage + self._partition = partition self._update_cache = {} self._batch = rocksdict.WriteBatch(raw_mode=True) self._prefix = _DEFAULT_PREFIX @@ -261,7 +292,7 @@ def __init__( @contextlib.contextmanager def with_prefix(self, prefix: Any = b"") -> Self: """ - Prefix all the keys in the given scope. + A context manager set the prefix for all keys in the scope. Normally, it's called by Streaming DataFrames engine to ensure that every message key is stored separately. @@ -312,7 +343,7 @@ def get(self, key: Any, default: Any = None) -> Optional[Any]: return self._deserialize_value(cached) # The value is not found in cache, check the db - stored = self._storage.get(key_serialized, _sentinel) + stored = self._partition.get(key_serialized, _sentinel) if stored is not _sentinel: return self._deserialize_value(stored) return default @@ -346,7 +377,6 @@ def delete(self, key: Any): It first deletes the key from the update cache. :param key: a JSON-deserializable key. - :return: """ key_serialized = self._serialize_key(key) try: @@ -370,7 +400,7 @@ def exists(self, key: Any) -> bool: key_serialized = self._serialize_key(key) if key_serialized in self._update_cache: return True - return self._storage.exists(key_serialized) + return self._partition.exists(key_serialized) @property def completed(self) -> bool: @@ -378,7 +408,7 @@ def completed(self) -> bool: Check if the transaction is completed. It doesn't indicate whether transaction is successful or not. - Use `TransactionStore.failed` for that. + Use `RocksDBTransaction.failed` for that. The completed transaction should not be re-used. @@ -399,20 +429,27 @@ def failed(self) -> bool: return self._failed @_validate_transaction_state - def _flush(self): - # TODO: Does Flush sometimes need to be called manually? + def maybe_flush(self, offset: Optional[int] = None): """ Flush the recent updates to the database and empty the update cache. It writes the WriteBatch to RocksDB and marks itself as finished. If writing fails, the transaction will be also marked as "failed" and cannot be used anymore. - """ + .. note:: If no keys have been modified during the transaction + (i.e no "set" or "delete" have been called at least once), it will + not flush ANY data to the database including the offset in order to optimize + I/O. + + :param offset: offset of the last processed message, optional. + """ try: # Don't write batches if this transaction doesn't change any keys if len(self._batch): - self._storage.write(self._batch) + if offset is not None: + self._batch.put(_PROCESSED_OFFSET_KEY, _int_to_int64_bytes(offset)) + self._partition.write(self._batch) except Exception: self._failed = True raise @@ -432,5 +469,5 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - if not self._failed: - self._flush() + if exc_val is None and not self._failed: + self.maybe_flush() diff --git a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/serialization.py b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/serialization.py index e4cc48e09..e393badf1 100644 --- a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/serialization.py +++ b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/serialization.py @@ -1,8 +1,8 @@ import json from typing import Any, Optional +from streamingdataframes.state.types import DumpsFunc, LoadsFunc from .exceptions import StateSerializationError -from .types import DumpsFunc, LoadsFunc __all__ = ( "serialize", diff --git a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/store.py b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/store.py new file mode 100644 index 000000000..89d12dc8e --- /dev/null +++ b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/store.py @@ -0,0 +1,171 @@ +import logging +from pathlib import Path +from typing import Dict, Optional + +from streamingdataframes.state.exceptions import PartitionNotAssignedError +from streamingdataframes.state.types import DumpsFunc, LoadsFunc, Store +from .partition import ( + RocksDBStorePartition, + RocksDBPartitionTransaction, +) +from .types import RocksDBOptionsType + +logger = logging.getLogger(__name__) + +__all__ = ("RocksDBStore",) + + +class RocksDBStore(Store): + def __init__( + self, + name: str, + topic: str, + base_dir: str, + options: Optional[RocksDBOptionsType] = None, + open_max_retries: int = 10, + open_retry_backoff: float = 3.0, + dumps: Optional[DumpsFunc] = None, + loads: Optional[LoadsFunc] = None, + ): + """ + RocksDB-based state store. + + It keeps track of individual store partitions and provides access to the + partitions' transactions. + + :param name: a unique store name + :param topic: a topic name for this store + :param base_dir: path to a directory with the state + :param options: RocksDB options. If `None`, the default options will be used. + :param open_max_retries: number of times to retry opening the database + if it's locked by another process. To disable retrying, pass 0. + :param open_retry_backoff: number of seconds to wait between each retry. + :param dumps: the function used to serialize keys & values to bytes in + transactions. Default - `json.dumps` + :param loads: the function used to deserialize keys & values from bytes + to objects in transactions. Default - `json.loads`. + """ + self._name = name + self._topic = topic + self._partitions_dir = Path(base_dir).absolute() / self._name / self._topic + self._transactions: Dict[int, RocksDBPartitionTransaction] = {} + self._partitions: Dict[int, RocksDBStorePartition] = {} + self._options = options + self._dumps = dumps + self._loads = loads + self._open_max_retries = open_max_retries + self._open_retry_backoff = open_retry_backoff + + @property + def topic(self) -> str: + """ + Store topic name + """ + return self._topic + + @property + def name(self) -> str: + """ + Store name + """ + return self._name + + @property + def partitions(self) -> Dict[int, RocksDBStorePartition]: + """ + Mapping of assigned store partitions + """ + return self._partitions + + def assign_partition(self, partition: int) -> RocksDBStorePartition: + """ + Open and assign store partition. + + If the partition is already assigned, it will not re-open it and return + the existing partition instead. + + :param partition: partition number + :return: instance of`RocksDBStorePartition` + """ + if partition in self._partitions: + logger.debug( + f'Partition "{partition}" for store "{self._name}" ' + f'(topic "{self._topic}") ' + f"is already assigned" + ) + return self._partitions[partition] + + path = str((self._partitions_dir / str(partition)).absolute()) + store_partition = RocksDBStorePartition( + path=path, + options=self._options, + dumps=self._dumps, + loads=self._loads, + open_max_retries=self._open_max_retries, + open_retry_backoff=self._open_retry_backoff, + ) + + self._partitions[partition] = store_partition + logger.debug( + f'Assigned partition "{partition}" ' + f'for store "{self._name}" (topic "{self._topic}")' + ) + return store_partition + + def revoke_partition(self, partition: int): + """ + Revoke and close the assigned store partition. + + If the partition is not assigned, it will log the message and return. + + :param partition: partition number + """ + store_partition = self._partitions.get(partition) + if store_partition is None: + logger.debug( + f'Partition for store "{self._name}" (topic "{self._topic}") ' + f"is not assigned" + ) + return + + store_partition.close() + self._partitions.pop(partition) + logger.debug(f'Revoked partition "{partition}" for store "{self._name}"') + + def start_partition_transaction( + self, partition: int + ) -> RocksDBPartitionTransaction: + """ + Start a new partition transaction. + + `RocksDBPartitionTransaction` is the primary interface for working with data in + the underlying RocksDB. + + :param partition: partition number + :return: instance of `RocksDBPartitionTransaction` + """ + if partition not in self._partitions: + # Requested partition has not been assigned. Something went completely wrong + raise PartitionNotAssignedError( + f'Partition "{partition}" is not assigned ' + f'to the store "{self._name}" (topic "{self._topic}")' + ) + + store_partition = self._partitions[partition] + return store_partition.begin() + + def close(self): + """ + Close the store and revoke all assigned partitions + """ + logger.debug(f'Closing store "{self._name}" (topic "{self._topic}")') + partitions = list(self._partitions.keys()) + for partition in partitions: + self.revoke_partition(partition) + logger.debug(f'Closed store "{self._name}" (topic "{self._topic}")') + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() diff --git a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/types.py b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/types.py index 5107ad2ba..863729070 100644 --- a/src/StreamingDataFrames/streamingdataframes/state/rocksdb/types.py +++ b/src/StreamingDataFrames/streamingdataframes/state/rocksdb/types.py @@ -1,11 +1,19 @@ -from typing import Callable, Any, Protocol +from typing import Protocol, Optional, Literal import rocksdict -DumpsFunc = Callable[[Any], bytes] -LoadsFunc = Callable[[bytes], Any] +CompressionType = Literal["none", "snappy", "zlib", "bz2", "lz4", "lz4hc", "zstd"] -class RocksDBOptionsProto(Protocol): +class RocksDBOptionsType(Protocol): + write_buffer_size: int + target_file_size_base: int + max_write_buffer_number: int + block_cache_size: int + enable_pipelined_write: bool + compression_type: CompressionType + wal_dir: Optional[str] + db_log_dir: Optional[str] + def to_options(self) -> rocksdict.Options: ... diff --git a/src/StreamingDataFrames/streamingdataframes/state/types.py b/src/StreamingDataFrames/streamingdataframes/state/types.py new file mode 100644 index 000000000..d8bb9dbd6 --- /dev/null +++ b/src/StreamingDataFrames/streamingdataframes/state/types.py @@ -0,0 +1,177 @@ +from typing import Protocol, Any, Optional, Iterator, Callable, Dict + +from typing_extensions import Self + +DumpsFunc = Callable[[Any], bytes] +LoadsFunc = Callable[[bytes], Any] + + +class Store(Protocol): + """ + Abstract state store. + + It keeps track of individual store partitions and provides access to the + partitions' transactions. + """ + + @property + def topic(self) -> str: + """ + Topic name + """ + + @property + def name(self) -> str: + """ + Store name + """ + + @property + def partitions(self) -> Dict[int, "StorePartition"]: + """ + Mapping of assigned store partitions + :return: dict of "{partition: }" + """ + ... + + def assign_partition(self, partition: int) -> "StorePartition": + """ + Assign new store partition + + :param partition: partition number + :return: instance of `StorePartition` + """ + ... + + def revoke_partition(self, partition: int): + """ + Revoke assigned store partition + + :param partition: partition number + """ + + ... + + def start_partition_transaction( + self, partition: int + ) -> Optional["PartitionTransaction"]: + """ + Start a new partition transaction. + + `PartitionTransaction` is the primary interface for working with data in Stores. + :param partition: partition number + :return: instance of `PartitionTransaction` + """ + + def close(self): + """ + Close store and revoke all store partitions + """ + + def __enter__(self): + ... + + def __exit__(self, exc_type, exc_val, exc_tb): + ... + + +class StorePartition(Protocol): + """ + A base class to access state in the underlying storage. + It represents a single instance of some storage (e.g. a single database for + the persistent storage). + + """ + + def begin(self) -> "PartitionTransaction": + """ + State new `PartitionTransaction` + """ + + def get_processed_offset(self) -> Optional[int]: + ... + + +class State(Protocol): + """ + Primary interface for working with key-value state data from `StreamingDataFrame` + """ + + def get(self, key: Any, default: Any = None) -> Optional[Any]: + """ + Get the value for key if key is present in the state, else default + + :param key: key + :param default: default value to return if the key is not found + :return: value or None if the key is not found and `default` is not provided + """ + + def set(self, key: Any, value: Any): + """ + Set value for the key. + :param key: key + :param value: value + """ + + def delete(self, key: Any): + """ + Delete value for the key. + + This function always returns `None`, even if value is not found. + :param key: key + """ + + def exists(self, key: Any) -> bool: + """ + Check if the key exists in state. + :param key: key + :return: True if key exists, False otherwise + """ + + +class PartitionTransaction(State): + """ + A transaction class to perform simple key-value operations like + "get", "set", "delete" and "exists" on a single storage partition. + """ + + @property + def failed(self) -> bool: + """ + Return `True` if transaction failed to update data at some point. + + Failed transactions cannot be re-used. + :return: bool + """ + + @property + def completed(self) -> bool: + """ + Return `True` if transaction is completed. + + Completed transactions cannot be re-used. + :return: bool + """ + ... + + def with_prefix(self, prefix: Any = b"") -> Iterator[Self]: + """ + A context manager set the prefix for all keys in the scope. + + Normally, it's called by `StreamingDataFrame` internals to ensure that every + message key is stored separately. + :param prefix: key prefix + :return: context maager + """ + + def maybe_flush(self, offset: Optional[int] = None): + """ + Flush the recent updates and last processed offset to the storage. + :param offset: offset of the last processed message, optional. + """ + + def __enter__(self): + ... + + def __exit__(self, exc_type, exc_val, exc_tb): + ... diff --git a/src/StreamingDataFrames/streamingdataframes/types.py b/src/StreamingDataFrames/streamingdataframes/types.py new file mode 100644 index 000000000..6a2688269 --- /dev/null +++ b/src/StreamingDataFrames/streamingdataframes/types.py @@ -0,0 +1,7 @@ +from typing import Protocol + + +class TopicPartition(Protocol): + topic: str + partition: int + offset: int diff --git a/src/StreamingDataFrames/tests/conftest.py b/src/StreamingDataFrames/tests/conftest.py index df4220587..137170c54 100644 --- a/src/StreamingDataFrames/tests/conftest.py +++ b/src/StreamingDataFrames/tests/conftest.py @@ -13,6 +13,7 @@ "tests.test_dataframes.test_models.fixtures", "tests.test_dataframes.test_platforms.test_quix.fixtures", "tests.test_dataframes.test_state.test_rocksdb.fixtures", + "tests.test_dataframes.test_state.fixtures", ] KafkaContainer = namedtuple("KafkaContainer", ("broker_address",)) diff --git a/src/StreamingDataFrames/tests/test_dataframes/fixtures.py b/src/StreamingDataFrames/tests/test_dataframes/fixtures.py index 6f62055ff..859211b0b 100644 --- a/src/StreamingDataFrames/tests/test_dataframes/fixtures.py +++ b/src/StreamingDataFrames/tests/test_dataframes/fixtures.py @@ -237,6 +237,7 @@ def factory(value, topic="input-topic", key=b"key", headers=None) -> Row: @pytest.fixture() def app_factory(kafka_container, random_consumer_group): def factory( + consumer_group: Optional[str] = None, auto_offset_reset: AutoOffsetReset = "latest", consumer_extra_config: Optional[dict] = None, producer_extra_config: Optional[dict] = None, @@ -244,10 +245,11 @@ def factory( on_producer_error: Optional[ProducerErrorCallback] = None, on_processing_error: Optional[ProcessingErrorCallback] = None, on_message_processed: Optional[MessageProcessedCallback] = None, + state_dir: Optional[str] = None, ) -> Application: return Application( broker_address=kafka_container.broker_address, - consumer_group=random_consumer_group, + consumer_group=consumer_group or random_consumer_group, auto_offset_reset=auto_offset_reset, consumer_extra_config=consumer_extra_config, producer_extra_config=producer_extra_config, @@ -255,6 +257,7 @@ def factory( on_producer_error=on_producer_error, on_processing_error=on_processing_error, on_message_processed=on_message_processed, + state_dir=state_dir, ) return factory diff --git a/src/StreamingDataFrames/tests/test_dataframes/test_app.py b/src/StreamingDataFrames/tests/test_dataframes/test_app.py index 0d7267151..e20cb7b39 100644 --- a/src/StreamingDataFrames/tests/test_dataframes/test_app.py +++ b/src/StreamingDataFrames/tests/test_dataframes/test_app.py @@ -1,10 +1,12 @@ import time +import uuid from concurrent.futures import Future from json import loads, dumps from unittest.mock import patch, create_autospec import pytest from confluent_kafka import KafkaException, TopicPartition +from tests.utils import TopicPartitionStub from streamingdataframes.app import Application from streamingdataframes.models import ( @@ -380,3 +382,204 @@ def test_topic_init_name_is_prefixed(self, kafka_container): initial_topic_name = "input_topic" topic = app.topic(initial_topic_name, value_deserializer=JSONDeserializer()) assert topic.name == f"{workspace_id}-{initial_topic_name}" + + +class TestApplicationWithState: + def test_run_stateful_success( + self, + app_factory, + producer, + topic_factory, + executor, + state_manager_factory, + tmp_path, + ): + """ + Test that StreamingDataFrame processes 3 messages from Kafka and updates + the counter in the state store + """ + + consumer_group = str(uuid.uuid4()) + state_dir = (tmp_path / "state").absolute() + app = app_factory( + consumer_group=consumer_group, + auto_offset_reset="earliest", + state_dir=state_dir, + ) + + topic_in_name, _ = topic_factory() + + topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer()) + state_manager = app._state_manager + state_manager.register_store(topic_in.name, "default") + + # TODO: Use stateful functions after they're implemented + # Define a function that counts incoming Rows using state + def count(_): + state = state_manager.get_store_transaction("default") + total = state.get("total", 0) + total += 1 + state.set("total", total) + if total == total_messages: + total_consumed.set_result(total) + + df = app.dataframe(topics_in=[topic_in]) + df.apply(count) + + total_messages = 3 + # Produce messages to the topic and flush + data = {"key": b"key", "value": dumps({"key": "value"})} + with producer: + for _ in range(total_messages): + producer.produce(topic_in_name, **data) + + total_consumed = Future() + + # Stop app when the future is resolved + executor.submit(_stop_app_on_future, app, total_consumed, 10.0) + app.run(df) + + # Check that the values are actually in the DB + state_manager = state_manager_factory( + group_id=consumer_group, state_dir=state_dir + ) + state_manager.register_store(topic_in.name, "default") + state_manager.on_partition_assign( + TopicPartitionStub(topic=topic_in.name, partition=0) + ) + store = state_manager.get_store(topic=topic_in.name, store_name="default") + with store.start_partition_transaction(partition=0) as tx: + assert tx.get("total") == total_consumed.result() + + def test_run_stateful_processing_fails( + self, + app_factory, + producer, + topic_factory, + executor, + state_manager_factory, + tmp_path, + ): + consumer_group = str(uuid.uuid4()) + state_dir = (tmp_path / "state").absolute() + app = app_factory( + consumer_group=consumer_group, + auto_offset_reset="earliest", + state_dir=state_dir, + ) + + topic_in_name, _ = topic_factory() + + topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer()) + state_manager = app._state_manager + state_manager.register_store(topic_in.name, "default") + + # TODO: Use stateful functions after they're implemented + # Define a function that counts incoming Rows using state + def count(_): + state = state_manager.get_store_transaction("default") + total = state.get("total", 0) + total += 1 + state.set("total", total) + + failed = Future() + + def fail(_): + failed.set_result(True) + raise ValueError("test") + + df = app.dataframe(topics_in=[topic_in]) + df.apply(count) + df.apply(fail) + + total_messages = 3 + # Produce messages to the topic and flush + data = {"key": b"key", "value": dumps({"key": "value"})} + with producer: + for _ in range(total_messages): + producer.produce(topic_in_name, **data) + + # Stop app when the future is resolved + executor.submit(_stop_app_on_future, app, failed, 10.0) + with pytest.raises(ValueError): + app.run(df) + + # Ensure that nothing was committed to the DB + state_manager = state_manager_factory( + group_id=consumer_group, state_dir=state_dir + ) + state_manager.register_store(topic_in.name, "default") + state_manager.on_partition_assign( + TopicPartitionStub(topic=topic_in.name, partition=0) + ) + store = state_manager.get_store(topic=topic_in.name, store_name="default") + with store.start_partition_transaction(partition=0) as tx: + assert tx.get("total") is None + + def test_run_stateful_suppress_processing_errors( + self, + app_factory, + producer, + topic_factory, + executor, + state_manager_factory, + tmp_path, + ): + consumer_group = str(uuid.uuid4()) + state_dir = (tmp_path / "state").absolute() + app = app_factory( + consumer_group=consumer_group, + auto_offset_reset="earliest", + state_dir=state_dir, + # Suppress errors during message processing + on_processing_error=lambda *args: True, + ) + + topic_in_name, _ = topic_factory() + + topic_in = app.topic(topic_in_name, value_deserializer=JSONDeserializer()) + state_manager = app._state_manager + state_manager.register_store(topic_in.name, "default") + + # TODO: Use stateful functions after they're implemented + # Define a function that counts incoming Rows using state + def count(_): + state = state_manager.get_store_transaction("default") + total = state.get("total", 0) + total += 1 + state.set("total", total) + if total == total_messages: + total_consumed.set_result(total) + + def fail(_): + raise ValueError("test") + + df = app.dataframe(topics_in=[topic_in]) + df.apply(count) + df.apply(fail) + + total_messages = 3 + # Produce messages to the topic and flush + data = {"key": b"key", "value": dumps({"key": "value"})} + with producer: + for _ in range(total_messages): + producer.produce(topic_in_name, **data) + + total_consumed = Future() + + # Stop app when the future is resolved + executor.submit(_stop_app_on_future, app, total_consumed, 10.0) + # Run the application + app.run(df) + + # Ensure that data is committed to the DB + state_manager = state_manager_factory( + group_id=consumer_group, state_dir=state_dir + ) + state_manager.register_store(topic_in.name, "default") + state_manager.on_partition_assign( + TopicPartitionStub(topic=topic_in.name, partition=0) + ) + store = state_manager.get_store(topic=topic_in.name, store_name="default") + with store.start_partition_transaction(partition=0) as tx: + assert tx.get("total") == total_consumed.result() diff --git a/src/StreamingDataFrames/tests/test_dataframes/test_rowconsumer.py b/src/StreamingDataFrames/tests/test_dataframes/test_rowconsumer.py index a2835418d..6cf0b8437 100644 --- a/src/StreamingDataFrames/tests/test_dataframes/test_rowconsumer.py +++ b/src/StreamingDataFrames/tests/test_dataframes/test_rowconsumer.py @@ -1,6 +1,7 @@ import pytest from confluent_kafka import KafkaError, TopicPartition +from streamingdataframes.exceptions import PartitionAssignmentError from streamingdataframes.models import ( Deserializer, IgnoreMessage, @@ -162,3 +163,24 @@ def on_error(exc, *args): row = consumer.poll_row(10.0) assert row is None assert suppressed + + def test_poll_row_kafka_error_suppress_except_partition_assignment( + self, row_consumer_factory, topic_json_serdes_factory, producer + ): + topic = topic_json_serdes_factory() + + def on_error(*_): + return True + + def on_assign(*_): + raise ValueError("Test") + + with row_consumer_factory( + auto_offset_reset="error", + on_error=on_error, + ) as consumer, producer: + producer.produce(topic.name, key=b"key", value=b"value") + producer.flush() + consumer.subscribe([topic], on_assign=on_assign) + with pytest.raises(PartitionAssignmentError): + consumer.poll_row(10.0) diff --git a/src/StreamingDataFrames/tests/test_dataframes/test_state/fixtures.py b/src/StreamingDataFrames/tests/test_dataframes/test_state/fixtures.py new file mode 100644 index 000000000..550338e80 --- /dev/null +++ b/src/StreamingDataFrames/tests/test_dataframes/test_state/fixtures.py @@ -0,0 +1,26 @@ +import uuid +from typing import Optional + +import pytest + +from streamingdataframes.state import StateStoreManager + + +@pytest.fixture() +def state_manager_factory(tmp_path): + def factory( + group_id: Optional[str] = None, state_dir: Optional[str] = None + ) -> StateStoreManager: + group_id = group_id or str(uuid.uuid4()) + state_dir = state_dir or str(uuid.uuid4()) + return StateStoreManager(group_id=group_id, state_dir=str(tmp_path / state_dir)) + + return factory + + +@pytest.fixture() +def state_manager(state_manager_factory) -> StateStoreManager: + manager = state_manager_factory() + manager.init() + yield manager + manager.close() diff --git a/src/StreamingDataFrames/tests/test_dataframes/test_state/test_manager.py b/src/StreamingDataFrames/tests/test_dataframes/test_state/test_manager.py new file mode 100644 index 000000000..41ce316b1 --- /dev/null +++ b/src/StreamingDataFrames/tests/test_dataframes/test_state/test_manager.py @@ -0,0 +1,188 @@ +import contextlib +import uuid +from unittest.mock import patch + +import pytest +import rocksdict +from tests.utils import TopicPartitionStub + +from streamingdataframes.state.exceptions import ( + StoreNotRegisteredError, + InvalidStoreTransactionStateError, +) + + +class TestStateStoreManager: + def test_init_close(self, state_manager_factory): + with state_manager_factory(): + ... + + def test_init_state_dir_exists_success(self, state_manager_factory, tmp_path): + group_id = str(uuid.uuid4()) + base_dir_path = tmp_path / "state" + base_dir_path.mkdir(parents=True) + (base_dir_path / group_id).mkdir() + + with state_manager_factory(group_id=group_id, state_dir=str(base_dir_path)): + ... + + def test_init_state_dir_exists_not_a_dir_fails( + self, state_manager_factory, tmp_path + ): + group_id = str(uuid.uuid4()) + base_dir_path = tmp_path / "state" + base_dir_path.mkdir() + (base_dir_path / group_id).touch() + + with pytest.raises(FileExistsError): + with state_manager_factory(group_id=group_id, state_dir=str(base_dir_path)): + ... + + def test_rebalance_partitions_stores_not_registered(self, state_manager): + tp = TopicPartitionStub("topic", 0) + # It's ok to rebalance partitions when there are no stores registered + state_manager.on_partition_assign(tp) + state_manager.on_partition_revoke(tp) + state_manager.on_partition_lost(tp) + + def test_assign_revoke_partitions_stores_registered(self, state_manager): + state_manager.register_store("topic1", store_name="store1") + state_manager.register_store("topic1", store_name="store2") + state_manager.register_store("topic2", store_name="store1") + + stores_list = [s for d in state_manager.stores.values() for s in d.values()] + assert len(stores_list) == 3 + + partitions = [ + TopicPartitionStub("topic1", 0), + TopicPartitionStub("topic2", 0), + ] + + store_partitions = [] + for tp in partitions: + store_partitions.extend(state_manager.on_partition_assign(tp)) + assert len(store_partitions) == 3 + + assert len(state_manager.get_store("topic1", "store1").partitions) == 1 + assert len(state_manager.get_store("topic1", "store2").partitions) == 1 + assert len(state_manager.get_store("topic2", "store1").partitions) == 1 + + for tp in partitions: + state_manager.on_partition_revoke(tp) + + assert not state_manager.get_store("topic1", "store1").partitions + assert not state_manager.get_store("topic1", "store2").partitions + assert not state_manager.get_store("topic2", "store1").partitions + + def test_assign_lose_partitions_stores_registered(self, state_manager): + state_manager.register_store("topic1", store_name="store1") + state_manager.register_store("topic1", store_name="store2") + state_manager.register_store("topic2", store_name="store1") + + stores_list = [s for d in state_manager.stores.values() for s in d.values()] + assert len(stores_list) == 3 + + partitions = [ + TopicPartitionStub("topic1", 0), + TopicPartitionStub("topic2", 0), + ] + + for tp in partitions: + state_manager.on_partition_assign(tp) + assert len(state_manager.get_store("topic1", "store1").partitions) == 1 + assert len(state_manager.get_store("topic1", "store2").partitions) == 1 + assert len(state_manager.get_store("topic2", "store1").partitions) == 1 + + for tp in partitions: + state_manager.on_partition_lost(tp) + + assert not state_manager.get_store("topic1", "store1").partitions + assert not state_manager.get_store("topic1", "store2").partitions + assert not state_manager.get_store("topic2", "store1").partitions + + def test_register_store_twice(self, state_manager): + state_manager.register_store("topic", "store") + state_manager.register_store("topic", "store") + + def test_get_store_not_registered(self, state_manager): + with pytest.raises(StoreNotRegisteredError): + state_manager.get_store("topic", "store") + + def test_store_transaction_success(self, state_manager): + state_manager.register_store("topic", "store") + tp = TopicPartitionStub("topic", 0) + state_manager.on_partition_assign(tp) + + store = state_manager.get_store("topic", "store") + store_partition = store.partitions[0] + + assert store_partition.get_processed_offset() is None + + with state_manager.start_store_transaction("topic", partition=0, offset=1): + tx = state_manager.get_store_transaction("store") + tx.set("some_key", "some_value") + + state_manager.on_partition_assign(tp) + + store = state_manager.get_store("topic", "store") + store_partition = store.partitions[0] + + assert store_partition.get_processed_offset() == 1 + + def test_store_transaction_no_flush_on_exception(self, state_manager): + state_manager.register_store("topic", "store") + state_manager.on_partition_assign(TopicPartitionStub("topic", 0)) + store = state_manager.get_store("topic", "store") + + with contextlib.suppress(Exception): + with state_manager.start_store_transaction("topic", partition=0, offset=1): + tx = state_manager.get_store_transaction("store") + tx.set("some_key", "some_value") + raise ValueError() + + store_partition = store.partitions[0] + assert store_partition.get_processed_offset() is None + + def test_store_transaction_no_flush_if_partition_transaction_failed( + self, state_manager + ): + """ + Ensure that no PartitionTransactions are flushed to the DB if + any of them fails + """ + state_manager.register_store("topic", "store1") + state_manager.register_store("topic", "store2") + state_manager.on_partition_assign(TopicPartitionStub("topic", 0)) + store1 = state_manager.get_store("topic", "store1") + store2 = state_manager.get_store("topic", "store2") + + with state_manager.start_store_transaction("topic", partition=0, offset=1): + tx_store1 = state_manager.get_store_transaction("store1") + tx_store2 = state_manager.get_store_transaction("store2") + # Simulate exception in one of the transactions + with contextlib.suppress(ValueError), patch.object( + rocksdict.WriteBatch, "put", side_effect=ValueError("test") + ): + tx_store1.set("some_key", "some_value") + tx_store2.set("some_key", "some_value") + + assert store1.partitions[0].get_processed_offset() is None + assert store2.partitions[0].get_processed_offset() is None + + def test_get_store_transaction_store_not_registered_fails(self, state_manager): + with pytest.raises(StoreNotRegisteredError): + with state_manager.start_store_transaction("topic", 0, 0): + ... + + def test_get_store_transaction_not_started(self, state_manager): + with pytest.raises(InvalidStoreTransactionStateError): + state_manager.get_store_transaction("store") + + def test_start_store_transaction_already_started(self, state_manager): + state_manager.register_store("topic", "store") + with state_manager.start_store_transaction("topic", partition=0, offset=0): + with pytest.raises(InvalidStoreTransactionStateError): + with state_manager.start_store_transaction( + "topic", partition=0, offset=0 + ): + ... diff --git a/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/fixtures.py b/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/fixtures.py index 95be585e0..31d16039a 100644 --- a/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/fixtures.py +++ b/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/fixtures.py @@ -1,13 +1,16 @@ +import uuid from typing import Optional import pytest -from streamingdataframes.state.rocksdb import RocksDBStorage, RocksDBOptions -from streamingdataframes.state.rocksdb.serialization import DumpsFunc, LoadsFunc +from streamingdataframes.state.rocksdb import RocksDBStore +from streamingdataframes.state.rocksdb.options import RocksDBOptions +from streamingdataframes.state.rocksdb.partition import RocksDBStorePartition +from streamingdataframes.state.types import DumpsFunc, LoadsFunc @pytest.fixture() -def rocksdb_storage_factory(tmp_path): +def rocksdb_partition_factory(tmp_path): def factory( name: str = "db", options: Optional[RocksDBOptions] = None, @@ -15,9 +18,9 @@ def factory( open_retry_backoff: float = 3.0, dumps: Optional[DumpsFunc] = None, loads: Optional[LoadsFunc] = None, - ) -> RocksDBStorage: + ) -> RocksDBStorePartition: path = (tmp_path / name).as_posix() - return RocksDBStorage( + return RocksDBStorePartition( path, options=options, open_max_retries=open_max_retries, @@ -30,7 +33,23 @@ def factory( @pytest.fixture() -def rocksdb_storage(rocksdb_storage_factory) -> RocksDBStorage: - storage = rocksdb_storage_factory() - yield storage - storage.close() +def rocksdb_partition(rocksdb_partition_factory) -> RocksDBStorePartition: + partition = rocksdb_partition_factory() + yield partition + partition.close() + + +@pytest.fixture() +def rocksdb_store_factory(tmp_path): + def factory(topic: Optional[str] = None, name: str = "default") -> RocksDBStore: + topic = topic or str(uuid.uuid4()) + return RocksDBStore(topic=topic, name=name, base_dir=str(tmp_path)) + + return factory + + +@pytest.fixture() +def rocksdb_store(rocksdb_store_factory) -> RocksDBStore: + store = rocksdb_store_factory() + yield store + store.close() diff --git a/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/test_stores.py b/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/test_partition.py similarity index 62% rename from src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/test_stores.py rename to src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/test_partition.py index 9d8bee600..351a5efb9 100644 --- a/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/test_stores.py +++ b/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/test_partition.py @@ -12,7 +12,7 @@ from streamingdataframes.state.rocksdb import ( StateSerializationError, StateTransactionError, - RocksDBStorage, + RocksDBStorePartition, NestedPrefixError, RocksDBOptions, ) @@ -44,13 +44,13 @@ ] -class TestRocksDBStorage: - def test_open_db_close(self, rocksdb_storage_factory): - with rocksdb_storage_factory(): +class TestRocksDBStorePartition: + def test_open_db_close(self, rocksdb_partition_factory): + with rocksdb_partition_factory(): ... - def test_open_db_locked_retries(self, rocksdb_storage_factory, executor): - db1 = rocksdb_storage_factory("db") + def test_open_db_locked_retries(self, rocksdb_partition_factory, executor): + db1 = rocksdb_partition_factory("db") def _close_db(): time.sleep(3) @@ -58,79 +58,79 @@ def _close_db(): executor.submit(_close_db) - rocksdb_storage_factory("db", open_max_retries=10, open_retry_backoff=1) + rocksdb_partition_factory("db", open_max_retries=10, open_retry_backoff=1) - def test_open_db_locked_no_retries_fails(self, rocksdb_storage_factory, executor): - _ = rocksdb_storage_factory("db") + def test_open_db_locked_no_retries_fails(self, rocksdb_partition_factory, executor): + _ = rocksdb_partition_factory("db") with pytest.raises(Exception): - rocksdb_storage_factory("db", open_max_retries=0) + rocksdb_partition_factory("db", open_max_retries=0) def test_open_db_locked_retries_exhausted_fails( - self, rocksdb_storage_factory, executor + self, rocksdb_partition_factory, executor ): - _ = rocksdb_storage_factory("db") + _ = rocksdb_partition_factory("db") with pytest.raises(Exception): - rocksdb_storage_factory("db", open_max_retries=3, open_retry_backoff=1) + rocksdb_partition_factory("db", open_max_retries=3, open_retry_backoff=1) - def test_open_arbitrary_exception_fails(self, rocksdb_storage_factory): + def test_open_arbitrary_exception_fails(self, rocksdb_partition_factory): err = Exception("some exception") with patch.object(Rdict, "__init__", side_effect=err): with pytest.raises(Exception) as raised: - rocksdb_storage_factory() + rocksdb_partition_factory() assert str(raised.value) == "some exception" - def test_get_db_closed_fails(self, rocksdb_storage_factory): - storage = rocksdb_storage_factory() + def test_get_db_closed_fails(self, rocksdb_partition_factory): + storage = rocksdb_partition_factory() storage.close() with pytest.raises(Exception): storage.get(b"key") - def test_get_key_doesnt_exist(self, rocksdb_storage): - assert rocksdb_storage.get(b"key") is None + def test_get_key_doesnt_exist(self, rocksdb_partition): + assert rocksdb_partition.get(b"key") is None - def test_destroy(self, rocksdb_storage_factory): - with rocksdb_storage_factory() as storage: + def test_destroy(self, rocksdb_partition_factory): + with rocksdb_partition_factory() as storage: path = storage.path - RocksDBStorage.destroy(path) + RocksDBStorePartition.destroy(path) - def test_custom_options(self, rocksdb_storage_factory, tmp_path): + def test_custom_options(self, rocksdb_partition_factory, tmp_path): """ Pass custom "logs_dir" to Rdict and ensure it exists and has some files """ logs_dir = Path(tmp_path / "db" / "logs") options = RocksDBOptions(db_log_dir=logs_dir.as_posix()) - with rocksdb_storage_factory(options=options): + with rocksdb_partition_factory(options=options): assert logs_dir.is_dir() assert len(list(logs_dir.rglob("*"))) == 1 -class TestTransactionStore: - def test_transaction_complete(self, rocksdb_storage): - with rocksdb_storage.begin() as tx: +class TestRocksDBPartitionTransaction: + def test_transaction_complete(self, rocksdb_partition): + with rocksdb_partition.begin() as tx: ... assert tx.completed - def test_transaction_doesnt_write_empty_batch(self, rocksdb_storage): + def test_transaction_doesnt_write_empty_batch(self, rocksdb_partition): """ Test that transaction doesn't call "StateStore.write()" if the internal WriteBatch is empty (i.e. no keys were updated during the transaction). Writing empty batches costs more than doing """ - with patch.object(RocksDBStorage, "write") as mocked: - with rocksdb_storage.begin() as tx: + with patch.object(RocksDBStorePartition, "write") as mocked: + with rocksdb_partition.begin() as tx: tx.get("key") assert not mocked.called - def test_delete_key_doesnt_exist(self, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_delete_key_doesnt_exist(self, rocksdb_partition): + with rocksdb_partition.begin() as tx: tx.delete("key") @pytest.mark.parametrize( @@ -141,8 +141,8 @@ def test_delete_key_doesnt_exist(self, rocksdb_storage): "value", TEST_VALUES, ) - def test_get_key_exists_cached(self, key, value, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_get_key_exists_cached(self, key, value, rocksdb_partition): + with rocksdb_partition.begin() as tx: tx.set(key, value) stored = tx.get(key) assert stored == value @@ -155,46 +155,46 @@ def test_get_key_exists_cached(self, key, value, rocksdb_storage): "value", TEST_VALUES, ) - def test_get_key_exists_no_cache(self, key, value, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_get_key_exists_no_cache(self, key, value, rocksdb_partition): + with rocksdb_partition.begin() as tx: tx.set(key, value) - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: stored = tx.get(key, value) assert stored == value - def test_get_key_doesnt_exist_default(self, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_get_key_doesnt_exist_default(self, rocksdb_partition): + with rocksdb_partition.begin() as tx: value = tx.get("key", default=123) assert value == 123 - def test_delete_key_cached(self, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_delete_key_cached(self, rocksdb_partition): + with rocksdb_partition.begin() as tx: tx.set("key", "value") assert tx.get("key") == "value" tx.delete("key") assert tx.get("key") is None - def test_delete_key_no_cache(self, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_delete_key_no_cache(self, rocksdb_partition): + with rocksdb_partition.begin() as tx: tx.set("key", "value") assert tx.get("key") == "value" - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: tx.delete("key") - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: assert tx.get("key") is None - def test_key_exists_cached(self, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_key_exists_cached(self, rocksdb_partition): + with rocksdb_partition.begin() as tx: tx.set("key", "value") assert tx.exists("key") assert not tx.exists("key123") - def test_key_exists_no_cache(self, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_key_exists_no_cache(self, rocksdb_partition): + with rocksdb_partition.begin() as tx: tx.set("key", "value") - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: assert tx.exists("key") assert not tx.exists("key123") @@ -207,18 +207,18 @@ def test_key_exists_no_cache(self, rocksdb_storage): (datetime.utcnow(), "string"), ], ) - def test_set_serialization_error(self, key, value, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_set_serialization_error(self, key, value, rocksdb_partition): + with rocksdb_partition.begin() as tx: with pytest.raises(StateSerializationError): tx.set(key, value) @pytest.mark.parametrize("key", [object(), b"somebytes", datetime.utcnow()]) - def test_delete_serialization_error(self, key, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_delete_serialization_error(self, key, rocksdb_partition): + with rocksdb_partition.begin() as tx: with pytest.raises(StateSerializationError): tx.delete(key) - def test_get_deserialization_error(self, rocksdb_storage): + def test_get_deserialization_error(self, rocksdb_partition): bytes_ = secrets.token_bytes(10) string_ = "string" @@ -227,42 +227,42 @@ def test_get_deserialization_error(self, rocksdb_storage): batch.put(bytes_, serialize(string_)) # Set valid key and non-deserializable value batch.put(serialize(string_), bytes_) - rocksdb_storage.write(batch) + rocksdb_partition.write(batch) - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: with pytest.raises(StateSerializationError): tx.get(string_) with pytest.raises(StateSerializationError): tx.get(bytes_) @pytest.mark.parametrize("prefix", TEST_PREFIXES) - def test_set_key_with_prefix_no_cache(self, prefix, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_set_key_with_prefix_no_cache(self, prefix, rocksdb_partition): + with rocksdb_partition.begin() as tx: with tx.with_prefix(prefix): tx.set("key", "value") - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: with tx.with_prefix(prefix): assert tx.get("key") == "value" - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: assert tx.get("key") is None @pytest.mark.parametrize("prefix", TEST_PREFIXES) - def test_delete_key_with_prefix_no_cache(self, prefix, rocksdb_storage): - with rocksdb_storage.begin() as tx: + def test_delete_key_with_prefix_no_cache(self, prefix, rocksdb_partition): + with rocksdb_partition.begin() as tx: with tx.with_prefix(prefix): tx.set("key", "value") - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: with tx.with_prefix(prefix): assert tx.get("key") == "value" - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: with tx.with_prefix(prefix): tx.delete("key") - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: with tx.with_prefix(prefix): assert tx.get("key") is None @@ -273,7 +273,7 @@ def test_delete_key_with_prefix_no_cache(self, prefix, rocksdb_storage): lambda tx: tx.delete("key"), ], ) - def test_update_key_failed_transaction_failed(self, operation, rocksdb_storage): + def test_update_key_failed_transaction_failed(self, operation, rocksdb_partition): """ Test that if the update operation (set or delete) fails the transaction is marked as failed and cannot be re-used anymore. @@ -281,7 +281,7 @@ def test_update_key_failed_transaction_failed(self, operation, rocksdb_storage): with patch.object( rocksdict.WriteBatch, "put", side_effect=ValueError("test") ), patch.object(rocksdict.WriteBatch, "delete", side_effect=ValueError("test")): - with rocksdb_storage.begin() as tx: + with rocksdb_partition.begin() as tx: with contextlib.suppress(ValueError): operation(tx=tx) @@ -301,22 +301,24 @@ def test_update_key_failed_transaction_failed(self, operation, rocksdb_storage): tx.exists("key") with pytest.raises(StateTransactionError): - tx._flush() + tx.maybe_flush() assert not tx.completed - def test_flush_failed_transaction_failed(self, rocksdb_storage): + def test_flush_failed_transaction_failed(self, rocksdb_partition): """ - Test that if the "StateStore.write()" fails the transaction is also marked + Test that if the "maybe_flush()" fails the transaction is also marked as failed and cannot be re-used anymore. """ - with patch.object(RocksDBStorage, "write", side_effect=ValueError("test")): - with rocksdb_storage.begin() as tx: + with patch.object( + RocksDBStorePartition, "write", side_effect=ValueError("test") + ): + with rocksdb_partition.begin() as tx: tx.set("key", "value") with contextlib.suppress(ValueError): - tx._flush() + tx.maybe_flush() assert tx.failed @@ -335,18 +337,27 @@ def test_flush_failed_transaction_failed(self, rocksdb_storage): assert tx.completed - def test_nested_prefixes_fail(self, rocksdb_storage): - tx = rocksdb_storage.begin() + def test_transaction_not_flushed_on_error(self, rocksdb_partition): + with contextlib.suppress(ValueError): + with rocksdb_partition.begin() as tx: + tx.set("key", "value") + raise ValueError("test") + + with rocksdb_partition.begin() as tx: + assert tx.get("key") is None + + def test_nested_prefixes_fail(self, rocksdb_partition): + tx = rocksdb_partition.begin() with pytest.raises(NestedPrefixError): with tx.with_prefix("prefix"): with tx.with_prefix("prefix"): ... - def test_custom_dumps_loads(self, rocksdb_storage_factory): + def test_custom_dumps_loads(self, rocksdb_partition_factory): key = secrets.token_bytes(10) value = secrets.token_bytes(10) - with rocksdb_storage_factory(loads=lambda v: v, dumps=lambda v: v) as db: + with rocksdb_partition_factory(loads=lambda v: v, dumps=lambda v: v) as db: with db.begin() as tx: tx.set(key, value) diff --git a/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/test_store.py b/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/test_store.py new file mode 100644 index 000000000..0dfdf9017 --- /dev/null +++ b/src/StreamingDataFrames/tests/test_dataframes/test_state/test_rocksdb/test_store.py @@ -0,0 +1,46 @@ +import pytest + +from streamingdataframes.state.exceptions import PartitionNotAssignedError + + +class TestRocksDBStore: + def test_open_close(self, rocksdb_store_factory): + with rocksdb_store_factory(): + pass + + def test_assign_revoke_partition(self, rocksdb_store): + # Assign a partition to the store + rocksdb_store.assign_partition(0) + assert rocksdb_store.partitions[0] + # Revoke partition + rocksdb_store.revoke_partition(0) + assert 0 not in rocksdb_store.partitions + # Assign partition again + rocksdb_store.assign_partition(0) + + def test_assign_partition_twice(self, rocksdb_store): + rocksdb_store.assign_partition(0) + rocksdb_store.assign_partition(0) + + def test_revoke_partition_not_assigned(self, rocksdb_store): + rocksdb_store.revoke_partition(0) + + def test_create_transaction(self, rocksdb_store): + rocksdb_store.assign_partition(0) + with rocksdb_store.start_partition_transaction(0) as tx: + tx.set("key", "value") + rocksdb_store.revoke_partition(0) + + # Assign partition again and check the value + rocksdb_store.assign_partition(0) + with rocksdb_store.start_partition_transaction(0) as tx: + assert tx.get("key") == "value" + + def test_get_transaction_partition_not_assigned(self, rocksdb_store): + with pytest.raises(PartitionNotAssignedError): + rocksdb_store.start_partition_transaction(0) + + rocksdb_store.assign_partition(0) + rocksdb_store.revoke_partition(0) + with pytest.raises(PartitionNotAssignedError): + rocksdb_store.start_partition_transaction(0) diff --git a/src/StreamingDataFrames/tests/utils.py b/src/StreamingDataFrames/tests/utils.py index 7b4d0b42c..a1ceeb476 100644 --- a/src/StreamingDataFrames/tests/utils.py +++ b/src/StreamingDataFrames/tests/utils.py @@ -1,5 +1,8 @@ +import dataclasses import time +from confluent_kafka import OFFSET_INVALID + DEFAULT_TIMEOUT = 10.0 @@ -21,3 +24,10 @@ def __bool__(self): if expired: raise TimeoutError("Timeout expired") return True + + +@dataclasses.dataclass +class TopicPartitionStub: + topic: str + partition: int + offset: int = OFFSET_INVALID