diff --git a/pyproject.toml b/pyproject.toml index 12c7959e8..8b19808eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "quixstreams" -dynamic = ["version", "dependencies", "optional-dependencies"] +dynamic = ["version", "dependencies"] description = "Python library for building stream processing applications with Apache Kafka" license = {file = "LICENSE"} readme = "README.md" @@ -23,6 +23,15 @@ classifiers = [ "Programming Language :: Python :: 3", ] +[project.optional-dependencies] +all = [ + "protobuf" +] + +protobuf = [ + "protobuf" +] + [project.urls] Homepage = "https://github.com/quixio/quix-streams" diff --git a/quixstreams/models/serializers/protobuf.py b/quixstreams/models/serializers/protobuf.py new file mode 100644 index 000000000..ab8496a87 --- /dev/null +++ b/quixstreams/models/serializers/protobuf.py @@ -0,0 +1,56 @@ +from typing import Union, Mapping, Iterable, Dict + +from .base import Serializer, Deserializer, SerializationContext +from .exceptions import SerializationError + +from google.protobuf.message import Message, DecodeError, EncodeError +from google.protobuf.json_format import MessageToDict, ParseDict, ParseError + +__all__ = ("ProtobufSerializer", "ProtobufDeserializer") + + +class ProtobufSerializer(Serializer): + def __init__( + self, + msg_type: Message, + ): + """ + Serializer that returns data in protobuf format. + + :param msg_type: protobuf message class. + """ + super().__init__() + self._msg_type = msg_type + + def __call__(self, value: Dict, ctx: SerializationContext) -> Union[str, bytes]: + msg = self._msg_type() + + try: + return ParseDict(value, msg).SerializeToString(deterministic=True) + except (EncodeError, ParseError) as exc: + raise SerializationError(str(exc)) from exc + + +class ProtobufDeserializer(Deserializer): + def __init__( + self, + msg_type: Message, + ): + """ + Deserializer that parses protobuf data. + + :param msg_type: protobuf message class. + """ + super().__init__() + self._msg_type = msg_type + + def __call__( + self, value: bytes, ctx: SerializationContext + ) -> Union[Iterable[Mapping], Mapping]: + msg = self._msg_type() + + try: + msg.ParseFromString(value) + return MessageToDict(msg) + except DecodeError as exc: + raise SerializationError(str(exc)) from exc diff --git a/tests/requirements.txt b/tests/requirements.txt index ae7babc63..d5b42b16f 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,3 +3,4 @@ testcontainers==4.5.1; python_version >= '3.9' pytest requests>=2.32 docker>=7.1.0 # Required to use requests>=2.32 +protobuf>=5.27.2 diff --git a/tests/test_quixstreams/test_models/protobuf/test.proto b/tests/test_quixstreams/test_models/protobuf/test.proto new file mode 100644 index 000000000..d74a61084 --- /dev/null +++ b/tests/test_quixstreams/test_models/protobuf/test.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + +package test; + +message Test { + string name = 1; + int32 id = 2; +} diff --git a/tests/test_quixstreams/test_models/protobuf/test_pb2.py b/tests/test_quixstreams/test_models/protobuf/test_pb2.py new file mode 100644 index 000000000..e2930418c --- /dev/null +++ b/tests/test_quixstreams/test_models/protobuf/test_pb2.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: tests/test_quixstreams/test_models/protobuf/test.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n6tests/test_quixstreams/test_models/protobuf/test.proto\x12\x04test" \n\x04Test\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\x05\x62\x06proto3' +) + + +_TEST = DESCRIPTOR.message_types_by_name["Test"] +Test = _reflection.GeneratedProtocolMessageType( + "Test", + (_message.Message,), + { + "DESCRIPTOR": _TEST, + "__module__": "tests.test_quixstreams.test_models.protobuf.test_pb2", + # @@protoc_insertion_point(class_scope:test.Test) + }, +) +_sym_db.RegisterMessage(Test) + +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _TEST._serialized_start = 64 + _TEST._serialized_end = 96 +# @@protoc_insertion_point(module_scope) diff --git a/tests/test_quixstreams/test_models/test_serializers.py b/tests/test_quixstreams/test_models/test_serializers.py index ce3012bd4..1d0f5e88e 100644 --- a/tests/test_quixstreams/test_models/test_serializers.py +++ b/tests/test_quixstreams/test_models/test_serializers.py @@ -16,8 +16,14 @@ DoubleDeserializer, StringDeserializer, ) +from quixstreams.models.serializers.protobuf import ( + ProtobufSerializer, + ProtobufDeserializer, +) from .utils import int_to_bytes, float_to_bytes +from .protobuf.test_pb2 import Test + dummy_context = SerializationContext(topic="topic") @@ -34,6 +40,10 @@ class TestSerializers: (BytesSerializer(), b"abc", b"abc"), (JSONSerializer(), {"a": 123}, b'{"a":123}'), (JSONSerializer(), [1, 2, 3], b"[1,2,3]"), + (ProtobufSerializer(Test), {}, b""), + (ProtobufSerializer(Test), {"id": 3}, b"\x10\x03"), + (ProtobufSerializer(Test), {"name": "foo", "id": 2}, b"\n\x03foo\x10\x02"), + (ProtobufSerializer(Test), {"name": "foo"}, b"\n\x03foo"), ], ) def test_serialize_success(self, serializer: Serializer, value, expected): @@ -50,6 +60,7 @@ def test_serialize_success(self, serializer: Serializer, value, expected): (StringSerializer(), {"a": 123}), (JSONSerializer(), object()), (JSONSerializer(), complex(1, 2)), + (ProtobufSerializer(Test), {"bar": 3}), ], ) def test_serialize_error(self, serializer: Serializer, value): @@ -70,6 +81,14 @@ class TestDeserializers: (BytesDeserializer(), b"123123", b"123123"), (JSONDeserializer(), b"123123", 123123), (JSONDeserializer(), b'{"a":"b"}', {"a": "b"}), + ( + ProtobufDeserializer(Test), + b"\n\x03foo\x10\x02", + {"name": "foo", "id": 2}, + ), + (ProtobufDeserializer(Test), b"\n\x03foo", {"name": "foo"}), + (ProtobufDeserializer(Test), b"\x10\x03", {"id": 3}), + (ProtobufDeserializer(Test), b"", {}), ], ) def test_deserialize_no_column_name_success( @@ -84,6 +103,7 @@ def test_deserialize_no_column_name_success( (IntegerDeserializer(), b'{"abc": "abc"}'), (DoubleDeserializer(), b"abc"), (JSONDeserializer(), b"{"), + (ProtobufDeserializer(Test), b"\n\x03foo\x10\x02\x13"), ], ) def test_deserialize_error(self, deserializer: Deserializer, value):