diff --git a/pyproject.toml b/pyproject.toml index 4f4cd5f62..c6beab872 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,6 +114,7 @@ ignore_missing_imports = true module = [ "quixstreams.sinks.community.*", "quixstreams.sources.community.*", + "quixstreams.models.serializers.quix.*", ] ignore_errors = true @@ -121,7 +122,6 @@ ignore_errors = true module = [ "quixstreams.core.*", "quixstreams.dataframe.*", - "quixstreams.models.*", "quixstreams.platforms.*", "quixstreams.rowproducer.*" ] diff --git a/quixstreams/models/messagecontext.py b/quixstreams/models/messagecontext.py index 351fe9157..0a09f305b 100644 --- a/quixstreams/models/messagecontext.py +++ b/quixstreams/models/messagecontext.py @@ -21,7 +21,7 @@ class MessageContext: def __init__( self, topic: str, - partition: int, + partition: Optional[int], offset: int, size: int, leader_epoch: Optional[int] = None, @@ -37,7 +37,7 @@ def topic(self) -> str: return self._topic @property - def partition(self) -> int: + def partition(self) -> Optional[int]: return self._partition @property diff --git a/quixstreams/models/rows.py b/quixstreams/models/rows.py index 7633a9a33..73741ea38 100644 --- a/quixstreams/models/rows.py +++ b/quixstreams/models/rows.py @@ -32,7 +32,7 @@ def topic(self) -> str: return self.context.topic @property - def partition(self) -> int: + def partition(self) -> Optional[int]: return self.context.partition @property diff --git a/quixstreams/models/serializers/avro.py b/quixstreams/models/serializers/avro.py index 712b64a03..87cd9da16 100644 --- a/quixstreams/models/serializers/avro.py +++ b/quixstreams/models/serializers/avro.py @@ -13,7 +13,7 @@ from fastavro import parse_schema, schemaless_reader, schemaless_writer from fastavro.types import Schema -from .base import Deserializer, SerializationContext, Serializer +from .base import DeserializationContext, Deserializer, SerializationContext, Serializer from .exceptions import SerializationError from .schema_registry import ( SchemaRegistryClientConfig, @@ -148,6 +148,12 @@ def __init__( ) super().__init__() + + if schema is None and schema_registry_client_config is None: + raise TypeError( + "One of `schema` or `schema_registry_client_config` is required" + ) + self._schema = parse_schema(schema) if schema else None self._reader_schema = parse_schema(reader_schema) if reader_schema else None self._return_record_name = return_record_name @@ -167,24 +173,26 @@ def __init__( ) def __call__( - self, value: bytes, ctx: SerializationContext + self, value: bytes, ctx: DeserializationContext ) -> Union[Iterable[Mapping], Mapping]: if self._schema_registry_deserializer is not None: try: return self._schema_registry_deserializer(value, ctx) except (SchemaRegistryError, _SerializationError, EOFError) as exc: raise SerializationError(str(exc)) from exc + elif self._schema is not None: + try: + return schemaless_reader( # type: ignore + BytesIO(value), + self._schema, + reader_schema=self._reader_schema, + return_record_name=self._return_record_name, + return_record_name_override=self._return_record_name_override, + return_named_type=self._return_named_type, + return_named_type_override=self._return_named_type_override, + handle_unicode_errors=self._handle_unicode_errors, + ) + except EOFError as exc: + raise SerializationError(str(exc)) from exc - try: - return schemaless_reader( - BytesIO(value), - self._schema, - reader_schema=self._reader_schema, - return_record_name=self._return_record_name, - return_record_name_override=self._return_record_name_override, - return_named_type=self._return_named_type, - return_named_type_override=self._return_named_type_override, - handle_unicode_errors=self._handle_unicode_errors, - ) - except EOFError as exc: - raise SerializationError(str(exc)) from exc + raise SerializationError("no schema found") diff --git a/quixstreams/models/serializers/base.py b/quixstreams/models/serializers/base.py index 64a49f414..326362736 100644 --- a/quixstreams/models/serializers/base.py +++ b/quixstreams/models/serializers/base.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Union +from typing import Any, Optional, Union from confluent_kafka.serialization import ( MessageField, @@ -9,10 +9,11 @@ ) from typing_extensions import Literal, TypeAlias -from ..types import HeadersMapping, KafkaHeaders +from ..types import Headers, HeadersMapping, KafkaHeaders __all__ = ( "SerializationContext", + "DeserializationContext", "MessageField", "Deserializer", "Serializer", @@ -23,9 +24,9 @@ class SerializationContext(_SerializationContext): """ - Provides additional context for message serialization/deserialization. + Provides additional context for message serialization. - Every `Serializer` and `Deserializer` receives an instance of `SerializationContext` + Every `Serializer` receives an instance of `SerializationContext` """ __slots__ = ("topic", "field", "headers") @@ -33,8 +34,28 @@ class SerializationContext(_SerializationContext): def __init__( self, topic: str, - field: MessageField, - headers: KafkaHeaders = None, + field: str, + headers: Optional[Headers] = None, + ) -> None: + self.topic = topic + self.field = field + self.headers = headers + + +class DeserializationContext(_SerializationContext): + """ + Provides additional context for message deserialization. + + Every `Deserializer` receives an instance of `DeserializationContext` + """ + + __slots__ = ("topic", "field", "headers") + + def __init__( + self, + topic: str, + field: str, + headers: KafkaHeaders, ) -> None: self.topic = topic self.field = field diff --git a/quixstreams/models/serializers/json.py b/quixstreams/models/serializers/json.py index 1646217a2..f6290b17f 100644 --- a/quixstreams/models/serializers/json.py +++ b/quixstreams/models/serializers/json.py @@ -1,5 +1,5 @@ import json -from typing import Any, Callable, Iterable, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, Union from confluent_kafka.schema_registry import SchemaRegistryClient, SchemaRegistryError from confluent_kafka.schema_registry.json_schema import ( @@ -10,7 +10,6 @@ ) from confluent_kafka.serialization import SerializationError as _SerializationError from jsonschema import Draft202012Validator, ValidationError -from jsonschema.protocols import Validator from quixstreams.utils.json import ( dumps as default_dumps, @@ -19,13 +18,16 @@ loads as default_loads, ) -from .base import Deserializer, SerializationContext, Serializer +from .base import DeserializationContext, Deserializer, SerializationContext, Serializer from .exceptions import SerializationError from .schema_registry import ( SchemaRegistryClientConfig, SchemaRegistrySerializationConfig, ) +if TYPE_CHECKING: + from jsonschema.validators import _Validator + __all__ = ("JSONSerializer", "JSONDeserializer") @@ -34,7 +36,7 @@ def __init__( self, dumps: Callable[[Any], Union[str, bytes]] = default_dumps, schema: Optional[Mapping] = None, - validator: Optional[Validator] = None, + validator: Optional["_Validator"] = None, schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None, schema_registry_serialization_config: Optional[ SchemaRegistrySerializationConfig @@ -121,7 +123,7 @@ def __init__( self, loads: Callable[[Union[bytes, bytearray]], Any] = default_loads, schema: Optional[Mapping] = None, - validator: Optional[Validator] = None, + validator: Optional["_Validator"] = None, schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None, ): """ @@ -164,7 +166,7 @@ def __init__( ) def __call__( - self, value: bytes, ctx: SerializationContext + self, value: bytes, ctx: DeserializationContext ) -> Union[Iterable[Mapping], Mapping]: if self._schema_registry_deserializer is not None: try: diff --git a/quixstreams/models/serializers/protobuf.py b/quixstreams/models/serializers/protobuf.py index bc5c5646e..78781bdf5 100644 --- a/quixstreams/models/serializers/protobuf.py +++ b/quixstreams/models/serializers/protobuf.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, Mapping, Optional, Union +from typing import Dict, Iterable, Mapping, Optional, Type, Union from confluent_kafka.schema_registry import SchemaRegistryClient, SchemaRegistryError from confluent_kafka.schema_registry.protobuf import ( @@ -11,7 +11,7 @@ from google.protobuf.json_format import MessageToDict, ParseDict, ParseError from google.protobuf.message import DecodeError, EncodeError, Message -from .base import Deserializer, SerializationContext, Serializer +from .base import DeserializationContext, Deserializer, SerializationContext, Serializer from .exceptions import SerializationError from .schema_registry import ( SchemaRegistryClientConfig, @@ -24,7 +24,7 @@ class ProtobufSerializer(Serializer): def __init__( self, - msg_type: Message, + msg_type: Type[Message], deterministic: bool = False, ignore_unknown_fields: bool = False, schema_registry_client_config: Optional[SchemaRegistryClientConfig] = None, @@ -110,7 +110,7 @@ def __call__( class ProtobufDeserializer(Deserializer): def __init__( self, - msg_type: Message, + msg_type: Type[Message], use_integers_for_enums: bool = False, preserving_proto_field_name: bool = False, to_dict: bool = True, @@ -168,7 +168,7 @@ def __init__( ) def __call__( - self, value: bytes, ctx: SerializationContext + self, value: bytes, ctx: DeserializationContext ) -> Union[Iterable[Mapping], Mapping, Message]: if self._schema_registry_deserializer is not None: try: diff --git a/quixstreams/models/topics/admin.py b/quixstreams/models/topics/admin.py index 7719a05a8..17bbc3708 100644 --- a/quixstreams/models/topics/admin.py +++ b/quixstreams/models/topics/admin.py @@ -7,10 +7,7 @@ from confluent_kafka.admin import ( AdminClient, ConfigResource, - KafkaException, # type: ignore -) -from confluent_kafka.admin import ( - NewTopic as ConfluentTopic, # type: ignore + KafkaException, ) from confluent_kafka.admin import ( TopicMetadata as ConfluentTopicMetadata, @@ -26,25 +23,6 @@ __all__ = ("TopicAdmin",) -def convert_topic_list(topics: List[Topic]) -> List[ConfluentTopic]: - """ - Converts `Topic`s to `ConfluentTopic`s as required for Confluent's - `AdminClient.create_topic()`. - - :param topics: list of `Topic`s - :return: list of confluent_kafka `ConfluentTopic`s - """ - return [ - ConfluentTopic( - topic=topic.name, - num_partitions=topic.config.num_partitions, - replication_factor=topic.config.replication_factor, - config=topic.config.extra_config, - ) - for topic in topics - ] - - def confluent_topic_config(topic: str) -> ConfigResource: return ConfigResource(2, topic) @@ -207,12 +185,12 @@ def create_topics( for topic in topics_to_create: logger.info( f'Creating a new topic "{topic.name}" ' - f'with config: "{topic.config.as_dict()}"' + f'with config: "{topic.config.as_dict() if topic.config is not None else {}}"' ) self._finalize_create( self.admin_client.create_topics( - convert_topic_list(topics_to_create), + [topic.as_confluent_topic() for topic in topics_to_create], request_timeout=timeout, ), finalize_timeout=finalize_timeout, diff --git a/quixstreams/models/topics/manager.py b/quixstreams/models/topics/manager.py index fbaba21d5..9ef58df33 100644 --- a/quixstreams/models/topics/manager.py +++ b/quixstreams/models/topics/manager.py @@ -40,7 +40,7 @@ class TopicManager: # Default topic params default_num_partitions = 1 default_replication_factor = 1 - default_extra_config = {} + default_extra_config: dict[str, str] = {} # Max topic name length for the new topics _max_topic_name_len = 255 @@ -211,6 +211,9 @@ def _get_source_topic_config( topic_name ] or deepcopy(self._non_changelog_topics[topic_name].config) + if topic_config is None: + raise RuntimeError(f"No configuration can be found for topic {topic_name}") + # Copy only certain configuration values from original topic if extras_imports: topic_config.extra_config = { @@ -475,10 +478,17 @@ def validate_all_topics(self, timeout: Optional[float] = None): for source_name in self._non_changelog_topics.keys(): source_cfg = actual_configs[source_name] + if source_cfg is None: + raise TopicNotFoundError(f"Topic {source_name} not found on the broker") + # For any changelog topics, validate the amount of partitions and # replication factor match with the source topic for changelog in self.changelog_topics.get(source_name, {}).values(): changelog_cfg = actual_configs[changelog.name] + if changelog_cfg is None: + raise TopicNotFoundError( + f"Topic {changelog_cfg} not found on the broker" + ) if changelog_cfg.num_partitions != source_cfg.num_partitions: raise TopicConfigurationMismatch( diff --git a/quixstreams/models/topics/topic.py b/quixstreams/models/topics/topic.py index f07f7d56e..6d3b82a95 100644 --- a/quixstreams/models/topics/topic.py +++ b/quixstreams/models/topics/topic.py @@ -2,6 +2,8 @@ import logging from typing import Any, Callable, List, Optional, Union +from confluent_kafka.admin import NewTopic as ConfluentTopic + from quixstreams.models.messagecontext import MessageContext from quixstreams.models.messages import KafkaMessage from quixstreams.models.rows import Row @@ -10,6 +12,7 @@ SERIALIZERS, BytesDeserializer, BytesSerializer, + DeserializationContext, Deserializer, DeserializerIsNotProvidedError, DeserializerType, @@ -48,7 +51,7 @@ class TopicConfig: num_partitions: int replication_factor: int - extra_config: dict = dataclasses.field(default_factory=dict) + extra_config: dict[str, str] = dataclasses.field(default_factory=dict) def as_dict(self): return dataclasses.asdict(self) @@ -137,6 +140,23 @@ def __clone__( timestamp_extractor=timestamp_extractor or self._timestamp_extractor, ) + def as_confluent_topic(self) -> ConfluentTopic: + """ + Converts `Topic`s to `NewTopic`s as required for Confluent's + `AdminClient.create_topic()`. + + :return: confluent_kafka `NewTopic`s + """ + if self.config is None: + return ConfluentTopic(topic=self.name) + + return ConfluentTopic( + topic=self.name, + num_partitions=self.config.num_partitions, + replication_factor=self.config.replication_factor, + config=self.config.extra_config, + ) + def row_serialize(self, row: Row, key: Any) -> KafkaMessage: """ Serialize Row to a Kafka message structure @@ -200,7 +220,7 @@ def row_deserialize( if (key_bytes := message.key()) is None: key_deserialized = None else: - key_ctx = SerializationContext( + key_ctx = DeserializationContext( topic=message.topic(), field=MessageField.KEY, headers=headers ) key_deserialized = self._key_deserializer(value=key_bytes, ctx=key_ctx) @@ -208,7 +228,7 @@ def row_deserialize( if (value_bytes := message.value()) is None: value_deserialized = None else: - value_ctx = SerializationContext( + value_ctx = DeserializationContext( topic=message.topic(), field=MessageField.VALUE, headers=headers ) try: @@ -223,7 +243,7 @@ def row_deserialize( message.partition(), message.offset(), ) - return + return None timestamp_type, timestamp_ms = message.timestamp() message_context = MessageContext( @@ -234,7 +254,7 @@ def row_deserialize( leader_epoch=message.leader_epoch(), ) - if self._value_deserializer.split_values: + if value_deserialized is not None and self._value_deserializer.split_values: # The expected value from this serializer is Iterable and each item # should be processed as a separate message rows = [] @@ -302,7 +322,7 @@ def serialize( def deserialize(self, message: ConfluentKafkaMessageProto): if (key := message.key()) is not None: if self._key_deserializer: - key_ctx = SerializationContext( + key_ctx = DeserializationContext( topic=message.topic(), field=MessageField.KEY, headers=message.headers(), @@ -314,7 +334,7 @@ def deserialize(self, message: ConfluentKafkaMessageProto): ) if (value := message.value()) is not None: if self._value_deserializer: - value_ctx = SerializationContext( + value_ctx = DeserializationContext( topic=message.topic(), field=MessageField.VALUE, headers=message.headers(), diff --git a/quixstreams/models/topics/utils.py b/quixstreams/models/topics/utils.py index d3d8452cf..4064a4740 100644 --- a/quixstreams/models/topics/utils.py +++ b/quixstreams/models/topics/utils.py @@ -1,5 +1,8 @@ +from typing import List + from quixstreams.models.types import ( HeadersMapping, + HeadersTuple, HeadersTuples, KafkaHeaders, ) @@ -25,12 +28,12 @@ def merge_headers(original: KafkaHeaders, other: HeadersMapping) -> HeadersTuple # Make a shallow copy of "other" to pop keys from it other = other.copy() - new_headers = [] + new_headers: List[HeadersTuple] = [] # Iterate over original headers and put them to a new list with values from # the "other" dict if the key is found for header, value in original: if header in other: - value = other.pop(header) + continue new_headers.append((header, value)) # Append the new headers to the list new_headers.extend(other.items()) diff --git a/quixstreams/models/types.py b/quixstreams/models/types.py index f34816521..f4b7e9659 100644 --- a/quixstreams/models/types.py +++ b/quixstreams/models/types.py @@ -1,4 +1,4 @@ -from typing import List, Mapping, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union from typing_extensions import Protocol @@ -6,8 +6,9 @@ MessageValue = Union[str, bytes] HeadersValue = Union[str, bytes] -HeadersMapping = Mapping[str, HeadersValue] -HeadersTuples = Sequence[Tuple[str, HeadersValue]] +HeadersMapping = dict[str, HeadersValue] +HeadersTuple = Tuple[str, HeadersValue] +HeadersTuples = Sequence[HeadersTuple] Headers = Union[HeadersTuples, HeadersMapping] KafkaHeaders = Optional[List[Tuple[str, bytes]]]