Skip to content

Commit

Permalink
mypy: make quixstreams.models.* pass type checks
Browse files Browse the repository at this point in the history
  • Loading branch information
quentin-quix committed Dec 4, 2024
1 parent 01de03e commit ed463c7
Show file tree
Hide file tree
Showing 12 changed files with 117 additions and 74 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,14 @@ ignore_missing_imports = true
module = [
"quixstreams.sinks.community.*",
"quixstreams.sources.community.*",
"quixstreams.models.serializers.quix.*",
]
ignore_errors = true

[[tool.mypy.overrides]]
module = [
"quixstreams.core.*",
"quixstreams.dataframe.*",
"quixstreams.models.*",
"quixstreams.platforms.*",
"quixstreams.rowproducer.*"
]
Expand Down
4 changes: 2 additions & 2 deletions quixstreams/models/messagecontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class MessageContext:
def __init__(
self,
topic: str,
partition: int,
partition: Optional[int],
offset: int,
size: int,
leader_epoch: Optional[int] = None,
Expand All @@ -37,7 +37,7 @@ def topic(self) -> str:
return self._topic

@property
def partition(self) -> int:
def partition(self) -> Optional[int]:
return self._partition

@property
Expand Down
2 changes: 1 addition & 1 deletion quixstreams/models/rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def topic(self) -> str:
return self.context.topic

@property
def partition(self) -> int:
def partition(self) -> Optional[int]:
return self.context.partition

@property
Expand Down
38 changes: 23 additions & 15 deletions quixstreams/models/serializers/avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from fastavro import parse_schema, schemaless_reader, schemaless_writer
from fastavro.types import Schema

from .base import Deserializer, SerializationContext, Serializer
from .base import DeserializationContext, Deserializer, SerializationContext, Serializer
from .exceptions import SerializationError
from .schema_registry import (
SchemaRegistryClientConfig,
Expand Down Expand Up @@ -148,6 +148,12 @@ def __init__(
)

super().__init__()

if schema is None and schema_registry_client_config is None:
raise TypeError(
"One of `schema` or `schema_registry_client_config` is required"
)

self._schema = parse_schema(schema) if schema else None
self._reader_schema = parse_schema(reader_schema) if reader_schema else None
self._return_record_name = return_record_name
Expand All @@ -167,24 +173,26 @@ def __init__(
)

def __call__(
self, value: bytes, ctx: SerializationContext
self, value: bytes, ctx: DeserializationContext
) -> Union[Iterable[Mapping], Mapping]:
if self._schema_registry_deserializer is not None:
try:
return self._schema_registry_deserializer(value, ctx)
except (SchemaRegistryError, _SerializationError, EOFError) as exc:
raise SerializationError(str(exc)) from exc
elif self._schema is not None:
try:
return schemaless_reader( # type: ignore
BytesIO(value),
self._schema,
reader_schema=self._reader_schema,
return_record_name=self._return_record_name,
return_record_name_override=self._return_record_name_override,
return_named_type=self._return_named_type,
return_named_type_override=self._return_named_type_override,
handle_unicode_errors=self._handle_unicode_errors,
)
except EOFError as exc:
raise SerializationError(str(exc)) from exc

try:
return schemaless_reader(
BytesIO(value),
self._schema,
reader_schema=self._reader_schema,
return_record_name=self._return_record_name,
return_record_name_override=self._return_record_name_override,
return_named_type=self._return_named_type,
return_named_type_override=self._return_named_type_override,
handle_unicode_errors=self._handle_unicode_errors,
)
except EOFError as exc:
raise SerializationError(str(exc)) from exc
raise SerializationError("no schema found")
33 changes: 27 additions & 6 deletions quixstreams/models/serializers/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Any, Union
from typing import Any, Optional, Union

from confluent_kafka.serialization import (
MessageField,
Expand All @@ -9,10 +9,11 @@
)
from typing_extensions import Literal, TypeAlias

from ..types import HeadersMapping, KafkaHeaders
from ..types import Headers, HeadersMapping, KafkaHeaders

__all__ = (
"SerializationContext",
"DeserializationContext",
"MessageField",
"Deserializer",
"Serializer",
Expand All @@ -23,18 +24,38 @@

class SerializationContext(_SerializationContext):
"""
Provides additional context for message serialization/deserialization.
Provides additional context for message serialization.
Every `Serializer` and `Deserializer` receives an instance of `SerializationContext`
Every `Serializer` receives an instance of `SerializationContext`
"""

__slots__ = ("topic", "field", "headers")

def __init__(
self,
topic: str,
field: MessageField,
headers: KafkaHeaders = None,
field: str,
headers: Optional[Headers] = None,
) -> None:
self.topic = topic
self.field = field
self.headers = headers


class DeserializationContext(_SerializationContext):
"""
Provides additional context for message deserialization.
Every `Deserializer` receives an instance of `DeserializationContext`
"""

__slots__ = ("topic", "field", "headers")

def __init__(
self,
topic: str,
field: str,
headers: KafkaHeaders,
) -> None:
self.topic = topic
self.field = field
Expand Down
14 changes: 8 additions & 6 deletions quixstreams/models/serializers/json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Callable, Iterable, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, Union

from confluent_kafka.schema_registry import SchemaRegistryClient, SchemaRegistryError
from confluent_kafka.schema_registry.json_schema import (
Expand All @@ -10,7 +10,6 @@
)
from confluent_kafka.serialization import SerializationError as _SerializationError
from jsonschema import Draft202012Validator, ValidationError
from jsonschema.protocols import Validator

from quixstreams.utils.json import (
dumps as default_dumps,
Expand All @@ -19,13 +18,16 @@
loads as default_loads,
)

from .base import Deserializer, SerializationContext, Serializer
from .base import DeserializationContext, Deserializer, SerializationContext, Serializer
from .exceptions import SerializationError
from .schema_registry import (
SchemaRegistryClientConfig,
SchemaRegistrySerializationConfig,
)

if TYPE_CHECKING:
from jsonschema.validators import _Validator

__all__ = ("JSONSerializer", "JSONDeserializer")


Expand All @@ -34,7 +36,7 @@ def __init__(
self,
dumps: Callable[[Any], Union[str, bytes]] = default_dumps,
schema: Optional[Mapping] = None,
validator: Optional[Validator] = None,
validator: Optional["_Validator"] = None,
schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None,
schema_registry_serialization_config: Optional[
SchemaRegistrySerializationConfig
Expand Down Expand Up @@ -121,7 +123,7 @@ def __init__(
self,
loads: Callable[[Union[bytes, bytearray]], Any] = default_loads,
schema: Optional[Mapping] = None,
validator: Optional[Validator] = None,
validator: Optional["_Validator"] = None,
schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None,
):
"""
Expand Down Expand Up @@ -164,7 +166,7 @@ def __init__(
)

def __call__(
self, value: bytes, ctx: SerializationContext
self, value: bytes, ctx: DeserializationContext
) -> Union[Iterable[Mapping], Mapping]:
if self._schema_registry_deserializer is not None:
try:
Expand Down
10 changes: 5 additions & 5 deletions quixstreams/models/serializers/protobuf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, Mapping, Optional, Union
from typing import Dict, Iterable, Mapping, Optional, Type, Union

from confluent_kafka.schema_registry import SchemaRegistryClient, SchemaRegistryError
from confluent_kafka.schema_registry.protobuf import (
Expand All @@ -11,7 +11,7 @@
from google.protobuf.json_format import MessageToDict, ParseDict, ParseError
from google.protobuf.message import DecodeError, EncodeError, Message

from .base import Deserializer, SerializationContext, Serializer
from .base import DeserializationContext, Deserializer, SerializationContext, Serializer
from .exceptions import SerializationError
from .schema_registry import (
SchemaRegistryClientConfig,
Expand All @@ -24,7 +24,7 @@
class ProtobufSerializer(Serializer):
def __init__(
self,
msg_type: Message,
msg_type: Type[Message],
deterministic: bool = False,
ignore_unknown_fields: bool = False,
schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None,
Expand Down Expand Up @@ -110,7 +110,7 @@ def __call__(
class ProtobufDeserializer(Deserializer):
def __init__(
self,
msg_type: Message,
msg_type: Type[Message],
use_integers_for_enums: bool = False,
preserving_proto_field_name: bool = False,
to_dict: bool = True,
Expand Down Expand Up @@ -168,7 +168,7 @@ def __init__(
)

def __call__(
self, value: bytes, ctx: SerializationContext
self, value: bytes, ctx: DeserializationContext
) -> Union[Iterable[Mapping], Mapping, Message]:
if self._schema_registry_deserializer is not None:
try:
Expand Down
28 changes: 3 additions & 25 deletions quixstreams/models/topics/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from confluent_kafka.admin import (
AdminClient,
ConfigResource,
KafkaException, # type: ignore
)
from confluent_kafka.admin import (
NewTopic as ConfluentTopic, # type: ignore
KafkaException,
)
from confluent_kafka.admin import (
TopicMetadata as ConfluentTopicMetadata,
Expand All @@ -26,25 +23,6 @@
__all__ = ("TopicAdmin",)


def convert_topic_list(topics: List[Topic]) -> List[ConfluentTopic]:
"""
Converts `Topic`s to `ConfluentTopic`s as required for Confluent's
`AdminClient.create_topic()`.
:param topics: list of `Topic`s
:return: list of confluent_kafka `ConfluentTopic`s
"""
return [
ConfluentTopic(
topic=topic.name,
num_partitions=topic.config.num_partitions,
replication_factor=topic.config.replication_factor,
config=topic.config.extra_config,
)
for topic in topics
]


def confluent_topic_config(topic: str) -> ConfigResource:
return ConfigResource(2, topic)

Expand Down Expand Up @@ -207,12 +185,12 @@ def create_topics(
for topic in topics_to_create:
logger.info(
f'Creating a new topic "{topic.name}" '
f'with config: "{topic.config.as_dict()}"'
f'with config: "{topic.config.as_dict() if topic.config is not None else {}}"'
)

self._finalize_create(
self.admin_client.create_topics(
convert_topic_list(topics_to_create),
[topic.as_confluent_topic() for topic in topics_to_create],
request_timeout=timeout,
),
finalize_timeout=finalize_timeout,
Expand Down
12 changes: 11 additions & 1 deletion quixstreams/models/topics/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TopicManager:
# Default topic params
default_num_partitions = 1
default_replication_factor = 1
default_extra_config = {}
default_extra_config: dict[str, str] = {}

# Max topic name length for the new topics
_max_topic_name_len = 255
Expand Down Expand Up @@ -211,6 +211,9 @@ def _get_source_topic_config(
topic_name
] or deepcopy(self._non_changelog_topics[topic_name].config)

if topic_config is None:
raise RuntimeError(f"No configuration can be found for topic {topic_name}")

# Copy only certain configuration values from original topic
if extras_imports:
topic_config.extra_config = {
Expand Down Expand Up @@ -475,10 +478,17 @@ def validate_all_topics(self, timeout: Optional[float] = None):

for source_name in self._non_changelog_topics.keys():
source_cfg = actual_configs[source_name]
if source_cfg is None:
raise TopicNotFoundError(f"Topic {source_name} not found on the broker")

# For any changelog topics, validate the amount of partitions and
# replication factor match with the source topic
for changelog in self.changelog_topics.get(source_name, {}).values():
changelog_cfg = actual_configs[changelog.name]
if changelog_cfg is None:
raise TopicNotFoundError(
f"Topic {changelog_cfg} not found on the broker"
)

if changelog_cfg.num_partitions != source_cfg.num_partitions:
raise TopicConfigurationMismatch(
Expand Down
Loading

0 comments on commit ed463c7

Please sign in to comment.