diff --git a/quixstreams/app.py b/quixstreams/app.py index bb909799a..0e639c4ef 100644 --- a/quixstreams/app.py +++ b/quixstreams/app.py @@ -217,7 +217,7 @@ def __init__( if broker_address: # If broker_address is passed to the app it takes priority over any quix configuration self._is_quix_app = False - topic_manager_factory = TopicManager + self._topic_manager_factory = TopicManager if isinstance(broker_address, str): broker_address = ConnectionConfig(bootstrap_servers=broker_address) else: @@ -245,7 +245,7 @@ def __init__( f"{quix_app_source} detected; " f"the application will connect to Quix Cloud brokers" ) - topic_manager_factory = functools.partial( + self._topic_manager_factory = functools.partial( QuixTopicManager, quix_config_builder=quix_config_builder ) # Check if the state dir points to the mounted PVC while running on Quix @@ -288,30 +288,12 @@ def __init__( self._on_message_processed = on_message_processed self._on_processing_error = on_processing_error or default_on_processing_error - self._consumer = RowConsumer( - broker_address=self._config.broker_address, - consumer_group=self._config.consumer_group, - auto_offset_reset=self._config.auto_offset_reset, - auto_commit_enable=False, # Disable auto commit and manage commits manually - extra_config=self._config.consumer_extra_config, - on_error=on_consumer_error, - ) + self._consumer = self._get_rowconsumer(on_error=on_consumer_error) self._producer = self._get_rowproducer(on_error=on_producer_error) self._running = False self._failed = False - if not topic_manager: - topic_manager = topic_manager_factory( - topic_admin=TopicAdmin( - broker_address=self._config.broker_address, - extra_config=self._config.producer_extra_config, - ), - consumer_group=self._config.consumer_group, - timeout=self._config.request_timeout, - create_timeout=self._config.topic_create_timeout, - auto_create_topics=self._config.auto_create_topics, - ) - self._topic_manager = topic_manager + self._topic_manager = topic_manager or self._get_topic_manager() producer = None recovery_manager = None @@ -369,6 +351,23 @@ def Quix(cls, *args, **kwargs): '"Quix__Sdk__Token" environment variable' ) + def _get_topic_manager(self) -> TopicManager: + """ + Create a TopicAdmin using the application config + + Used to create the application topic admin as well as the sources topic admins + """ + return self._topic_manager_factory( + topic_admin=TopicAdmin( + broker_address=self._config.broker_address, + extra_config=self._config.producer_extra_config, + ), + consumer_group=self._config.consumer_group, + timeout=self._config.request_timeout, + create_timeout=self._config.topic_create_timeout, + auto_create_topics=self._config.auto_create_topics, + ) + def topic( self, name: str, @@ -579,6 +578,24 @@ def get_producer(self) -> Producer: extra_config=self._config.producer_extra_config, ) + def _get_rowconsumer( + self, on_error: Optional[ConsumerErrorCallback] = None + ) -> RowConsumer: + """ + Create a RowConsumer using the application config + + Used to create the application consumer as well as the sources consumers + """ + + return RowConsumer( + broker_address=self._config.broker_address, + consumer_group=self._config.consumer_group, + auto_offset_reset=self._config.auto_offset_reset, + auto_commit_enable=False, # Disable auto commit and manage commits manually + extra_config=self._config.consumer_extra_config, + on_error=on_error, + ) + def get_consumer(self, auto_commit_enable: bool = True) -> Consumer: """ Create and return a pre-configured Consumer instance. @@ -650,9 +667,13 @@ def add_source(self, source: BaseSource, topic: Optional[Topic] = None) -> Topic if not topic: topic = self._topic_manager.register(source.default_topic()) - producer = self._get_rowproducer(transactional=False) - source.configure(topic, producer) - self._source_manager.register(source) + self._source_manager.register( + source, + topic, + self._get_rowproducer(transactional=False), + self._get_rowconsumer(), + self._get_topic_manager(), + ) return topic def run(self, dataframe: Optional[StreamingDataFrame] = None): @@ -879,7 +900,8 @@ def _on_assign(self, _, topic_partitions: List[TopicPartition]): stored_offsets = [ offset for offset in ( - store_tp.get_processed_offset() for store_tp in store_partitions + store_tp.get_processed_offset() + for store_tp in store_partitions.values() ) if offset is not None ] diff --git a/quixstreams/models/topics/manager.py b/quixstreams/models/topics/manager.py index f8d97f653..fbaba21d5 100644 --- a/quixstreams/models/topics/manager.py +++ b/quixstreams/models/topics/manager.py @@ -46,7 +46,7 @@ class TopicManager: _max_topic_name_len = 255 _groupby_extra_config_imports_defaults = {"retention.bytes", "retention.ms"} - _changelog_extra_config_defaults = {"cleanup.policy": "compact"} + _changelog_extra_config_override = {"cleanup.policy": "compact"} _changelog_extra_config_imports_defaults = {"retention.bytes", "retention.ms"} def __init__( @@ -67,7 +67,7 @@ def __init__( self._consumer_group = consumer_group self._topics: Dict[str, Topic] = {} self._repartition_topics: Dict[str, Topic] = {} - self._changelog_topics: Dict[str, Dict[str, Topic]] = {} + self._changelog_topics: Dict[Optional[str], Dict[str, Topic]] = {} self._timeout = timeout self._create_timeout = create_timeout self._auto_create_topics = auto_create_topics @@ -101,7 +101,7 @@ def repartition_topics(self) -> Dict[str, Topic]: return self._repartition_topics @property - def changelog_topics(self) -> Dict[str, Dict[str, Topic]]: + def changelog_topics(self) -> Dict[Optional[str], Dict[str, Topic]]: """ Note: `Topic`s are the changelogs. @@ -152,7 +152,7 @@ def _format_nested_name(self, topic_name: str) -> str: def _internal_name( self, topic_type: Literal["changelog", "repartition"], - topic_name: str, + topic_name: Optional[str], suffix: str, ) -> str: """ @@ -163,13 +163,19 @@ def _internal_name( The internal format is <{TYPE}__{GROUP}--{NAME}--{SUFFIX}> :param topic_type: topic type, added as prefix (changelog, repartition) - :param topic_name: name of consumed topic (app input topic) + :param topic_name: name of consumed topic, if exist (app input topic) :param suffix: a unique descriptor related to topic type, added as suffix :return: formatted topic name """ - nested_name = self._format_nested_name(topic_name) - return f"{topic_type}__{'--'.join([self._consumer_group, nested_name, suffix])}" + + if topic_name is None: + parts = [self._consumer_group, suffix] + else: + nested_name = self._format_nested_name(topic_name) + parts = [self._consumer_group, nested_name, suffix] + + return f"{topic_type}__{'--'.join(parts)}" def _create_topics( self, topics: List[Topic], timeout: float, create_timeout: float @@ -341,13 +347,14 @@ def repartition_topic( def changelog_topic( self, - topic_name: str, + topic_name: Optional[str], store_name: str, + config: Optional[TopicConfig] = None, timeout: Optional[float] = None, ) -> Topic: """ - Performs all the logic necessary to generate a changelog topic based on a - "source topic" (aka input/consumed topic). + Performs all the logic necessary to generate a changelog topic based on an + optional "source topic" (aka input/consumed topic). Its main goal is to ensure partition counts of the to-be generated changelog match the source topic, and ensure the changelog topic is compacted. Also @@ -366,23 +373,32 @@ def changelog_topic( > NOTE: normally contain any prefixes added by TopicManager.topic() :param store_name: name of the store this changelog belongs to (default, rolling10s, etc.) + :param config: the changelog topic configuration. Default to `topic_name` configuration or TopicManager default :param timeout: config lookup timeout (seconds); Default 30 :return: `Topic` object (which is also stored on the TopicManager) """ + if config is None: + if topic_name is None: + config = self.topic_config( + num_partitions=self.default_num_partitions, + replication_factor=self.default_replication_factor, + ) + else: + source_topic_config = self._get_source_topic_config( + topic_name, + extras_imports=self._changelog_extra_config_imports_defaults, + timeout=timeout if timeout is not None else self._timeout, + ) - source_topic_config = self._get_source_topic_config( - topic_name, - extras_imports=self._changelog_extra_config_imports_defaults, - timeout=timeout if timeout is not None else self._timeout, - ) - source_topic_config.extra_config.update(self._changelog_extra_config_defaults) + config = self.topic_config( + num_partitions=source_topic_config.num_partitions, + replication_factor=source_topic_config.replication_factor, + extra_config=source_topic_config.extra_config, + ) - changelog_config = self.topic_config( - num_partitions=source_topic_config.num_partitions, - replication_factor=source_topic_config.replication_factor, - extra_config=source_topic_config.extra_config, - ) + # always override some default configuration + config.extra_config.update(self._changelog_extra_config_override) topic = self._finalize_topic( Topic( @@ -391,7 +407,7 @@ def changelog_topic( value_serializer="bytes", key_deserializer="bytes", value_deserializer="bytes", - config=changelog_config, + config=config, ) ) self._changelog_topics.setdefault(topic_name, {})[store_name] = topic diff --git a/quixstreams/sources/__init__.py b/quixstreams/sources/__init__.py index 236dafba2..7d96446ac 100644 --- a/quixstreams/sources/__init__.py +++ b/quixstreams/sources/__init__.py @@ -1,4 +1,11 @@ -from .base import BaseSource, Source, SourceException, SourceManager, multiprocessing +from .base import ( + BaseSource, + Source, + SourceException, + SourceManager, + StatefulSource, + multiprocessing, +) from .core.csv import CSVSource from .core.kafka import KafkaReplicatorSource, QuixEnvironmentSource @@ -11,4 +18,5 @@ "Source", "SourceException", "SourceManager", + "StatefulSource", ] diff --git a/quixstreams/sources/base/__init__.py b/quixstreams/sources/base/__init__.py index f10a2a46d..da66de5f9 100644 --- a/quixstreams/sources/base/__init__.py +++ b/quixstreams/sources/base/__init__.py @@ -1,7 +1,7 @@ from .exceptions import SourceException from .manager import SourceManager from .multiprocessing import multiprocessing -from .source import BaseSource, Source +from .source import BaseSource, Source, StatefulSource __all__ = ( "Source", @@ -9,4 +9,5 @@ "multiprocessing", "SourceManager", "SourceException", + "StatefulSource", ) diff --git a/quixstreams/sources/base/manager.py b/quixstreams/sources/base/manager.py index ba5c994ad..b5a653935 100644 --- a/quixstreams/sources/base/manager.py +++ b/quixstreams/sources/base/manager.py @@ -4,12 +4,19 @@ from pickle import PicklingError from typing import List +from confluent_kafka import OFFSET_BEGINNING + from quixstreams.logging import LOGGER_NAME, configure_logging -from quixstreams.models import Topic +from quixstreams.models import Topic, TopicManager +from quixstreams.models.topics import TopicConfig +from quixstreams.rowconsumer import RowConsumer +from quixstreams.rowproducer import RowProducer +from quixstreams.state import RecoveryManager, StateStoreManager, StorePartition +from quixstreams.state.memory import MemoryStore from .exceptions import SourceException from .multiprocessing import multiprocessing -from .source import BaseSource +from .source import BaseSource, StatefulSource logger = logging.getLogger(__name__) @@ -24,14 +31,27 @@ class SourceProcess(multiprocessing.Process): Some methods are designed to be used from the parent process, and others from the child process. """ - def __init__(self, source): + def __init__( + self, + source: BaseSource, + topic: Topic, + producer: RowProducer, + consumer: RowConsumer, + topic_manager: TopicManager, + ): super().__init__() - self.source: BaseSource = source + self.topic = topic + self.source = source self._exceptions: List[Exception] = [] self._started = False self._stopping = False + self._topic_manager = topic_manager + + self._consumer = consumer + self._producer = producer + # copy parent process log level to the child process self._loglevel = logging.getLogger(LOGGER_NAME).level @@ -39,7 +59,7 @@ def __init__(self, source): self._rpipe, self._wpipe = multiprocessing.Pipe(duplex=False) @property - def started(self): + def started(self) -> bool: return self._started # --- CHILD PROCESS METHODS --- # @@ -63,13 +83,27 @@ def run(self) -> None: self._started = True self._setup_signal_handlers() configure_logging(self._loglevel, str(self.source), pid=True) - logger.info("Source started") + logger.info("Starting source") + + configuration = {"topic": self.topic, "producer": self._producer} + + if isinstance(self.source, StatefulSource): + try: + configuration["store_partition"] = self._recover_state(self.source) + except BaseException as err: + logger.exception("Error in source") + self._report_exception(err) + return + self.source.configure(**configuration) + + logger.info("Source started") try: self.source.start() except BaseException as err: logger.exception("Error in source") self._report_exception(err) + return logger.info("Source completed") threadcount = threading.active_count() @@ -79,6 +113,45 @@ def run(self) -> None: "s" if threadcount > 1 else "", ) + def _recover_state(self, source: StatefulSource) -> StorePartition: + """ + Recover the state from the changelog topic and return the assigned partition + + For stateful sources only. + """ + recovery_manager = RecoveryManager( + consumer=self._consumer, + topic_manager=self._topic_manager, + ) + + state_manager = StateStoreManager( + producer=self._producer, recovery_manager=recovery_manager + ) + + state_manager.register_store( + topic_name=None, + store_name=source.store_name, + store_type=MemoryStore, + topic_config=TopicConfig( + num_partitions=source.store_partitions_count, + replication_factor=self._topic_manager.default_replication_factor, + ), + ) + + self._topic_manager.create_all_topics() + self._topic_manager.validate_all_topics() + + store_partitions = state_manager.on_partition_assign( + topic=None, + partition=source.assigned_store_partition, + committed_offset=OFFSET_BEGINNING, + ) + + if state_manager.recovery_required: + state_manager.do_recovery() + + return store_partitions[source.store_name] + def _stop(self, signum, _): """ Stop the source execution. @@ -108,7 +181,7 @@ def _report_exception(self, err: BaseException) -> None: # --- PARENT PROCESS METHODS --- # - def start(self) -> "SourceProcess": + def start(self) -> None: logger.info("Starting source %s", self.source) self._started = True return super().start() @@ -170,20 +243,31 @@ class SourceManager: def __init__(self): self.processes: List[SourceProcess] = [] - def register(self, source: BaseSource): + def register( + self, + source: BaseSource, + topic, + producer, + consumer, + topic_manager, + ) -> SourceProcess: """ Register a new source in the manager. Each source need to already be configured, can't reuse a topic and must be unique """ - if not source.configured: - raise ValueError("Accepts configured Source only") - if source.producer_topic in self.topics: - raise ValueError(f"topic '{source.producer_topic.name}' already in use") + if topic in self.topics: + raise ValueError(f'Topic name "{topic.name}" is already in use') elif source in self.sources: - raise ValueError(f"source '{source}' already registered") - - process = SourceProcess(source) + raise ValueError(f'Source "{source}" is already registered') + + process = SourceProcess( + source=source, + topic=topic, + producer=producer, + consumer=consumer, + topic_manager=topic_manager, + ) self.processes.append(process) return process @@ -193,7 +277,7 @@ def sources(self) -> List[BaseSource]: @property def topics(self) -> List[Topic]: - return [process.source.producer_topic for process in self.processes] + return [process.topic for process in self.processes] def start_sources(self) -> None: for process in self.processes: diff --git a/quixstreams/sources/base/source.py b/quixstreams/sources/base/source.py index e9c952656..b9c8753c6 100644 --- a/quixstreams/sources/base/source.py +++ b/quixstreams/sources/base/source.py @@ -7,13 +7,11 @@ from quixstreams.models.topics import Topic from quixstreams.models.types import Headers from quixstreams.rowproducer import RowProducer +from quixstreams.state import PartitionTransaction, State, StorePartition logger = logging.getLogger(__name__) -__all__ = ( - "BaseSource", - "Source", -) +__all__ = ("BaseSource", "Source", "StatefulSource") class BaseSource(ABC): @@ -86,21 +84,15 @@ def main(): def __init__(self): self._producer: Optional[RowProducer] = None self._producer_topic: Optional[Topic] = None - self._configured: bool = False - def configure(self, topic: Topic, producer: RowProducer) -> None: + def configure(self, topic: Topic, producer: RowProducer, **kwargs) -> None: """ - This method is triggered when the source is registered to the Application. + This method is triggered before the source is started. - It configures the source's Kafka producer and the topic it will produce to. + It configures the source's Kafka producer, the topic it will produce to and optional dependencies. """ self._producer = producer self._producer_topic = topic - self._configured = True - - @property - def configured(self): - return self._configured @property def producer_topic(self): @@ -143,9 +135,10 @@ class Source(BaseSource): Example: ```python - from quixstreams import Application import random + import time + from quixstreams import Application from quixstreams.sources import Source @@ -155,6 +148,7 @@ def run(self): number = random.randint(0, 100) serialized = self._producer_topic.serialize(value=number) self.produce(key=str(number), value=serialized.value) + time.sleep(0.5) def main(): @@ -320,3 +314,144 @@ def default_topic(self) -> Topic: def __repr__(self): return self.name + + +class StatefulSource(Source): + """ + A `Source` class for custom Sources that need a state. + + Subclasses are responsible for flushing, by calling `flush`, at reasonable intervals. + + Example: + + ```python + import random + import time + + from quixstreams import Application + from quixstreams.sources import StatefulSource + + + class RandomNumbersSource(StatefulSource): + def run(self): + + i = 0 + while self.running: + previous = self.state.get("number", 0) + current = random.randint(0, 100) + self.state.set("number", current) + + serialized = self._producer_topic.serialize(value=current + previous) + self.produce(key=str(current), value=serialized.value) + time.sleep(0.5) + + # flush the state every 10 messages + i += 1 + if i % 10 == 0: + self.flush() + + + def main(): + app = Application(broker_address="localhost:9092") + source = RandomNumbersSource(name="random-source") + + sdf = app.dataframe(source=source) + sdf.print(metadata=True) + + app.run() + + + if __name__ == "__main__": + main() + ``` + """ + + def __init__(self, name: str, shutdown_timeout: float = 10) -> None: + """ + :param name: The source unique name. It is used to generate the topic configuration. + :param shutdown_timeout: Time in second the application waits for the source to gracefully shutdown. + """ + super().__init__(name, shutdown_timeout) + self._store_partition: Optional[StorePartition] = None + self._store_transaction: Optional[PartitionTransaction] = None + self._store_state: Optional[State] = None + + def configure( + self, + topic: Topic, + producer: RowProducer, + *, + store_partition: Optional[StorePartition] = None, + **kwargs, + ) -> None: + """ + This method is triggered before the source is started. + + It configures the source's Kafka producer, the topic it will produce to and the store partition. + """ + super().configure(topic=topic, producer=producer) + self._store_partition = store_partition + self._store_transaction = None + self._store_state = None + + @property + def store_partitions_count(self) -> int: + """ + Count of store partitions. + + Used to configure the number of partition in the changelog topic. + """ + return 1 + + @property + def assigned_store_partition(self) -> int: + """ + The store partition assigned to this instance + """ + return 0 + + @property + def store_name(self) -> str: + """ + The source store name + """ + return f"source-{self.name}" + + @property + def state(self) -> State: + """ + Access the `State` of the source. + + The `State` lifecycle is tied to the store transaction. A transaction is only valid until the next `.flush()` call. If no valid transaction exist, a new transaction is created. + + Important: after each `.flush()` call, a previously returned instance is invalidated and cannot be used. The property must be called again. + """ + if self._store_partition is None: + raise RuntimeError("source is not configured") + + if self._store_transaction is None: + self._store_transaction = self._store_partition.begin() + + if self._store_state is None: + self._store_state = self._store_transaction.as_state() + + return self._store_state + + def flush(self, timeout: Optional[float] = None) -> None: + """ + This method commit the state and flush the producer. + + It ensures the state is published to the changelog topic and all messages are successfully delivered to Kafka. + + :param float timeout: time to attempt flushing (seconds). + None use producer default or -1 is infinite. Default: None + + :raises CheckpointProducerTimeout: if any message fails to produce before the timeout + """ + if self._store_transaction: + self._store_transaction.prepare(None) + self._store_transaction.flush() + self._store_transaction = None + self._store_state = None + + super().flush(timeout) diff --git a/quixstreams/state/base/transaction.py b/quixstreams/state/base/transaction.py index 11bfd71cd..7fa9dc6d2 100644 --- a/quixstreams/state/base/transaction.py +++ b/quixstreams/state/base/transaction.py @@ -377,7 +377,7 @@ def exists(self, key: Any, prefix: bytes, cf_name: str = "default") -> bool: return self._partition.exists(key_serialized, cf_name=cf_name) @validate_transaction_status(PartitionTransactionStatus.STARTED) - def prepare(self, processed_offset: int): + def prepare(self, processed_offset: Optional[int]): """ Produce changelog messages to the changelog topic for all changes accumulated in this transaction and prepare transaction to flush its state to the state @@ -399,7 +399,7 @@ def prepare(self, processed_offset: int): self._status = PartitionTransactionStatus.FAILED raise - def _prepare(self, processed_offset: int): + def _prepare(self, processed_offset: Optional[int]): if self._changelog_producer is None: return diff --git a/quixstreams/state/manager.py b/quixstreams/state/manager.py index 24f24cba2..e4edf4b05 100644 --- a/quixstreams/state/manager.py +++ b/quixstreams/state/manager.py @@ -1,8 +1,9 @@ import logging import shutil from pathlib import Path -from typing import Dict, List, Optional, Type, Union +from typing import Dict, Optional, Type, Union +from quixstreams.models.topics import TopicConfig from quixstreams.rowproducer import RowProducer from .base import Store, StorePartition @@ -38,21 +39,24 @@ class StateStoreManager: def __init__( self, - group_id: str, - state_dir: Union[str, Path], + group_id: Optional[str] = None, + state_dir: Optional[Union[str, Path]] = None, rocksdb_options: Optional[RocksDBOptionsType] = None, producer: Optional[RowProducer] = None, recovery_manager: Optional[RecoveryManager] = None, default_store_type: StoreTypes = RocksDBStore, ): - self._state_dir = (Path(state_dir) / group_id).absolute() + self._state_dir = (Path(state_dir) / group_id).absolute() if state_dir else None self._rocksdb_options = rocksdb_options - self._stores: Dict[str, Dict[str, Store]] = {} + self._stores: Dict[Optional[str], Dict[str, Store]] = {} self._producer = producer self._recovery_manager = recovery_manager self._default_store_type = default_store_type def _init_state_dir(self): + if self._state_dir is None: + return + logger.info(f'Initializing state directory at "{self._state_dir}"') if self._state_dir.exists(): if not self._state_dir.is_dir(): @@ -66,7 +70,7 @@ def _init_state_dir(self): logger.debug(f'Created state directory at "{self._state_dir}"') @property - def stores(self) -> Dict[str, Dict[str, Store]]: + def stores(self) -> Dict[Optional[str], Dict[str, Store]]: """ Map of registered state stores :return: dict in format {topic: {store_name: store}} @@ -124,7 +128,10 @@ def get_store( return store def _setup_changelogs( - self, topic_name: str, store_name: str + self, + topic_name: Optional[str], + store_name: str, + topic_config: Optional[TopicConfig] = None, ) -> ChangelogProducerFactory: if self._recovery_manager: logger.debug( @@ -132,8 +139,7 @@ def _setup_changelogs( f'(topic "{topic_name}")' ) changelog_topic = self._recovery_manager.register_changelog( - topic_name=topic_name, - store_name=store_name, + topic_name=topic_name, store_name=store_name, topic_config=topic_config ) return ChangelogProducerFactory( changelog_name=changelog_topic.name, @@ -142,9 +148,10 @@ def _setup_changelogs( def register_store( self, - topic_name: str, + topic_name: Optional[str], store_name: str = DEFAULT_STATE_STORE_NAME, store_type: Optional[StoreTypes] = None, + topic_config: Optional[TopicConfig] = None, ): """ Register a state store to be managed by StateStoreManager. @@ -160,7 +167,9 @@ def register_store( Default to StateStoreManager `default_store_type` """ if self._stores.get(topic_name, {}).get(store_name) is None: - changelog_producer_factory = self._setup_changelogs(topic_name, store_name) + changelog_producer_factory = self._setup_changelogs( + topic_name, store_name, topic_config=topic_config + ) store_type = store_type or self.default_store_type if store_type == RocksDBStore: @@ -224,8 +233,8 @@ def clear_stores(self): shutil.rmtree(self._state_dir) def on_partition_assign( - self, topic: str, partition: int, committed_offset: int - ) -> List[StorePartition]: + self, topic: Optional[str], partition: int, committed_offset: int + ) -> Dict[str, StorePartition]: """ Assign store partitions for each registered store for the given `TopicPartition` and return a list of assigned `StorePartition` objects. @@ -247,7 +256,7 @@ def on_partition_assign( committed_offset=committed_offset, store_partitions=store_partitions, ) - return list(store_partitions.values()) + return store_partitions def on_partition_revoke(self, topic: str, partition: int): """ diff --git a/quixstreams/state/recovery.py b/quixstreams/state/recovery.py index 642efa604..6d7f37af2 100644 --- a/quixstreams/state/recovery.py +++ b/quixstreams/state/recovery.py @@ -5,7 +5,7 @@ from quixstreams.kafka import Consumer from quixstreams.models import ConfluentKafkaMessageProto, Topic -from quixstreams.models.topics import TopicManager +from quixstreams.models.topics import TopicConfig, TopicManager from quixstreams.models.types import MessageHeadersMapping from quixstreams.rowproducer import RowProducer from quixstreams.state.base import StorePartition @@ -261,7 +261,12 @@ def recovering(self) -> bool: """ return self.has_assignments and self._running - def register_changelog(self, topic_name: str, store_name: str) -> Topic: + def register_changelog( + self, + topic_name: Optional[str], + store_name: str, + topic_config: Optional[TopicConfig] = None, + ) -> Topic: """ Register a changelog Topic with the TopicManager. @@ -271,6 +276,7 @@ def register_changelog(self, topic_name: str, store_name: str) -> Topic: return self._topic_manager.changelog_topic( topic_name=topic_name, store_name=store_name, + config=topic_config, ) def do_recovery(self): @@ -297,7 +303,7 @@ def do_recovery(self): def _generate_recovery_partitions( self, - topic_name: str, + topic_name: Optional[str], partition_num: int, store_partitions: Dict[str, StorePartition], committed_offset: int, @@ -328,7 +334,7 @@ def _generate_recovery_partitions( def assign_partition( self, - topic: str, + topic: Optional[str], partition: int, committed_offset: int, store_partitions: Dict[str, StorePartition], diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py index 3a099b95e..937a496ef 100644 --- a/quixstreams/state/types.py +++ b/quixstreams/state/types.py @@ -128,7 +128,7 @@ def prepared(self) -> bool: """ ... - def prepare(self, processed_offset: int): + def prepare(self, processed_offset: Optional[int]): """ Produce changelog messages to the changelog topic for all changes accumulated in this transaction and prepare transcation to flush its state to the state diff --git a/tests/test_quixstreams/test_app.py b/tests/test_quixstreams/test_app.py index 886c3df9a..a0f89a76d 100644 --- a/tests/test_quixstreams/test_app.py +++ b/tests/test_quixstreams/test_app.py @@ -33,7 +33,7 @@ from quixstreams.state import State from quixstreams.state.manager import SUPPORTED_STORES from quixstreams.state.rocksdb import RocksDBStore -from tests.utils import DummySink, DummySource +from tests.utils import DummySink, DummySource, DummyStatefulSource def _stop_app_on_future(app: Application, future: Future, timeout: float): @@ -1425,7 +1425,7 @@ def test_on_assign_topic_offset_behind_warning( # Do some change to probe the Writebatch tx.set("key", "value", prefix=b"__key__") tx.flush(processed_offset=9999) - assert state_partitions[partition_num].get_processed_offset() == 9999 + assert state_partitions["default"].get_processed_offset() == 9999 # Define some stateful function so the App assigns store partitions done = Future() @@ -1857,7 +1857,7 @@ def validate_state(stores): state_manager.register_store(topic.name, store_name) partition = state_manager.on_partition_assign( topic=topic.name, partition=0, committed_offset=committed_offset - )[0] + )["default"] with partition.begin() as tx: _validate_transaction_state(tx) @@ -2270,6 +2270,42 @@ def test_source_with_error( assert isinstance(exc.value.__cause__, RuntimeError) assert str(exc.value.__cause__) == f"test {raise_is} error" + def test_stateful_source(self, app_factory, executor): + def _run_app(source, done): + app = app_factory( + auto_offset_reset="earliest", + ) + + sdf = app.dataframe(source=source) + executor.submit(self.wait_finished, app, done, 15.0) + + # The app stops on source error + try: + app.run(sdf) + finally: + # shutdown the thread waiting for exit + done.set() + + source_name = str(uuid.uuid4()) + finished = multiprocessing.Event() + source = DummyStatefulSource( + name=source_name, + values=range(self.MESSAGES_COUNT), + finished=finished, + state_key="test", + ) + _run_app(source, finished) + + finished = multiprocessing.Event() + source = DummyStatefulSource( + name=source_name, + values=range(self.MESSAGES_COUNT), + finished=finished, + state_key="test", + assert_state_value=self.MESSAGES_COUNT - 1, + ) + _run_app(source, finished) + class TestApplicationMultipleSdf: def test_multiple_sdfs( diff --git a/tests/test_quixstreams/test_sources/test_base/test_manager.py b/tests/test_quixstreams/test_sources/test_base/test_manager.py index 0eaf8201a..e6c4657cd 100644 --- a/tests/test_quixstreams/test_sources/test_base/test_manager.py +++ b/tests/test_quixstreams/test_sources/test_base/test_manager.py @@ -19,30 +19,23 @@ def test_register(self): topic1 = Topic("topic1", None) topic2 = Topic("topic2", None) - with pytest.raises(ValueError): - manager.register(source1) - - source1.configure(topic1, None) - manager.register(source1) + manager.register(source1, topic1, None, None, None) # registering the same source twice fails with pytest.raises(ValueError): - manager.register(source1) + manager.register(source1, topic2, None, None, None) - source2.configure(topic1, None) # registering a source with the same topic fails with pytest.raises(ValueError): - manager.register(source2) + manager.register(source2, topic1, None, None, None) - source2.configure(topic2, None) - manager.register(source2) + manager.register(source2, topic2, None, None, None) def test_is_alives(self): manager = SourceManager() source = DummySource() - source.configure(Topic("topic", None), None) - manager.register(source) + manager.register(source, Topic("topic", None), None, None, None) assert not manager.is_alive() @@ -56,8 +49,7 @@ def test_is_alives_kill_source(self): manager = SourceManager() source = DummySource() - source.configure(Topic("topic", None), None) - process = manager.register(source) + process = manager.register(source, Topic("topic", None), None, None, None) assert not manager.is_alive() @@ -80,8 +72,7 @@ def test_terminate_source(self): manager = SourceManager() source = DummySource() - source.configure(Topic("topic", None), None) - process = manager.register(source) + process = manager.register(source, Topic("topic", None), None, None, None) assert not manager.is_alive() @@ -107,8 +98,7 @@ def test_raise_for_error(self, when, exitcode, pickleable): error_in=when, pickeable_error=pickleable, finished=finished ) - source.configure(Topic("topic", None), None) - process = manager.register(source) + process = manager.register(source, Topic("topic", None), None, None, None) # never raise when not started manager.raise_for_error() diff --git a/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py b/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py index 2f223c72d..ab4afcccb 100644 --- a/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py +++ b/tests/test_quixstreams/test_state/test_recovery/test_recovery_manager.py @@ -20,7 +20,9 @@ def test_register_changelog(self, recovery_manager_factory): with patch.object(TopicManager, "changelog_topic") as make_changelog: recovery_manager.register_changelog(topic_name=topic, store_name=store_name) - make_changelog.assert_called_with(topic_name=topic, store_name=store_name) + make_changelog.assert_called_with( + topic_name=topic, store_name=store_name, config=None + ) def test_assign_partition_invalid_offset( self, diff --git a/tests/utils.py b/tests/utils.py index d62ef6462..db1176183 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,14 +2,14 @@ import threading import time import uuid -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union from confluent_kafka import OFFSET_INVALID from quixstreams.sinks import BatchingSink from quixstreams.sinks.base import SinkBatch from quixstreams.sinks.base.item import SinkItem -from quixstreams.sources import Source +from quixstreams.sources import Source, StatefulSource DEFAULT_TIMEOUT = 10.0 @@ -127,10 +127,10 @@ def __init__( name: Optional[str] = None, values: Optional[List[Any]] = None, finished: threading.Event = None, - error_in: Union[str, List[str]] = None, + error_in: Optional[Union[str, List[str]]] = None, pickeable_error: bool = True, ) -> None: - super().__init__(name or str(uuid.uuid4()), 1) + super().__init__(name or str(uuid.uuid4()), 10) self.key = "dummy" self.values = values or [] @@ -139,9 +139,7 @@ def __init__( self.pickeable_error = pickeable_error def run(self): - for value in self.values: - msg = self.serialize(key=self.key, value=value) - self.produce(value=msg.value, key=msg.key) + self._produce() if "run" in self.error_in: self.error("test run error") @@ -152,6 +150,11 @@ def run(self): while self.running: time.sleep(0.1) + def _produce(self): + for value in self.values: + msg = self.serialize(key=self.key, value=value) + self.produce(value=msg.value, key=msg.key) + def cleanup(self, failed): if "cleanup" in self.error_in: self.error("test cleanup error") @@ -169,6 +172,34 @@ def error(self, msg): raise UnpickleableError(msg) +class DummyStatefulSource(DummySource, StatefulSource): + def __init__( + self, + name: Optional[str] = None, + values: Optional[Iterable[Any]] = None, + finished: threading.Event = None, + error_in: Optional[Union[str, List[str]]] = None, + pickeable_error: bool = True, + state_key: str = "", + assert_state_value: Any = None, + ) -> None: + super().__init__(name, values, finished, error_in, pickeable_error) + self._state_key = state_key + self._assert_state_value = assert_state_value + + def run(self): + if self._assert_state_value: + assert self._assert_state_value == self.state.get(self._state_key) + super().run() + + def _produce(self): + for value in self.values: + msg = self.serialize(key=self.key, value=value) + self.produce(value=msg.value, key=msg.key) + self.state.set(self._state_key, value) + self.flush() + + class UnpickleableError(Exception): def __init__(self, *args: object) -> None: # threading.Lock can't be pickled