Skip to content

Commit

Permalink
CI: Implement mypy pre-commit check
Browse files Browse the repository at this point in the history
  • Loading branch information
quentin-quix committed Nov 25, 2024
1 parent 45ee8a1 commit 9555406
Show file tree
Hide file tree
Showing 26 changed files with 246 additions and 131 deletions.
11 changes: 8 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: 3.12
- name: Update pip
run: |
python -m pip install -U pip
- name: Install requirements
run: |
python -m pip install -U -r requirements.txt -r tests/requirements.txt -r requirements-mypy.txt
- uses: pre-commit/[email protected]

test:
Expand All @@ -49,8 +55,7 @@ jobs:
python -m pip install -U pip
- name: Install requirements
run: |
python -m pip install -U -r tests/requirements.txt
python -m pip install -U -r requirements.txt
python -m pip install -U -r requirements.txt -r tests/requirements.txt
- name: Run tests
run: |
python -m pytest -v --log-cli-level=ERROR
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ repos:
entry: python conda/requirements.py
language: python
files: ^(requirements\.txt|pyproject\.toml)$
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.13.0
hooks:
- id: mypy
args: []
language: system
files: ^quixstreams/
36 changes: 36 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,39 @@ log_cli_level = "INFO"
log_cli_format = "[%(levelname)s] %(name)s: %(message)s"
# Custom markers
markers = ["timeit"]

[[tool.mypy.overrides]]
module = "confluent_kafka.*"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "quixstreams.core.*"
ignore_errors = true

[[tool.mypy.overrides]]
module = "quixstreams.dataframe.*"
ignore_errors = true

[[tool.mypy.overrides]]
module = "quixstreams.models.*"
ignore_errors = true

[[tool.mypy.overrides]]
module = "quixstreams.platforms.*"
ignore_errors = true

[[tool.mypy.overrides]]
module = "quixstreams.sinks.community.*"
ignore_errors = true

[[tool.mypy.overrides]]
module = "quixstreams.sources.community.*"
ignore_errors = true

[[tool.mypy.overrides]]
module = "quixstreams.state.*"
ignore_errors = true

[[tool.mypy.overrides]]
module = "quixstreams.rowproducer.*"
ignore_errors = true
34 changes: 24 additions & 10 deletions quixstreams/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import time
import warnings
from pathlib import Path
from typing import Callable, List, Literal, Optional, Tuple, Type, Union
from typing import Callable, List, Literal, Optional, Protocol, Tuple, Type, Union

from confluent_kafka import TopicPartition
from pydantic import AliasGenerator, Field
from pydantic_settings import BaseSettings as PydanticBaseSettings
from pydantic_settings import PydanticBaseSettingsSource, SettingsConfigDict
from typing_extensions import Self

Expand Down Expand Up @@ -60,6 +61,17 @@
_default_max_poll_interval_ms = 300000


class TopicManagerFactory(Protocol):
def __call__(
self,
topic_admin: TopicAdmin,
consumer_group: str,
timeout: float = 30,
create_timeout: float = 60,
auto_create_topics: bool = True,
) -> TopicManager: ...


class Application:
"""
The main Application class.
Expand Down Expand Up @@ -205,19 +217,21 @@ def __init__(
producer_extra_config = producer_extra_config or {}
consumer_extra_config = consumer_extra_config or {}

state_dir = Path(state_dir)

# We can't use os.getenv as defaults (and have testing work nicely)
# since it evaluates getenv when the function is defined.
# In general this is just a most robust approach.
broker_address = broker_address or os.getenv("Quix__Broker__Address")
quix_sdk_token = quix_sdk_token or os.getenv("Quix__Sdk__Token")
consumer_group = consumer_group or os.getenv(
"Quix__Consumer_Group", "quixstreams-default"
)

if not consumer_group:
consumer_group = os.getenv("Quix__Consumer_Group", "quixstreams-default")

if broker_address:
# If broker_address is passed to the app it takes priority over any quix configuration
self._is_quix_app = False
self._topic_manager_factory = TopicManager
self._topic_manager_factory: TopicManagerFactory = TopicManager
if isinstance(broker_address, str):
broker_address = ConnectionConfig(bootstrap_servers=broker_address)
else:
Expand Down Expand Up @@ -249,7 +263,6 @@ def __init__(
QuixTopicManager, quix_config_builder=quix_config_builder
)
# Check if the state dir points to the mounted PVC while running on Quix
state_dir = Path(state_dir)
check_state_dir(state_dir=state_dir)
quix_app_config = quix_config_builder.get_application_config(consumer_group)

Expand Down Expand Up @@ -487,12 +500,13 @@ def dataframe(
:param source: a `quixstreams.sources` "BaseSource" instance
:return: `StreamingDataFrame` object
"""
if not source and not topic:
raise ValueError("one of `source` or `topic` is required")

if source:
if source is not None:
topic = self.add_source(source, topic)

if topic is None:
raise ValueError("one of `source` or `topic` is required")

sdf = StreamingDataFrame(
topic=topic,
topic_manager=self._topic_manager,
Expand Down Expand Up @@ -1012,7 +1026,7 @@ class ApplicationConfig(BaseSettings):
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
settings_cls: Type[PydanticBaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
Expand Down
6 changes: 3 additions & 3 deletions quixstreams/checkpointing/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from confluent_kafka import KafkaException, TopicPartition

from quixstreams.kafka import Consumer
from quixstreams.kafka import BaseConsumer
from quixstreams.processing.pausing import PausingManager
from quixstreams.rowproducer import RowProducer
from quixstreams.sinks import SinkManager
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(
# processed offsets within the checkpoint
self._starting_tp_offsets: Dict[Tuple[str, int], int] = {}
# A mapping of <(topic, partition, store_name): PartitionTransaction>
self._store_transactions: Dict[(str, int, str), PartitionTransaction] = {}
self._store_transactions: Dict[Tuple[str, int, str], PartitionTransaction] = {}
# Passing zero or lower will flush the checkpoint after each processed message
self._commit_interval = max(commit_interval, 0)

Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(
self,
commit_interval: float,
producer: RowProducer,
consumer: Consumer,
consumer: BaseConsumer,
state_manager: StateStoreManager,
sink_manager: SinkManager,
pausing_manager: PausingManager,
Expand Down
7 changes: 4 additions & 3 deletions quixstreams/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
"copy_context",
)

_current_message_context = ContextVar("current_message_context")
_current_message_context: ContextVar[Optional[MessageContext]] = ContextVar(
"current_message_context"
)


class MessageContextNotSetError(QuixException): ...
Expand Down Expand Up @@ -48,7 +50,7 @@ def alter_context(value):
_current_message_context.set(context)


def message_context() -> MessageContext:
def message_context() -> Optional[MessageContext]:
"""
Get a MessageContext for the current message, which houses most of the message
metadata, like:
Expand All @@ -74,6 +76,5 @@ def message_context() -> MessageContext:
"""
try:
return _current_message_context.get()

except LookupError:
raise MessageContextNotSetError("Message context is not set")
9 changes: 7 additions & 2 deletions quixstreams/kafka/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import pydantic
from pydantic import AliasChoices, Field, SecretStr
from pydantic.functional_validators import BeforeValidator
from pydantic_settings import PydanticBaseSettingsSource
from pydantic_settings import (
BaseSettings as PydanticBaseSettings,
)
from pydantic_settings import (
PydanticBaseSettingsSource,
)
from typing_extensions import Annotated, Self

from quixstreams.utils.settings import BaseSettings
Expand Down Expand Up @@ -93,7 +98,7 @@ class ConnectionConfig(BaseSettings):
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
settings_cls: Type[PydanticBaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
Expand Down
21 changes: 17 additions & 4 deletions quixstreams/kafka/consumer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import logging
import typing
from typing import Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

from confluent_kafka import (
Consumer as ConfluentConsumer,
Expand All @@ -18,6 +18,7 @@
from .configuration import ConnectionConfig

__all__ = (
"BaseConsumer",
"Consumer",
"AutoOffsetReset",
"RebalancingCallback",
Expand Down Expand Up @@ -64,7 +65,7 @@ def wrapper(*args, **kwargs):
return wrapper


class Consumer:
class BaseConsumer:
def __init__(
self,
broker_address: Union[str, ConnectionConfig],
Expand Down Expand Up @@ -147,7 +148,7 @@ def poll(self, timeout: Optional[float] = None) -> Optional[Message]:
"""
return self._consumer.poll(timeout=timeout if timeout is not None else -1)

def subscribe(
def _subscribe(
self,
topics: List[str],
on_assign: Optional[RebalancingCallback] = None,
Expand Down Expand Up @@ -302,7 +303,8 @@ def commit(
raise ValueError(
'Parameters "message" and "offsets" are mutually exclusive'
)
kwargs = {

kwargs: dict[str, Any] = {
"asynchronous": asynchronous,
}
if offsets is not None:
Expand Down Expand Up @@ -559,3 +561,14 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
self.close()


class Consumer(BaseConsumer):
def subscribe(
self,
topics: List[str],
on_assign: Optional[RebalancingCallback] = None,
on_revoke: Optional[RebalancingCallback] = None,
on_lost: Optional[RebalancingCallback] = None,
):
return super()._subscribe(topics, on_assign, on_revoke, on_lost)
4 changes: 2 additions & 2 deletions quixstreams/kafka/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
logger: logging.Logger = logger,
error_callback: Callable[[KafkaError], None] = _default_error_cb,
extra_config: Optional[dict] = None,
flush_timeout: Optional[int] = None,
flush_timeout: Optional[float] = None,
):
"""
A wrapper around `confluent_kafka.Producer`.
Expand Down Expand Up @@ -190,7 +190,7 @@ def __init__(
logger: logging.Logger = logger,
error_callback: Callable[[KafkaError], None] = _default_error_cb,
extra_config: Optional[dict] = None,
flush_timeout: Optional[int] = None,
flush_timeout: Optional[float] = None,
):
super().__init__(
broker_address=broker_address,
Expand Down
2 changes: 1 addition & 1 deletion quixstreams/platforms/quix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class QuixApplicationConfig:

librdkafka_connection_config: ConnectionConfig
librdkafka_extra_config: dict
consumer_group: Optional[str] = None
consumer_group: str


class QuixKafkaConfigsBuilder:
Expand Down
4 changes: 2 additions & 2 deletions quixstreams/platforms/quix/topic_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class QuixTopicManager(TopicManager):

# Default topic params
# Set these to None to use defaults defined in Quix Cloud
default_num_partitions = None
default_replication_factor = None
default_num_partitions: None = None
default_replication_factor: None = None

# Max topic name length for the new topics
_max_topic_name_len = 249
Expand Down
10 changes: 5 additions & 5 deletions quixstreams/processing/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def store_offset(self, topic: str, partition: int, offset: int):
:param partition: partition number
:param offset: message offset
"""
self._checkpoint.store_offset(topic=topic, partition=partition, offset=offset)
self.checkpoint.store_offset(topic=topic, partition=partition, offset=offset)

def init_checkpoint(self):
"""
Expand All @@ -79,13 +79,13 @@ def commit_checkpoint(self, force: bool = False):
:param force: if `True`, commit the Checkpoint before its expiration deadline.
"""
if self._checkpoint.expired() or force:
if self._checkpoint.empty():
self._checkpoint.close()
if self.checkpoint.expired() or force:
if self.checkpoint.empty():
self.checkpoint.close()
else:
logger.debug(f"Committing a checkpoint; forced={force}")
start = time.monotonic()
self._checkpoint.commit()
self.checkpoint.commit()
elapsed = round(time.monotonic() - start, 2)
logger.debug(
f"Committed a checkpoint; forced={force}, time_elapsed={elapsed}s"
Expand Down
4 changes: 2 additions & 2 deletions quixstreams/processing/pausing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from confluent_kafka import TopicPartition

from quixstreams.kafka import Consumer
from quixstreams.kafka import BaseConsumer

logger = logging.getLogger(__name__)

Expand All @@ -20,7 +20,7 @@ class PausingManager:

_paused_tps: Dict[Tuple[str, int], float]

def __init__(self, consumer: Consumer):
def __init__(self, consumer: BaseConsumer):
self._consumer = consumer
self._paused_tps = {}
self._next_resume_at = _MAX_FLOAT
Expand Down
Loading

0 comments on commit 9555406

Please sign in to comment.