From 56fdf871a2c62fe540a5b092748f9063e48ed181 Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Mon, 22 Jul 2024 10:40:33 +0200 Subject: [PATCH] Avro serdes support --- pyproject.toml | 9 +- quixstreams/models/serializers/avro.py | 111 ++++++++++++++++++ tests/requirements.txt | 1 + .../test_models/test_serializers.py | 31 +++++ 4 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 quixstreams/models/serializers/avro.py diff --git a/pyproject.toml b/pyproject.toml index 12c7959e8..b868c6c72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "quixstreams" -dynamic = ["version", "dependencies", "optional-dependencies"] +dynamic = ["version", "dependencies"] description = "Python library for building stream processing applications with Apache Kafka" license = {file = "LICENSE"} readme = "README.md" @@ -26,6 +26,13 @@ classifiers = [ [project.urls] Homepage = "https://github.com/quixio/quix-streams" +[project.optional-dependencies] +all = [ + "fastavro>=1.8,<2.0" +] + +avro = ["fastavro>=1.8,<2.0"] + [tool.setuptools.packages.find] include = ["quixstreams*"] exclude = ["tests*", "docs*", "examples*"] diff --git a/quixstreams/models/serializers/avro.py b/quixstreams/models/serializers/avro.py new file mode 100644 index 000000000..2387c472c --- /dev/null +++ b/quixstreams/models/serializers/avro.py @@ -0,0 +1,111 @@ +from typing import Union, Mapping, Optional, Any, Iterable + +from io import BytesIO + +from fastavro import schemaless_reader, schemaless_writer, parse_schema +from fastavro.types import Schema + +from .base import Serializer, Deserializer, SerializationContext +from .exceptions import SerializationError + +__all__ = ("AvroSerializer", "AvroDeserializer") + + +class AvroSerializer(Serializer): + def __init__( + self, + schema: Schema, + strict: bool = False, + strict_allow_default: bool = False, + disable_tuple_notation: bool = False, + ): + """ + Serializer that returns data in Avro format. + + For more information see fastavro [schemaless_writer](https://fastavro.readthedocs.io/en/latest/writer.html#fastavro._write_py.schemaless_writer) method. + + :param schema: The avro schema. + :param strict: If set to True, an error will be raised if records do not contain exactly the same fields that the schema states. + Default - `False` + :param strict_allow_default: If set to True, an error will be raised if records do not contain exactly the same fields that the schema states unless it is a missing field that has a default value in the schema. + Default - `False` + :param disable_tuple_notation: If set to True, tuples will not be treated as a special case. Therefore, using a tuple to indicate the type of a record will not work. + Default - `False` + """ + self._schema = parse_schema(schema) + self._strict = strict + self._strict_allow_default = strict_allow_default + self._disable_tuple_notation = disable_tuple_notation + + def __call__(self, value: Any, ctx: SerializationContext) -> bytes: + data = BytesIO() + + with BytesIO() as data: + try: + schemaless_writer( + data, + self._schema, + value, + strict=self._strict, + strict_allow_default=self._strict_allow_default, + disable_tuple_notation=self._disable_tuple_notation, + ) + except ValueError as exc: + raise SerializationError(str(exc)) from exc + + return data.getvalue() + + +class AvroDeserializer(Deserializer): + def __init__( + self, + schema: Schema, + reader_schema: Optional[Schema] = None, + return_record_name: bool = False, + return_record_name_override: bool = False, + return_named_type: bool = False, + return_named_type_override: bool = False, + handle_unicode_errors: str = "strict", + ): + """ + Deserializer that parses data from Avro. + + For more information see fastavro [schemaless_reader](https://fastavro.readthedocs.io/en/latest/reader.html#fastavro._read_py.schemaless_reader) method. + + :param schema: The Avro schema. + :param reader_schema: If the schema has changed since being written then the new schema can be given to allow for schema migration. + Default - `None` + :param return_record_name: If true, when reading a union of records, the result will be a tuple where the first value is the name of the record and the second value is the record itself. + Default - `False` + :param return_record_name_override: If true, this will modify the behavior of return_record_name so that the record name is only returned for unions where there is more than one record. For unions that only have one record, this option will make it so that the record is returned by itself, not a tuple with the name. + Default - `False` + :param return_named_type: If true, when reading a union of named types, the result will be a tuple where the first value is the name of the type and the second value is the record itself NOTE: Using this option will ignore return_record_name and return_record_name_override. + Default - `False` + :param return_named_type_override: If true, this will modify the behavior of return_named_type so that the named type is only returned for unions where there is more than one named type. For unions that only have one named type, this option will make it so that the named type is returned by itself, not a tuple with the name. + Default - `False` + :param handle_unicode_errors: Should be set to a valid string that can be used in the errors argument of the string decode() function. + Default - `"strict"` + """ + super().__init__() + self._schema = parse_schema(schema) + self._reader_schema = parse_schema(reader_schema) if reader_schema else None + self._return_record_name = return_record_name + self._return_record_name_override = return_record_name_override + self._return_named_type = return_named_type + self._return_named_type_override = return_named_type_override + + def __call__( + self, value: bytes, ctx: SerializationContext + ) -> Union[Iterable[Mapping], Mapping]: + try: + return schemaless_reader( + BytesIO(value), + self._schema, + reader_schema=self._reader_schema, + return_record_name=self._return_record_name, + return_record_name_override=self._return_record_name_override, + return_named_type=self._return_named_type, + return_named_type_override=self._return_named_type_override, + ) + except EOFError as exc: + raise SerializationError(str(exc)) from exc diff --git a/tests/requirements.txt b/tests/requirements.txt index ae7babc63..65b64e66c 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,3 +3,4 @@ testcontainers==4.5.1; python_version >= '3.9' pytest requests>=2.32 docker>=7.1.0 # Required to use requests>=2.32 +fastavro>=1.8,<2.0 \ No newline at end of file diff --git a/tests/test_quixstreams/test_models/test_serializers.py b/tests/test_quixstreams/test_models/test_serializers.py index ce3012bd4..d3d090a2d 100644 --- a/tests/test_quixstreams/test_models/test_serializers.py +++ b/tests/test_quixstreams/test_models/test_serializers.py @@ -18,6 +18,17 @@ ) from .utils import int_to_bytes, float_to_bytes +from quixstreams.models.serializers.avro import AvroDeserializer, AvroSerializer + +AVRO_TEST_SCHEMA = { + "type": "record", + "name": "testschema", + "fields": [ + {"name": "name", "type": "string"}, + {"name": "id", "type": "int", "default": 0}, + ], +} + dummy_context = SerializationContext(topic="topic") @@ -34,6 +45,12 @@ class TestSerializers: (BytesSerializer(), b"abc", b"abc"), (JSONSerializer(), {"a": 123}, b'{"a":123}'), (JSONSerializer(), [1, 2, 3], b"[1,2,3]"), + ( + AvroSerializer(AVRO_TEST_SCHEMA), + {"name": "foo", "id": 123}, + b"\x06foo\xf6\x01", + ), + (AvroSerializer(AVRO_TEST_SCHEMA), {"name": "foo"}, b"\x06foo\x00"), ], ) def test_serialize_success(self, serializer: Serializer, value, expected): @@ -50,6 +67,9 @@ def test_serialize_success(self, serializer: Serializer, value, expected): (StringSerializer(), {"a": 123}), (JSONSerializer(), object()), (JSONSerializer(), complex(1, 2)), + (AvroSerializer(AVRO_TEST_SCHEMA), {"foo": "foo", "id": 123}), + (AvroSerializer(AVRO_TEST_SCHEMA), {"id": 123}), + (AvroSerializer(AVRO_TEST_SCHEMA, strict=True), {"name": "foo"}), ], ) def test_serialize_error(self, serializer: Serializer, value): @@ -70,6 +90,16 @@ class TestDeserializers: (BytesDeserializer(), b"123123", b"123123"), (JSONDeserializer(), b"123123", 123123), (JSONDeserializer(), b'{"a":"b"}', {"a": "b"}), + ( + AvroDeserializer(AVRO_TEST_SCHEMA), + b"\x06foo\xf6\x01", + {"name": "foo", "id": 123}, + ), + ( + AvroDeserializer(AVRO_TEST_SCHEMA), + b"\x06foo\x00", + {"name": "foo", "id": 0}, + ), ], ) def test_deserialize_no_column_name_success( @@ -84,6 +114,7 @@ def test_deserialize_no_column_name_success( (IntegerDeserializer(), b'{"abc": "abc"}'), (DoubleDeserializer(), b"abc"), (JSONDeserializer(), b"{"), + (AvroDeserializer(AVRO_TEST_SCHEMA), b"\x26foo\x00"), ], ) def test_deserialize_error(self, deserializer: Deserializer, value):