diff --git a/quixstreams/models/serializers/base.py b/quixstreams/models/serializers/base.py index 609de16a8..0a8813667 100644 --- a/quixstreams/models/serializers/base.py +++ b/quixstreams/models/serializers/base.py @@ -45,14 +45,10 @@ def to_confluent_ctx(self, field: MessageField) -> _SerializationContext: class Deserializer(abc.ABC): - def __init__(self, column_name: Optional[str] = None, *args, **kwargs): + def __init__(self, *args, **kwargs): """ A base class for all Deserializers - - :param column_name: if provided, the deserialized value will be wrapped into - dictionary with `column_name` as a key. """ - self.column_name = column_name @property def split_values(self) -> bool: @@ -62,11 +58,6 @@ def split_values(self) -> bool: """ return False - def _to_dict(self, value: Any) -> Union[Any, dict]: - if self.column_name: - return {self.column_name: value} - return value - @abc.abstractmethod def __call__(self, *args, **kwargs) -> Any: ... diff --git a/quixstreams/models/serializers/json.py b/quixstreams/models/serializers/json.py index 187c1cee8..0a5a824e1 100644 --- a/quixstreams/models/serializers/json.py +++ b/quixstreams/models/serializers/json.py @@ -35,25 +35,21 @@ def _to_json(self, value: Any): class JSONDeserializer(Deserializer): def __init__( self, - column_name: Optional[str] = None, loads: Callable[[Union[bytes, bytearray]], Any] = default_loads, ): """ Deserializer that parses data from JSON - :param column_name: if provided, the deserialized value will be wrapped into - dictionary with `column_name` as a key. :param loads: function to parse json from bytes. Default - :py:func:`quixstreams.utils.json.loads`. """ - super().__init__(column_name=column_name) + super().__init__() self._loads = loads def __call__( self, value: bytes, ctx: SerializationContext ) -> Union[Iterable[Mapping], Mapping]: try: - deserialized = self._loads(value) - return self._to_dict(deserialized) + return self._loads(value) except (ValueError, TypeError) as exc: raise SerializationError(str(exc)) from exc diff --git a/quixstreams/models/serializers/quix.py b/quixstreams/models/serializers/quix.py index 2253f1b37..e9081b623 100644 --- a/quixstreams/models/serializers/quix.py +++ b/quixstreams/models/serializers/quix.py @@ -79,16 +79,13 @@ class QuixDeserializer(JSONDeserializer): def __init__( self, - column_name: Optional[str] = None, loads: Callable[[Union[bytes, bytearray]], Any] = default_loads, ): """ - :param column_name: if provided, the deserialized value will be wrapped into - dictionary with `column_name` as a key. :param loads: function to parse json from bytes. Default - :py:func:`quixstreams.utils.json.loads`. """ - super().__init__(column_name=column_name, loads=loads) + super().__init__(loads=loads) self._deserializers = { QModelKey.TIMESERIESDATA: self.deserialize_timeseries, QModelKey.PARAMETERDATA: self.deserialize_timeseries, @@ -148,7 +145,7 @@ def deserialize_timeseries( row_value["Tags"] = {tag: next(values) for tag, values in tags} row_value[Q_TIMESTAMP_KEY] = timestamp_ns - yield self._to_dict(row_value) + yield row_value def deserialize( self, model_key: str, value: Union[List[Mapping], Mapping] @@ -163,11 +160,11 @@ def deserialize( return self._deserializers[model_key](value) def deserialize_event_data(self, value: Mapping) -> Iterable[Mapping]: - yield self._to_dict(self._parse_event_data(value)) + yield self._parse_event_data(value) def deserialize_event_data_list(self, value: List[Mapping]) -> Iterable[Mapping]: for item in value: - yield self._to_dict(self._parse_event_data(item)) + yield self._parse_event_data(item) def _parse_event_data(self, value: Mapping) -> Mapping: if not isinstance(value, Mapping): diff --git a/quixstreams/models/serializers/simple_types.py b/quixstreams/models/serializers/simple_types.py index 846fa2f44..c63ad5f3e 100644 --- a/quixstreams/models/serializers/simple_types.py +++ b/quixstreams/models/serializers/simple_types.py @@ -46,10 +46,8 @@ class BytesDeserializer(Deserializer): A deserializer to bypass bytes without any changes """ - def __call__( - self, value: bytes, ctx: SerializationContext - ) -> Union[bytes, Mapping[str, bytes]]: - return self._to_dict(value) + def __call__(self, value: bytes, ctx: SerializationContext) -> bytes: + return value class BytesSerializer(Serializer): @@ -62,14 +60,14 @@ def __call__(self, value: bytes, ctx: SerializationContext) -> bytes: class StringDeserializer(Deserializer): - def __init__(self, column_name: Optional[str] = None, codec: str = "utf_8"): + def __init__(self, codec: str = "utf_8"): """ Deserializes bytes to strings using the specified encoding. :param codec: string encoding A wrapper around `confluent_kafka.serialization.StringDeserializer`. """ - super().__init__(column_name=column_name) + super().__init__() self._codec = codec self._deserializer = _StringDeserializer(codec=self._codec) @@ -77,8 +75,7 @@ def __init__(self, column_name: Optional[str] = None, codec: str = "utf_8"): def __call__( self, value: bytes, ctx: SerializationContext ) -> Union[str, Mapping[str, str]]: - deserialized = self._deserializer(value=value) - return self._to_dict(deserialized) + return self._deserializer(value=value) class IntegerDeserializer(Deserializer): @@ -88,16 +85,15 @@ class IntegerDeserializer(Deserializer): A wrapper around `confluent_kafka.serialization.IntegerDeserializer`. """ - def __init__(self, column_name: Optional[str] = None): - super().__init__(column_name=column_name) + def __init__(self): + super().__init__() self._deserializer = _IntegerDeserializer() @_wrap_serialization_error def __call__( self, value: bytes, ctx: SerializationContext ) -> Union[int, Mapping[str, int]]: - deserialized = self._deserializer(value=value) - return self._to_dict(deserialized) + return self._deserializer(value=value) class DoubleDeserializer(Deserializer): @@ -107,16 +103,15 @@ class DoubleDeserializer(Deserializer): A wrapper around `confluent_kafka.serialization.DoubleDeserializer`. """ - def __init__(self, column_name: Optional[str] = None): - super().__init__(column_name=column_name) + def __init__(self): + super().__init__() self._deserializer = _DoubleDeserializer() @_wrap_serialization_error def __call__( self, value: bytes, ctx: SerializationContext ) -> Union[float, Mapping[str, float]]: - deserialized = self._deserializer(value=value) - return self._to_dict(deserialized) + return self._deserializer(value=value) class StringSerializer(Serializer): diff --git a/tests/test_quixstreams/test_app.py b/tests/test_quixstreams/test_app.py index 9d83c2465..571a9a310 100644 --- a/tests/test_quixstreams/test_app.py +++ b/tests/test_quixstreams/test_app.py @@ -121,11 +121,10 @@ def on_message_processed(topic_, partition, offset): on_message_processed=on_message_processed, ) - column_name = "root" partition_num = 0 topic_in = app.topic( str(uuid.uuid4()), - value_deserializer=JSONDeserializer(column_name=column_name), + value_deserializer=JSONDeserializer(), ) topic_out = app.topic( str(uuid.uuid4()), @@ -178,7 +177,7 @@ def on_message_processed(topic_, partition, offset): for row in rows_out: assert row.topic == topic_out.name assert row.key == data["key"] - assert row.value == {column_name: loads(data["value"].decode())} + assert row.value == loads(data["value"].decode()) assert row.timestamp == timestamp_ms assert row.headers == headers @@ -240,9 +239,7 @@ def count_and_fail(_): def test_run_consumer_error_raised(self, app_factory, executor): # Set "auto_offset_reset" to "error" to simulate errors in Consumer app = app_factory(auto_offset_reset="error") - topic = app.topic( - str(uuid.uuid4()), value_deserializer=JSONDeserializer(column_name="root") - ) + topic = app.topic(str(uuid.uuid4()), value_deserializer=JSONDeserializer()) sdf = app.dataframe(topic) # Stop app after 10s if nothing failed diff --git a/tests/test_quixstreams/test_models/test_quix_serializers.py b/tests/test_quixstreams/test_models/test_quix_serializers.py index 72427e33b..c0cd3a651 100644 --- a/tests/test_quixstreams/test_models/test_quix_serializers.py +++ b/tests/test_quixstreams/test_models/test_quix_serializers.py @@ -261,66 +261,6 @@ def test_deserialize_timeseries_timestamp_field_clash( ) ) - @pytest.mark.parametrize("as_legacy", [False, True]) - def test_deserialize_timeseries_with_column_name_success( - self, quix_timeseries_factory, as_legacy - ): - message = quix_timeseries_factory( - binary={"param1": [b"1", None], "param2": [None, b"1"]}, - strings={"param3": [1, None], "param4": [None, 1.1]}, - numeric={"param5": ["1", None], "param6": [None, "a"], "param7": ["", ""]}, - tags={"tag1": ["value1", "value2"], "tag2": ["value3", "value4"]}, - timestamps=[1234567890, 1234567891], - as_legacy=as_legacy, - ) - - expected = [ - { - "root": { - "param1": b"1", - "param2": None, - "param3": 1, - "param4": None, - "param5": "1", - "param6": None, - "param7": "", - "Tags": {"tag1": "value1", "tag2": "value3"}, - "Timestamp": 1234567890, - } - }, - { - "root": { - "param1": None, - "param2": b"1", - "param3": None, - "param4": 1.1, - "param5": None, - "param6": "a", - "param7": "", - "Tags": {"tag1": "value2", "tag2": "value4"}, - "Timestamp": 1234567891, - } - }, - ] - - deserializer = QuixDeserializer(column_name="root") - rows = list( - deserializer( - value=message.value(), - ctx=SerializationContext( - topic=message.topic(), - headers=message.headers(), - ), - ) - ) - assert len(rows) == len(expected) - for item, row in zip(expected, rows): - assert "root" in row - value = row["root"] - item = row["root"] - for key in item: - assert item[key] == value[key] - @pytest.mark.parametrize("as_legacy", [False, True]) def test_deserialize_eventdata_success( self, quix_eventdata_factory, quix_eventdata_params_factory, as_legacy @@ -381,45 +321,6 @@ def test_deserialize_eventdata_list_success( assert row["Value"] == params.value assert row["Tags"] == params.tags - @pytest.mark.parametrize("as_legacy", [False, True]) - def test_deserialize_event_data_with_column( - self, - quix_eventdata_list_factory, - quix_eventdata_params_factory, - as_legacy, - ): - event_params = [ - quix_eventdata_params_factory( - id="test", - value={"blabla": 123}, - tags={"tag1": "1"}, - timestamp=1234567790, - ), - quix_eventdata_params_factory( - id="test2", - value={"blabla2": 1234}, - tags={"tag2": "2"}, - timestamp=1234567891, - ), - ] - message = quix_eventdata_list_factory(params=event_params, as_legacy=as_legacy) - - deserializer = QuixDeserializer(column_name="root") - rows = list( - deserializer( - value=message.value(), - ctx=SerializationContext(topic="test", headers=message.headers()), - ) - ) - assert len(rows) == 2 - for row, params in zip(rows, event_params): - assert "root" in row - row = row["root"] - assert row["Timestamp"] == params.timestamp - assert row["Id"] == params.id - assert row["Value"] == params.value - assert row["Tags"] == params.tags - class TestQuixTimeseriesSerializer: def test_serialize_dict_success(self): diff --git a/tests/test_quixstreams/test_models/test_serializers.py b/tests/test_quixstreams/test_models/test_serializers.py index 0f51435d9..ce3012bd4 100644 --- a/tests/test_quixstreams/test_models/test_serializers.py +++ b/tests/test_quixstreams/test_models/test_serializers.py @@ -77,32 +77,6 @@ def test_deserialize_no_column_name_success( ): assert deserializer(value, ctx=dummy_context) == expected - @pytest.mark.parametrize( - "deserializer, value, expected", - [ - ( - IntegerDeserializer("value"), - int_to_bytes(123), - {"value": 123}, - ), - (DoubleDeserializer("value"), float_to_bytes(123), {"value": 123.0}), - (DoubleDeserializer("value"), float_to_bytes(123.123), {"value": 123.123}), - (StringDeserializer("value"), b"abc", {"value": "abc"}), - ( - StringDeserializer("value", codec="cp1251"), - "abc".encode("cp1251"), - {"value": "abc"}, - ), - (BytesDeserializer("value"), b"123123", {"value": b"123123"}), - (JSONDeserializer("value"), b"123123", {"value": 123123}), - (JSONDeserializer("value"), b'{"a":"b"}', {"value": {"a": "b"}}), - ], - ) - def test_deserialize_with_column_name_success( - self, deserializer: Deserializer, value, expected - ): - assert deserializer(value, ctx=dummy_context) == expected - @pytest.mark.parametrize( "deserializer, value", [ diff --git a/tests/test_quixstreams/test_models/test_topics/test_topics.py b/tests/test_quixstreams/test_models/test_topics/test_topics.py index 739759cd9..ef639a0ca 100644 --- a/tests/test_quixstreams/test_models/test_topics/test_topics.py +++ b/tests/test_quixstreams/test_models/test_topics/test_topics.py @@ -40,8 +40,6 @@ def __call__(self, value: bytes, ctx: SerializationContext): deserialized = self._deserializer(value=value) if not deserialized % 3: raise IgnoreMessage("Ignore numbers divisible by 3") - if self.column_name: - return {self.column_name: deserialized} return deserialized @@ -51,11 +49,11 @@ class TestTopic: [ ( IntegerDeserializer(), - IntegerDeserializer("column"), + IntegerDeserializer(), int_to_bytes(1), int_to_bytes(2), 1, - {"column": 2}, + 2, ), ( DoubleDeserializer(), @@ -75,11 +73,11 @@ class TestTopic: ), ( DoubleDeserializer(), - JSONDeserializer(column_name="root"), + JSONDeserializer(), float_to_bytes(1.1), json.dumps({"key": "value"}).encode(), 1.1, - {"root": {"key": "value"}}, + {"key": "value"}, ), ( BytesDeserializer(), @@ -194,13 +192,13 @@ def test_row_list_deserialize_success( def test_row_deserialize_ignorevalueerror_raised(self, topic_manager_topic_factory): topic = topic_manager_topic_factory( - value_deserializer=IgnoreDivisibleBy3Deserializer(column_name="value"), + value_deserializer=IgnoreDivisibleBy3Deserializer(), ) row = topic.row_deserialize( message=ConfluentKafkaMessageStub(key=b"key", value=int_to_bytes(4)) ) assert row - assert row.value == {"value": 4} + assert row.value == 4 row = topic.row_deserialize( message=ConfluentKafkaMessageStub(key=b"key", value=int_to_bytes(3))