Skip to content

Commit

Permalink
mypy: make quixstreams.models.* pass type checks (#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
quentin-quix authored Dec 6, 2024
1 parent dc66d90 commit 808dd79
Show file tree
Hide file tree
Showing 19 changed files with 167 additions and 88 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ ignore_missing_imports = true
module = [
"quixstreams.sinks.community.*",
"quixstreams.sources.community.*",
"quixstreams.models.serializers.quix.*",
]
ignore_errors = true

[[tool.mypy.overrides]]
module = [
"quixstreams.core.*",
"quixstreams.dataframe.*",
"quixstreams.models.*",
"quixstreams.platforms.*",
"quixstreams.rowproducer.*"
]
Expand Down
6 changes: 3 additions & 3 deletions quixstreams/error_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import logging
from typing import Callable, Optional

from .models import ConfluentKafkaMessageProto, Row
from .models import RawConfluentKafkaMessageProto, Row

ProcessingErrorCallback = Callable[[Exception, Optional[Row], logging.Logger], bool]
ConsumerErrorCallback = Callable[
[Exception, Optional[ConfluentKafkaMessageProto], logging.Logger], bool
[Exception, Optional[RawConfluentKafkaMessageProto], logging.Logger], bool
]
ProducerErrorCallback = Callable[[Exception, Optional[Row], logging.Logger], bool]


def default_on_consumer_error(
exc: Exception,
message: Optional[ConfluentKafkaMessageProto],
message: Optional[RawConfluentKafkaMessageProto],
logger: logging.Logger,
):
topic, partition, offset = None, None, None
Expand Down
19 changes: 17 additions & 2 deletions quixstreams/kafka/consumer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import logging
import typing
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union, cast

from confluent_kafka import (
Consumer as ConfluentConsumer,
Expand All @@ -14,8 +14,13 @@
from confluent_kafka.admin import ClusterMetadata, GroupMetadata

from quixstreams.exceptions import KafkaPartitionError, PartitionAssignmentError
from quixstreams.models.types import (
RawConfluentKafkaMessageProto,
SuccessfulConfluentKafkaMessageProto,
)

from .configuration import ConnectionConfig
from .exceptions import KafkaConsumerException

__all__ = (
"BaseConsumer",
Expand Down Expand Up @@ -65,6 +70,14 @@ def wrapper(*args, **kwargs):
return wrapper


def raise_for_msg_error(
msg: RawConfluentKafkaMessageProto,
) -> SuccessfulConfluentKafkaMessageProto:
if msg.error():
raise KafkaConsumerException(error=msg.error())
return cast(SuccessfulConfluentKafkaMessageProto, msg)


class BaseConsumer:
def __init__(
self,
Expand Down Expand Up @@ -129,7 +142,9 @@ def __init__(
}
self._inner_consumer: Optional[ConfluentConsumer] = None

def poll(self, timeout: Optional[float] = None) -> Optional[Message]:
def poll(
self, timeout: Optional[float] = None
) -> Optional[RawConfluentKafkaMessageProto]:
"""
Consumes a single message, calls callbacks and returns events.
Expand Down
2 changes: 1 addition & 1 deletion quixstreams/models/rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
key: Optional[Any],
timestamp: int,
context: MessageContext,
headers: KafkaHeaders = None,
headers: KafkaHeaders,
):
self.value = value
self.key = key
Expand Down
34 changes: 21 additions & 13 deletions quixstreams/models/serializers/avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ def __init__(
)

super().__init__()

if schema is None and schema_registry_client_config is None:
raise TypeError(
"One of `schema` or `schema_registry_client_config` is required"
)

self._schema = parse_schema(schema) if schema else None
self._reader_schema = parse_schema(reader_schema) if reader_schema else None
self._return_record_name = return_record_name
Expand All @@ -174,17 +180,19 @@ def __call__(
return self._schema_registry_deserializer(value, ctx)
except (SchemaRegistryError, _SerializationError, EOFError) as exc:
raise SerializationError(str(exc)) from exc
elif self._schema is not None:
try:
return schemaless_reader( # type: ignore
BytesIO(value),
self._schema,
reader_schema=self._reader_schema,
return_record_name=self._return_record_name,
return_record_name_override=self._return_record_name_override,
return_named_type=self._return_named_type,
return_named_type_override=self._return_named_type_override,
handle_unicode_errors=self._handle_unicode_errors,
)
except EOFError as exc:
raise SerializationError(str(exc)) from exc

try:
return schemaless_reader(
BytesIO(value),
self._schema,
reader_schema=self._reader_schema,
return_record_name=self._return_record_name,
return_record_name_override=self._return_record_name_override,
return_named_type=self._return_named_type,
return_named_type_override=self._return_named_type_override,
handle_unicode_errors=self._handle_unicode_errors,
)
except EOFError as exc:
raise SerializationError(str(exc)) from exc
raise SerializationError("no schema found")
6 changes: 3 additions & 3 deletions quixstreams/models/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)
from typing_extensions import Literal, TypeAlias

from ..types import HeadersMapping, KafkaHeaders
from ..types import Headers, HeadersMapping, KafkaHeaders

__all__ = (
"SerializationContext",
Expand All @@ -33,8 +33,8 @@ class SerializationContext(_SerializationContext):
def __init__(
self,
topic: str,
field: MessageField,
headers: KafkaHeaders = None,
field: str,
headers: Union[KafkaHeaders, Headers] = None,
) -> None:
self.topic = topic
self.field = field
Expand Down
10 changes: 6 additions & 4 deletions quixstreams/models/serializers/json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Callable, Iterable, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, Union

from confluent_kafka.schema_registry import SchemaRegistryClient, SchemaRegistryError
from confluent_kafka.schema_registry.json_schema import (
Expand All @@ -10,7 +10,6 @@
)
from confluent_kafka.serialization import SerializationError as _SerializationError
from jsonschema import Draft202012Validator, ValidationError
from jsonschema.protocols import Validator

from quixstreams.utils.json import (
dumps as default_dumps,
Expand All @@ -26,6 +25,9 @@
SchemaRegistrySerializationConfig,
)

if TYPE_CHECKING:
from jsonschema.validators import _Validator

__all__ = ("JSONSerializer", "JSONDeserializer")


Expand All @@ -34,7 +36,7 @@ def __init__(
self,
dumps: Callable[[Any], Union[str, bytes]] = default_dumps,
schema: Optional[Mapping] = None,
validator: Optional[Validator] = None,
validator: Optional["_Validator"] = None,
schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None,
schema_registry_serialization_config: Optional[
SchemaRegistrySerializationConfig
Expand Down Expand Up @@ -121,7 +123,7 @@ def __init__(
self,
loads: Callable[[Union[bytes, bytearray]], Any] = default_loads,
schema: Optional[Mapping] = None,
validator: Optional[Validator] = None,
validator: Optional["_Validator"] = None,
schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None,
):
"""
Expand Down
6 changes: 3 additions & 3 deletions quixstreams/models/serializers/protobuf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, Mapping, Optional, Union
from typing import Dict, Iterable, Mapping, Optional, Type, Union

from confluent_kafka.schema_registry import SchemaRegistryClient, SchemaRegistryError
from confluent_kafka.schema_registry.protobuf import (
Expand All @@ -24,7 +24,7 @@
class ProtobufSerializer(Serializer):
def __init__(
self,
msg_type: Message,
msg_type: Type[Message],
deterministic: bool = False,
ignore_unknown_fields: bool = False,
schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None,
Expand Down Expand Up @@ -110,7 +110,7 @@ def __call__(
class ProtobufDeserializer(Deserializer):
def __init__(
self,
msg_type: Message,
msg_type: Type[Message],
use_integers_for_enums: bool = False,
preserving_proto_field_name: bool = False,
to_dict: bool = True,
Expand Down
28 changes: 3 additions & 25 deletions quixstreams/models/topics/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from confluent_kafka.admin import (
AdminClient,
ConfigResource,
KafkaException, # type: ignore
)
from confluent_kafka.admin import (
NewTopic as ConfluentTopic, # type: ignore
KafkaException,
)
from confluent_kafka.admin import (
TopicMetadata as ConfluentTopicMetadata,
Expand All @@ -26,25 +23,6 @@
__all__ = ("TopicAdmin",)


def convert_topic_list(topics: List[Topic]) -> List[ConfluentTopic]:
"""
Converts `Topic`s to `ConfluentTopic`s as required for Confluent's
`AdminClient.create_topic()`.
:param topics: list of `Topic`s
:return: list of confluent_kafka `ConfluentTopic`s
"""
return [
ConfluentTopic(
topic=topic.name,
num_partitions=topic.config.num_partitions,
replication_factor=topic.config.replication_factor,
config=topic.config.extra_config,
)
for topic in topics
]


def confluent_topic_config(topic: str) -> ConfigResource:
return ConfigResource(2, topic)

Expand Down Expand Up @@ -207,12 +185,12 @@ def create_topics(
for topic in topics_to_create:
logger.info(
f'Creating a new topic "{topic.name}" '
f'with config: "{topic.config.as_dict()}"'
f'with config: "{topic.config.as_dict() if topic.config is not None else {}}"'
)

self._finalize_create(
self.admin_client.create_topics(
convert_topic_list(topics_to_create),
[topic.as_newtopic() for topic in topics_to_create],
request_timeout=timeout,
),
finalize_timeout=finalize_timeout,
Expand Down
17 changes: 15 additions & 2 deletions quixstreams/models/topics/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TopicManager:
# Default topic params
default_num_partitions = 1
default_replication_factor = 1
default_extra_config = {}
default_extra_config: dict[str, str] = {}

# Max topic name length for the new topics
_max_topic_name_len = 255
Expand Down Expand Up @@ -207,9 +207,15 @@ def _get_source_topic_config(
:return: a TopicConfig
"""

topic_config = self._admin.inspect_topics([topic_name], timeout=timeout)[
topic_name
] or deepcopy(self._non_changelog_topics[topic_name].config)
]
if topic_config is None and topic_name in self._non_changelog_topics:
topic_config = deepcopy(self._non_changelog_topics[topic_name].config)

if topic_config is None:
raise RuntimeError(f"No configuration can be found for topic {topic_name}")

# Copy only certain configuration values from original topic
if extras_imports:
Expand Down Expand Up @@ -475,10 +481,17 @@ def validate_all_topics(self, timeout: Optional[float] = None):

for source_name in self._non_changelog_topics.keys():
source_cfg = actual_configs[source_name]
if source_cfg is None:
raise TopicNotFoundError(f"Topic {source_name} not found on the broker")

# For any changelog topics, validate the amount of partitions and
# replication factor match with the source topic
for changelog in self.changelog_topics.get(source_name, {}).values():
changelog_cfg = actual_configs[changelog.name]
if changelog_cfg is None:
raise TopicNotFoundError(
f"Topic {changelog_cfg} not found on the broker"
)

if changelog_cfg.num_partitions != source_cfg.num_partitions:
raise TopicConfigurationMismatch(
Expand Down
Loading

0 comments on commit 808dd79

Please sign in to comment.