From 955540695c1f2365c48677e073d4fb6cdcf8254b Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Tue, 19 Nov 2024 11:07:26 +0100 Subject: [PATCH] CI: Implement mypy pre-commit check --- .github/workflows/ci.yml | 11 ++- .pre-commit-config.yaml | 7 ++ pyproject.toml | 36 ++++++++ quixstreams/app.py | 34 ++++--- quixstreams/checkpointing/checkpoint.py | 6 +- quixstreams/context.py | 7 +- quixstreams/kafka/configuration.py | 9 +- quixstreams/kafka/consumer.py | 21 ++++- quixstreams/kafka/producer.py | 4 +- quixstreams/platforms/quix/config.py | 2 +- quixstreams/platforms/quix/topic_manager.py | 4 +- quixstreams/processing/context.py | 10 +-- quixstreams/processing/pausing.py | 4 +- quixstreams/rowconsumer.py | 26 +++--- quixstreams/rowproducer.py | 8 +- quixstreams/sinks/core/influxdb3.py | 2 +- quixstreams/sources/base/exceptions.py | 4 +- quixstreams/sources/base/manager.py | 23 +++-- quixstreams/sources/base/source.py | 23 +++-- quixstreams/sources/core/csv.py | 4 +- quixstreams/sources/core/kafka/checkpoint.py | 11 +-- quixstreams/sources/core/kafka/kafka.py | 88 ++++++++++++------- quixstreams/sources/core/kafka/quix.py | 16 +--- quixstreams/state/recovery.py | 4 +- requirements-mypy.txt | 5 ++ .../test_sources/test_core/test_kafka.py | 8 +- 26 files changed, 246 insertions(+), 131 deletions(-) create mode 100644 requirements-mypy.txt diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 00ccb5047..de7633ffc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,13 @@ jobs: - name: Setup Python uses: actions/setup-python@v4 with: - python-version: 3.9 + python-version: 3.12 + - name: Update pip + run: | + python -m pip install -U pip + - name: Install requirements + run: | + python -m pip install -U -r requirements.txt -r tests/requirements.txt -r requirements-mypy.txt - uses: pre-commit/action@v3.0.1 test: @@ -49,8 +55,7 @@ jobs: python -m pip install -U pip - name: Install requirements run: | - python -m pip install -U -r tests/requirements.txt - python -m pip install -U -r requirements.txt + python -m pip install -U -r requirements.txt -r tests/requirements.txt - name: Run tests run: | python -m pytest -v --log-cli-level=ERROR diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 54fe36a9a..a38814fb7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,3 +11,10 @@ repos: entry: python conda/requirements.py language: python files: ^(requirements\.txt|pyproject\.toml)$ + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.13.0 + hooks: + - id: mypy + args: [] + language: system + files: ^quixstreams/ diff --git a/pyproject.toml b/pyproject.toml index 1cb40618c..5413e7afe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,3 +95,39 @@ log_cli_level = "INFO" log_cli_format = "[%(levelname)s] %(name)s: %(message)s" # Custom markers markers = ["timeit"] + +[[tool.mypy.overrides]] +module = "confluent_kafka.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "quixstreams.core.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "quixstreams.dataframe.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "quixstreams.models.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "quixstreams.platforms.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "quixstreams.sinks.community.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "quixstreams.sources.community.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "quixstreams.state.*" +ignore_errors = true + +[[tool.mypy.overrides]] +module = "quixstreams.rowproducer.*" +ignore_errors = true diff --git a/quixstreams/app.py b/quixstreams/app.py index 0e639c4ef..8c75b89e3 100644 --- a/quixstreams/app.py +++ b/quixstreams/app.py @@ -6,10 +6,11 @@ import time import warnings from pathlib import Path -from typing import Callable, List, Literal, Optional, Tuple, Type, Union +from typing import Callable, List, Literal, Optional, Protocol, Tuple, Type, Union from confluent_kafka import TopicPartition from pydantic import AliasGenerator, Field +from pydantic_settings import BaseSettings as PydanticBaseSettings from pydantic_settings import PydanticBaseSettingsSource, SettingsConfigDict from typing_extensions import Self @@ -60,6 +61,17 @@ _default_max_poll_interval_ms = 300000 +class TopicManagerFactory(Protocol): + def __call__( + self, + topic_admin: TopicAdmin, + consumer_group: str, + timeout: float = 30, + create_timeout: float = 60, + auto_create_topics: bool = True, + ) -> TopicManager: ... + + class Application: """ The main Application class. @@ -205,19 +217,21 @@ def __init__( producer_extra_config = producer_extra_config or {} consumer_extra_config = consumer_extra_config or {} + state_dir = Path(state_dir) + # We can't use os.getenv as defaults (and have testing work nicely) # since it evaluates getenv when the function is defined. # In general this is just a most robust approach. broker_address = broker_address or os.getenv("Quix__Broker__Address") quix_sdk_token = quix_sdk_token or os.getenv("Quix__Sdk__Token") - consumer_group = consumer_group or os.getenv( - "Quix__Consumer_Group", "quixstreams-default" - ) + + if not consumer_group: + consumer_group = os.getenv("Quix__Consumer_Group", "quixstreams-default") if broker_address: # If broker_address is passed to the app it takes priority over any quix configuration self._is_quix_app = False - self._topic_manager_factory = TopicManager + self._topic_manager_factory: TopicManagerFactory = TopicManager if isinstance(broker_address, str): broker_address = ConnectionConfig(bootstrap_servers=broker_address) else: @@ -249,7 +263,6 @@ def __init__( QuixTopicManager, quix_config_builder=quix_config_builder ) # Check if the state dir points to the mounted PVC while running on Quix - state_dir = Path(state_dir) check_state_dir(state_dir=state_dir) quix_app_config = quix_config_builder.get_application_config(consumer_group) @@ -487,12 +500,13 @@ def dataframe( :param source: a `quixstreams.sources` "BaseSource" instance :return: `StreamingDataFrame` object """ - if not source and not topic: - raise ValueError("one of `source` or `topic` is required") - if source: + if source is not None: topic = self.add_source(source, topic) + if topic is None: + raise ValueError("one of `source` or `topic` is required") + sdf = StreamingDataFrame( topic=topic, topic_manager=self._topic_manager, @@ -1012,7 +1026,7 @@ class ApplicationConfig(BaseSettings): @classmethod def settings_customise_sources( cls, - settings_cls: Type[BaseSettings], + settings_cls: Type[PydanticBaseSettings], init_settings: PydanticBaseSettingsSource, env_settings: PydanticBaseSettingsSource, dotenv_settings: PydanticBaseSettingsSource, diff --git a/quixstreams/checkpointing/checkpoint.py b/quixstreams/checkpointing/checkpoint.py index c44915d15..05719788f 100644 --- a/quixstreams/checkpointing/checkpoint.py +++ b/quixstreams/checkpointing/checkpoint.py @@ -5,7 +5,7 @@ from confluent_kafka import KafkaException, TopicPartition -from quixstreams.kafka import Consumer +from quixstreams.kafka import BaseConsumer from quixstreams.processing.pausing import PausingManager from quixstreams.rowproducer import RowProducer from quixstreams.sinks import SinkManager @@ -48,7 +48,7 @@ def __init__( # processed offsets within the checkpoint self._starting_tp_offsets: Dict[Tuple[str, int], int] = {} # A mapping of <(topic, partition, store_name): PartitionTransaction> - self._store_transactions: Dict[(str, int, str), PartitionTransaction] = {} + self._store_transactions: Dict[Tuple[str, int, str], PartitionTransaction] = {} # Passing zero or lower will flush the checkpoint after each processed message self._commit_interval = max(commit_interval, 0) @@ -123,7 +123,7 @@ def __init__( self, commit_interval: float, producer: RowProducer, - consumer: Consumer, + consumer: BaseConsumer, state_manager: StateStoreManager, sink_manager: SinkManager, pausing_manager: PausingManager, diff --git a/quixstreams/context.py b/quixstreams/context.py index 2eaaaa415..2de9d3be9 100644 --- a/quixstreams/context.py +++ b/quixstreams/context.py @@ -11,7 +11,9 @@ "copy_context", ) -_current_message_context = ContextVar("current_message_context") +_current_message_context: ContextVar[Optional[MessageContext]] = ContextVar( + "current_message_context" +) class MessageContextNotSetError(QuixException): ... @@ -48,7 +50,7 @@ def alter_context(value): _current_message_context.set(context) -def message_context() -> MessageContext: +def message_context() -> Optional[MessageContext]: """ Get a MessageContext for the current message, which houses most of the message metadata, like: @@ -74,6 +76,5 @@ def message_context() -> MessageContext: """ try: return _current_message_context.get() - except LookupError: raise MessageContextNotSetError("Message context is not set") diff --git a/quixstreams/kafka/configuration.py b/quixstreams/kafka/configuration.py index a3ad4b009..0637cb1de 100644 --- a/quixstreams/kafka/configuration.py +++ b/quixstreams/kafka/configuration.py @@ -3,7 +3,12 @@ import pydantic from pydantic import AliasChoices, Field, SecretStr from pydantic.functional_validators import BeforeValidator -from pydantic_settings import PydanticBaseSettingsSource +from pydantic_settings import ( + BaseSettings as PydanticBaseSettings, +) +from pydantic_settings import ( + PydanticBaseSettingsSource, +) from typing_extensions import Annotated, Self from quixstreams.utils.settings import BaseSettings @@ -93,7 +98,7 @@ class ConnectionConfig(BaseSettings): @classmethod def settings_customise_sources( cls, - settings_cls: Type[BaseSettings], + settings_cls: Type[PydanticBaseSettings], init_settings: PydanticBaseSettingsSource, env_settings: PydanticBaseSettingsSource, dotenv_settings: PydanticBaseSettingsSource, diff --git a/quixstreams/kafka/consumer.py b/quixstreams/kafka/consumer.py index ec98d640e..1f5f6842b 100644 --- a/quixstreams/kafka/consumer.py +++ b/quixstreams/kafka/consumer.py @@ -1,7 +1,7 @@ import functools import logging import typing -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union from confluent_kafka import ( Consumer as ConfluentConsumer, @@ -18,6 +18,7 @@ from .configuration import ConnectionConfig __all__ = ( + "BaseConsumer", "Consumer", "AutoOffsetReset", "RebalancingCallback", @@ -64,7 +65,7 @@ def wrapper(*args, **kwargs): return wrapper -class Consumer: +class BaseConsumer: def __init__( self, broker_address: Union[str, ConnectionConfig], @@ -147,7 +148,7 @@ def poll(self, timeout: Optional[float] = None) -> Optional[Message]: """ return self._consumer.poll(timeout=timeout if timeout is not None else -1) - def subscribe( + def _subscribe( self, topics: List[str], on_assign: Optional[RebalancingCallback] = None, @@ -302,7 +303,8 @@ def commit( raise ValueError( 'Parameters "message" and "offsets" are mutually exclusive' ) - kwargs = { + + kwargs: dict[str, Any] = { "asynchronous": asynchronous, } if offsets is not None: @@ -559,3 +561,14 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): self.close() + + +class Consumer(BaseConsumer): + def subscribe( + self, + topics: List[str], + on_assign: Optional[RebalancingCallback] = None, + on_revoke: Optional[RebalancingCallback] = None, + on_lost: Optional[RebalancingCallback] = None, + ): + return super()._subscribe(topics, on_assign, on_revoke, on_lost) diff --git a/quixstreams/kafka/producer.py b/quixstreams/kafka/producer.py index 39a83c83b..3012b3cdc 100644 --- a/quixstreams/kafka/producer.py +++ b/quixstreams/kafka/producer.py @@ -46,7 +46,7 @@ def __init__( logger: logging.Logger = logger, error_callback: Callable[[KafkaError], None] = _default_error_cb, extra_config: Optional[dict] = None, - flush_timeout: Optional[int] = None, + flush_timeout: Optional[float] = None, ): """ A wrapper around `confluent_kafka.Producer`. @@ -190,7 +190,7 @@ def __init__( logger: logging.Logger = logger, error_callback: Callable[[KafkaError], None] = _default_error_cb, extra_config: Optional[dict] = None, - flush_timeout: Optional[int] = None, + flush_timeout: Optional[float] = None, ): super().__init__( broker_address=broker_address, diff --git a/quixstreams/platforms/quix/config.py b/quixstreams/platforms/quix/config.py index 987211c68..9c42ebdf9 100644 --- a/quixstreams/platforms/quix/config.py +++ b/quixstreams/platforms/quix/config.py @@ -77,7 +77,7 @@ class QuixApplicationConfig: librdkafka_connection_config: ConnectionConfig librdkafka_extra_config: dict - consumer_group: Optional[str] = None + consumer_group: str class QuixKafkaConfigsBuilder: diff --git a/quixstreams/platforms/quix/topic_manager.py b/quixstreams/platforms/quix/topic_manager.py index f201e83bc..753a58685 100644 --- a/quixstreams/platforms/quix/topic_manager.py +++ b/quixstreams/platforms/quix/topic_manager.py @@ -22,8 +22,8 @@ class QuixTopicManager(TopicManager): # Default topic params # Set these to None to use defaults defined in Quix Cloud - default_num_partitions = None - default_replication_factor = None + default_num_partitions: None = None + default_replication_factor: None = None # Max topic name length for the new topics _max_topic_name_len = 249 diff --git a/quixstreams/processing/context.py b/quixstreams/processing/context.py index 69da8ffd5..fc57f6f1b 100644 --- a/quixstreams/processing/context.py +++ b/quixstreams/processing/context.py @@ -52,7 +52,7 @@ def store_offset(self, topic: str, partition: int, offset: int): :param partition: partition number :param offset: message offset """ - self._checkpoint.store_offset(topic=topic, partition=partition, offset=offset) + self.checkpoint.store_offset(topic=topic, partition=partition, offset=offset) def init_checkpoint(self): """ @@ -79,13 +79,13 @@ def commit_checkpoint(self, force: bool = False): :param force: if `True`, commit the Checkpoint before its expiration deadline. """ - if self._checkpoint.expired() or force: - if self._checkpoint.empty(): - self._checkpoint.close() + if self.checkpoint.expired() or force: + if self.checkpoint.empty(): + self.checkpoint.close() else: logger.debug(f"Committing a checkpoint; forced={force}") start = time.monotonic() - self._checkpoint.commit() + self.checkpoint.commit() elapsed = round(time.monotonic() - start, 2) logger.debug( f"Committed a checkpoint; forced={force}, time_elapsed={elapsed}s" diff --git a/quixstreams/processing/pausing.py b/quixstreams/processing/pausing.py index 5ea34513f..9ee810cf1 100644 --- a/quixstreams/processing/pausing.py +++ b/quixstreams/processing/pausing.py @@ -5,7 +5,7 @@ from confluent_kafka import TopicPartition -from quixstreams.kafka import Consumer +from quixstreams.kafka import BaseConsumer logger = logging.getLogger(__name__) @@ -20,7 +20,7 @@ class PausingManager: _paused_tps: Dict[Tuple[str, int], float] - def __init__(self, consumer: Consumer): + def __init__(self, consumer: BaseConsumer): self._consumer = consumer self._paused_tps = {} self._next_resume_at = _MAX_FLOAT diff --git a/quixstreams/rowconsumer.py b/quixstreams/rowconsumer.py index ebf41238f..a9bea2016 100644 --- a/quixstreams/rowconsumer.py +++ b/quixstreams/rowconsumer.py @@ -5,7 +5,7 @@ from .error_callbacks import ConsumerErrorCallback, default_on_consumer_error from .exceptions import PartitionAssignmentError -from .kafka import AutoOffsetReset, ConnectionConfig, Consumer +from .kafka import AutoOffsetReset, BaseConsumer, ConnectionConfig from .kafka.consumer import RebalancingCallback from .kafka.exceptions import KafkaConsumerException from .models import Row, Topic @@ -16,14 +16,16 @@ __all__ = ("RowConsumer",) -class RowConsumer(Consumer): +class RowConsumer(BaseConsumer): def __init__( self, broker_address: Union[str, ConnectionConfig], consumer_group: str, auto_offset_reset: AutoOffsetReset, auto_commit_enable: bool = True, - on_commit: Callable[[Optional[KafkaError], List[TopicPartition]], None] = None, + on_commit: Optional[ + Callable[[Optional[KafkaError], List[TopicPartition]], None] + ] = None, extra_config: Optional[dict] = None, on_error: Optional[ConsumerErrorCallback] = None, ): @@ -64,9 +66,7 @@ def __init__( on_commit=on_commit, extra_config=extra_config, ) - self._on_error: Optional[ConsumerErrorCallback] = ( - on_error or default_on_consumer_error - ) + self._on_error: ConsumerErrorCallback = on_error or default_on_consumer_error self._topics: Mapping[str, Topic] = {} def subscribe( @@ -95,15 +95,15 @@ def subscribe( """ topics_map = {t.name: t for t in topics} topics_names = list(topics_map.keys()) - super().subscribe( + super()._subscribe( topics=topics_names, on_assign=on_assign, on_revoke=on_revoke, on_lost=on_lost, ) - self._topics = {t.name: t for t in topics} + self._topics = topics_map - def poll_row(self, timeout: float = None) -> Union[Row, List[Row], None]: + def poll_row(self, timeout: Optional[float] = None) -> Union[Row, List[Row], None]: """ Consumes a single message and deserialize it to Row or a list of Rows. @@ -122,11 +122,11 @@ def poll_row(self, timeout: float = None) -> Union[Row, List[Row], None]: except Exception as exc: to_suppress = self._on_error(exc, None, logger) if to_suppress: - return + return None raise if msg is None: - return + return None topic_name = msg.topic() try: @@ -139,9 +139,9 @@ def poll_row(self, timeout: float = None) -> Union[Row, List[Row], None]: return row_or_rows except IgnoreMessage: # Deserializer decided to ignore the message - return + return None except Exception as exc: to_suppress = self._on_error(exc, msg, logger) if to_suppress: - return + return None raise diff --git a/quixstreams/rowproducer.py b/quixstreams/rowproducer.py index ec6008e60..511d1e7d0 100644 --- a/quixstreams/rowproducer.py +++ b/quixstreams/rowproducer.py @@ -93,13 +93,13 @@ class RowProducer: def __init__( self, broker_address: Union[str, ConnectionConfig], - extra_config: dict = None, + extra_config: Optional[dict] = None, on_error: Optional[ProducerErrorCallback] = None, flush_timeout: Optional[float] = None, transactional: bool = False, ): if transactional: - self._producer = TransactionalProducer( + self._producer: Producer = TransactionalProducer( broker_address=broker_address, extra_config=extra_config, flush_timeout=flush_timeout, @@ -111,9 +111,7 @@ def __init__( flush_timeout=flush_timeout, ) - self._on_error: Optional[ProducerErrorCallback] = ( - on_error or default_on_producer_error - ) + self._on_error: ProducerErrorCallback = on_error or default_on_producer_error self._tp_offsets: Dict[Tuple[str, int], int] = {} self._error: Optional[KafkaError] = None self._active_transaction = False diff --git a/quixstreams/sinks/core/influxdb3.py b/quixstreams/sinks/core/influxdb3.py index 677edb641..28f50e7b7 100644 --- a/quixstreams/sinks/core/influxdb3.py +++ b/quixstreams/sinks/core/influxdb3.py @@ -31,7 +31,7 @@ def __init__( fields_keys: Iterable[str] = (), tags_keys: Iterable[str] = (), time_key: Optional[str] = None, - time_precision: WritePrecision = WritePrecision.MS, + time_precision: WritePrecision = WritePrecision.MS, # type: ignore include_metadata_tags: bool = False, batch_size: int = 1000, enable_gzip: bool = True, diff --git a/quixstreams/sources/base/exceptions.py b/quixstreams/sources/base/exceptions.py index 14ead25c9..2c3a865af 100644 --- a/quixstreams/sources/base/exceptions.py +++ b/quixstreams/sources/base/exceptions.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from .manager import SourceProcess @@ -12,7 +12,7 @@ class SourceException(Exception): """ def __init__(self, process: "SourceProcess") -> None: - self.pid: int = process.pid + self.pid: Optional[int] = process.pid self.process: SourceProcess = process self.exitcode = self.process.exitcode diff --git a/quixstreams/sources/base/manager.py b/quixstreams/sources/base/manager.py index b5a653935..d69ef66dd 100644 --- a/quixstreams/sources/base/manager.py +++ b/quixstreams/sources/base/manager.py @@ -1,8 +1,9 @@ import logging import signal import threading +from multiprocessing.context import SpawnProcess from pickle import PicklingError -from typing import List +from typing import TYPE_CHECKING, List from confluent_kafka import OFFSET_BEGINNING @@ -20,8 +21,13 @@ logger = logging.getLogger(__name__) +if TYPE_CHECKING: + process = SpawnProcess +else: + process = multiprocessing.Process -class SourceProcess(multiprocessing.Process): + +class SourceProcess(process): """ An implementation of the Source subprocess. @@ -38,7 +44,7 @@ def __init__( producer: RowProducer, consumer: RowConsumer, topic_manager: TopicManager, - ): + ) -> None: super().__init__() self.topic = topic self.source = source @@ -85,8 +91,7 @@ def run(self) -> None: configure_logging(self._loglevel, str(self.source), pid=True) logger.info("Starting source") - configuration = {"topic": self.topic, "producer": self._producer} - + configuration = {} if isinstance(self.source, StatefulSource): try: configuration["store_partition"] = self._recover_state(self.source) @@ -95,7 +100,9 @@ def run(self) -> None: self._report_exception(err) return - self.source.configure(**configuration) + self.source.configure( + topic=self.topic, producer=self._producer, **configuration + ) logger.info("Source started") try: @@ -184,7 +191,7 @@ def _report_exception(self, err: BaseException) -> None: def start(self) -> None: logger.info("Starting source %s", self.source) self._started = True - return super().start() + super().start() def raise_for_error(self) -> None: """ @@ -240,7 +247,7 @@ class SourceManager: Sources run in their separate process pay attention about cross-process communication """ - def __init__(self): + def __init__(self) -> None: self.processes: List[SourceProcess] = [] def register( diff --git a/quixstreams/sources/base/source.py b/quixstreams/sources/base/source.py index b9c8753c6..8a8e0eb90 100644 --- a/quixstreams/sources/base/source.py +++ b/quixstreams/sources/base/source.py @@ -81,7 +81,7 @@ def main(): # time in seconds the application will wait for the source to stop. shutdown_timeout: float = 10 - def __init__(self): + def __init__(self) -> None: self._producer: Optional[RowProducer] = None self._producer_topic: Optional[Topic] = None @@ -95,7 +95,15 @@ def configure(self, topic: Topic, producer: RowProducer, **kwargs) -> None: self._producer_topic = topic @property - def producer_topic(self): + def producer(self) -> RowProducer: + if self._producer is None: + raise RuntimeError("source not configured") + return self._producer + + @property + def producer_topic(self) -> Topic: + if self._producer_topic is None: + raise RuntimeError("source not configured") return self._producer_topic @abstractmethod @@ -214,9 +222,8 @@ def stop(self) -> None: It sets the `running` property to `False`. """ self._running = False - super().stop() - def start(self): + def start(self) -> None: """ This method is triggered in the subprocess when the source is started. @@ -252,7 +259,7 @@ def serialize( :return: `quixstreams.models.messages.KafkaMessage` """ - return self._producer_topic.serialize( + return self.producer_topic.serialize( key=key, value=value, headers=headers, timestamp_ms=timestamp_ms ) @@ -270,8 +277,8 @@ def produce( Produce a message to the configured source topic in Kafka. """ - self._producer.produce( - topic=self._producer_topic.name, + self.producer.produce( + topic=self.producer_topic.name, value=value, key=key, headers=headers, @@ -293,7 +300,7 @@ def flush(self, timeout: Optional[float] = None) -> None: :raises CheckpointProducerTimeout: if any message fails to produce before the timeout """ logger.debug("Flushing source") - unproduced_msg_count = self._producer.flush(timeout) + unproduced_msg_count = self.producer.flush(timeout) if unproduced_msg_count > 0: raise CheckpointProducerTimeout( f"'{unproduced_msg_count}' messages failed to be produced before the producer flush timeout" diff --git a/quixstreams/sources/core/csv.py b/quixstreams/sources/core/csv.py index e4fc5bd9d..8def52a60 100644 --- a/quixstreams/sources/core/csv.py +++ b/quixstreams/sources/core/csv.py @@ -2,7 +2,7 @@ import logging import time from pathlib import Path -from typing import AnyStr, Callable, Optional, Union +from typing import Callable, Optional, Union from quixstreams.models.topics import Topic from quixstreams.sources.base import Source @@ -15,7 +15,7 @@ def __init__( self, path: Union[str, Path], name: str, - key_extractor: Optional[Callable[[dict], AnyStr]] = None, + key_extractor: Optional[Callable[[dict], Union[str, bytes]]] = None, timestamp_extractor: Optional[Callable[[dict], int]] = None, delay: float = 0, shutdown_timeout: float = 10, diff --git a/quixstreams/sources/core/kafka/checkpoint.py b/quixstreams/sources/core/kafka/checkpoint.py index 659d0fcb0..81d5bc0e6 100644 --- a/quixstreams/sources/core/kafka/checkpoint.py +++ b/quixstreams/sources/core/kafka/checkpoint.py @@ -7,8 +7,8 @@ CheckpointConsumerCommitError, CheckpointProducerTimeout, ) +from quixstreams.kafka.consumer import BaseConsumer from quixstreams.models.topics import Topic -from quixstreams.rowconsumer import Consumer from quixstreams.rowproducer import RowProducer @@ -21,7 +21,7 @@ def __init__( self, producer: RowProducer, producer_topic: Topic, - consumer: Consumer, + consumer: BaseConsumer, commit_interval: float, commit_every: int = 0, flush_timeout: float = 10, @@ -83,6 +83,7 @@ def _commit(self, offsets: List[TopicPartition]): ) else: partitions = self._consumer.commit(offsets=offsets, asynchronous=False) - for partition in partitions: - if partition.error: - raise CheckpointConsumerCommitError(partition.error) + if partitions: + for partition in partitions: + if partition.error: + raise CheckpointConsumerCommitError(partition.error) diff --git a/quixstreams/sources/core/kafka/kafka.py b/quixstreams/sources/core/kafka/kafka.py index 43825ca81..398337b12 100644 --- a/quixstreams/sources/core/kafka/kafka.py +++ b/quixstreams/sources/core/kafka/kafka.py @@ -57,11 +57,11 @@ def __init__( app_config: "ApplicationConfig", topic: str, broker_address: Union[str, ConnectionConfig], - auto_offset_reset: AutoOffsetReset = "latest", + auto_offset_reset: Optional[AutoOffsetReset] = "latest", consumer_extra_config: Optional[dict] = None, consumer_poll_timeout: Optional[float] = None, shutdown_timeout: float = 10, - on_consumer_error: Optional[ConsumerErrorCallback] = default_on_consumer_error, + on_consumer_error: ConsumerErrorCallback = default_on_consumer_error, value_deserializer: DeserializerType = "json", key_deserializer: DeserializerType = "bytes", ) -> None: @@ -131,6 +131,36 @@ def target_consumer_group(self): consumer_group = f"{self._config.consumer_group_prefix}-{consumer_group}" return consumer_group + @property + def checkpoint(self) -> Checkpoint: + if self._checkpoint is None: + raise RuntimeError("source not started") + return self._checkpoint + + @property + def source_cluster_consumer(self) -> Consumer: + if self._source_cluster_consumer is None: + raise RuntimeError("source not started") + return self._source_cluster_consumer + + @property + def target_cluster_consumer(self) -> Consumer: + if self._target_cluster_consumer is None: + raise RuntimeError("source not started") + return self._target_cluster_consumer + + @property + def source_cluster_admin(self) -> TopicAdmin: + if self._source_cluster_admin is None: + raise RuntimeError("source not started") + return self._source_cluster_admin + + @property + def target_cluster_admin(self) -> TopicAdmin: + if self._target_cluster_admin is None: + raise RuntimeError("source not started") + return self._target_cluster_admin + def run(self) -> None: logger.info( f'Starting the source "{self.name}" with the config: ' @@ -178,11 +208,9 @@ def run(self) -> None: on_revoke=self.on_revoke, ) - super().run() - self.init_checkpoint() while self._running: - self._producer.poll() + self.producer.poll() msg = self.poll_source() if msg is None: continue @@ -192,7 +220,7 @@ def run(self) -> None: def produce_message(self, msg: Message): topic_name, partition, offset = msg.topic(), msg.partition(), msg.offset() - self._checkpoint.store_offset(topic_name, partition, offset) + self.checkpoint.store_offset(topic_name, partition, offset) self.produce( value=msg.value(), key=msg.key(), @@ -203,37 +231,35 @@ def produce_message(self, msg: Message): def poll_source(self) -> Optional[Message]: try: - msg = self._source_cluster_consumer.poll( - timeout=self._consumer_poll_timeout - ) + msg = self.source_cluster_consumer.poll(timeout=self._consumer_poll_timeout) except Exception as exc: if self._on_consumer_error(exc, None, logger): - return + return None raise if msg is None: - return + return None try: if err := msg.error(): raise KafkaConsumerException(error=err) except Exception as exc: if self._on_consumer_error(exc, msg, logger): - return + return None raise return msg def commit_checkpoint(self, force: bool = False) -> None: - if not self._checkpoint.expired() and not force: + if not self.checkpoint.expired() and not force: return - if self._checkpoint.empty(): - self._checkpoint.close() + if self.checkpoint.empty(): + self.checkpoint.close() else: logger.debug("Committing checkpoint") start = time.monotonic() - self._checkpoint.commit() + self.checkpoint.commit() elapsed = round(time.monotonic() - start, 2) logger.debug(f"Checkpoint commited in {elapsed}s") @@ -241,9 +267,9 @@ def commit_checkpoint(self, force: bool = False) -> None: def init_checkpoint(self) -> None: self._checkpoint = Checkpoint( - producer=self._producer, - producer_topic=self._producer_topic, - consumer=self._target_cluster_consumer, + producer=self.producer, + producer_topic=self.producer_topic, + consumer=self.target_cluster_consumer, commit_every=self._config.commit_every, commit_interval=self._config.commit_interval, flush_timeout=self._flush_timeout, @@ -251,7 +277,7 @@ def init_checkpoint(self) -> None: ) def _validate_topics(self) -> None: - source_topic_config = self._source_cluster_admin.inspect_topics( + source_topic_config = self.source_cluster_admin.inspect_topics( topic_names=[self._topic], timeout=self._config.request_timeout ).get(self._topic) @@ -262,17 +288,17 @@ def _validate_topics(self) -> None: "source topic %s configuration: %s", self._topic, source_topic_config ) - target_topic_config = self._target_cluster_admin.inspect_topics( - topic_names=[self._producer_topic.name], + target_topic_config = self.target_cluster_admin.inspect_topics( + topic_names=[self.producer_topic.name], timeout=self._config.request_timeout, - ).get(self._producer_topic.name) + ).get(self.producer_topic.name) if target_topic_config is None: - raise ValueError(f"Destination topic {self._producer_topic.name} not found") + raise ValueError(f"Destination topic {self.producer_topic.name} not found") logger.debug( "destination topic %s configuration: %s", - self._producer_topic.name, + self.producer_topic.name, target_topic_config, ) @@ -286,11 +312,11 @@ def _target_cluster_offsets( ) -> Dict[int, int]: partitions = [ TopicPartition( - topic=self._producer_topic.name, partition=partition.partition + topic=self.producer_topic.name, partition=partition.partition ) for partition in partitions ] - partitions_commited = self._target_cluster_consumer.committed( + partitions_commited = self.target_cluster_consumer.committed( partitions, timeout=self._config.request_timeout ) a = {partition.partition: partition.offset for partition in partitions_commited} @@ -307,11 +333,11 @@ def on_assign(self, _, source_partitions: List[TopicPartition]) -> None: partition.partition, ) - self._source_cluster_consumer.incremental_assign(source_partitions) + self.source_cluster_consumer.incremental_assign(source_partitions) def on_revoke(self, *_) -> None: if self._failed: - self._checkpoint.close() + self.checkpoint.close() else: self.commit_checkpoint(force=True) @@ -324,8 +350,8 @@ def stop(self) -> None: def cleanup(self, failed: bool) -> None: self._failed = failed - self._source_cluster_consumer.close() - self._target_cluster_consumer.close() + self.source_cluster_consumer.close() + self.target_cluster_consumer.close() def default_topic(self) -> Topic: admin = TopicAdmin( diff --git a/quixstreams/sources/core/kafka/quix.py b/quixstreams/sources/core/kafka/quix.py index 87f1520c5..7483a68f7 100644 --- a/quixstreams/sources/core/kafka/quix.py +++ b/quixstreams/sources/core/kafka/quix.py @@ -3,8 +3,8 @@ from quixstreams.error_callbacks import ConsumerErrorCallback, default_on_consumer_error from quixstreams.kafka import AutoOffsetReset from quixstreams.models.serializers import DeserializerType -from quixstreams.models.topics import Topic, TopicConfig -from quixstreams.platforms.quix import QuixKafkaConfigsBuilder, QuixTopicManager +from quixstreams.models.topics import Topic +from quixstreams.platforms.quix import QuixKafkaConfigsBuilder from quixstreams.platforms.quix.api import QuixPortalApiService from .kafka import KafkaReplicatorSource @@ -59,7 +59,7 @@ def __init__( consumer_extra_config: Optional[dict] = None, consumer_poll_timeout: Optional[float] = None, shutdown_timeout: float = 10, - on_consumer_error: Optional[ConsumerErrorCallback] = default_on_consumer_error, + on_consumer_error: ConsumerErrorCallback = default_on_consumer_error, value_deserializer: DeserializerType = "json", key_deserializer: DeserializerType = "bytes", ) -> None: @@ -86,15 +86,7 @@ def __init__( ) quix_topic = self._quix_config.convert_topic_response( - self._quix_config.get_or_create_topic( - Topic( - name=topic, - config=TopicConfig( - num_partitions=QuixTopicManager.default_num_partitions, - replication_factor=QuixTopicManager.default_replication_factor, - ), - ) - ) + self._quix_config.get_or_create_topic(Topic(name=topic)) ) consumer_extra_config.update(self._quix_config.librdkafka_extra_config) diff --git a/quixstreams/state/recovery.py b/quixstreams/state/recovery.py index 6d7f37af2..b54bcbdc8 100644 --- a/quixstreams/state/recovery.py +++ b/quixstreams/state/recovery.py @@ -3,7 +3,7 @@ from confluent_kafka import TopicPartition as ConfluentPartition -from quixstreams.kafka import Consumer +from quixstreams.kafka import BaseConsumer from quixstreams.models import ConfluentKafkaMessageProto, Topic from quixstreams.models.topics import TopicConfig, TopicManager from quixstreams.models.types import MessageHeadersMapping @@ -229,7 +229,7 @@ class RecoveryManager: Recovery is attempted from the `Application` after any new partition assignment. """ - def __init__(self, consumer: Consumer, topic_manager: TopicManager): + def __init__(self, consumer: BaseConsumer, topic_manager: TopicManager): self._running = False self._consumer = consumer self._topic_manager = topic_manager diff --git a/requirements-mypy.txt b/requirements-mypy.txt new file mode 100644 index 000000000..e2ccc0d0a --- /dev/null +++ b/requirements-mypy.txt @@ -0,0 +1,5 @@ +mypy==1.13.0 +mypy-extensions==1.0.0 +types-jsonschema==4.23.0.20240813 +types-protobuf==5.28.3.20241030 +types-requests==2.32.0.20241016 diff --git a/tests/test_quixstreams/test_sources/test_core/test_kafka.py b/tests/test_quixstreams/test_sources/test_core/test_kafka.py index b7491ad71..2e252bc80 100644 --- a/tests/test_quixstreams/test_sources/test_core/test_kafka.py +++ b/tests/test_quixstreams/test_sources/test_core/test_kafka.py @@ -341,13 +341,11 @@ def mock_get_or_create(self, topic, timeout=None): "id": f"{workspace_id}-{topic.name}", "name": topic.name, "configuration": { - "partitions": topic.config.num_partitions, - "replicationFactor": topic.config.replication_factor, + "partitions": 1, + "replicationFactor": 1, "retentionInMinutes": 1, "retentionInBytes": 1, - "cleanupPolicy": topic.config.extra_config.get( - "cleanup.policy", "Delete" - ), + "cleanupPolicy": "Delete", }, }