Skip to content

Commit

Permalink
remove everything related to generating column_names during deseriali…
Browse files Browse the repository at this point in the history
…zation
  • Loading branch information
tim-quix committed Jun 28, 2024
1 parent 835ef27 commit ea3d071
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 178 deletions.
11 changes: 1 addition & 10 deletions quixstreams/models/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,10 @@ def to_confluent_ctx(self, field: MessageField) -> _SerializationContext:


class Deserializer(abc.ABC):
def __init__(self, column_name: Optional[str] = None, *args, **kwargs):
def __init__(self, *args, **kwargs):
"""
A base class for all Deserializers
:param column_name: if provided, the deserialized value will be wrapped into
dictionary with `column_name` as a key.
"""
self.column_name = column_name

@property
def split_values(self) -> bool:
Expand All @@ -62,11 +58,6 @@ def split_values(self) -> bool:
"""
return False

def _to_dict(self, value: Any) -> Union[Any, dict]:
if self.column_name:
return {self.column_name: value}
return value

@abc.abstractmethod
def __call__(self, *args, **kwargs) -> Any: ...

Expand Down
8 changes: 2 additions & 6 deletions quixstreams/models/serializers/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,21 @@ def _to_json(self, value: Any):
class JSONDeserializer(Deserializer):
def __init__(
self,
column_name: Optional[str] = None,
loads: Callable[[Union[bytes, bytearray]], Any] = default_loads,
):
"""
Deserializer that parses data from JSON
:param column_name: if provided, the deserialized value will be wrapped into
dictionary with `column_name` as a key.
:param loads: function to parse json from bytes.
Default - :py:func:`quixstreams.utils.json.loads`.
"""
super().__init__(column_name=column_name)
super().__init__()
self._loads = loads

def __call__(
self, value: bytes, ctx: SerializationContext
) -> Union[Iterable[Mapping], Mapping]:
try:
deserialized = self._loads(value)
return self._to_dict(deserialized)
return self._loads(value)
except (ValueError, TypeError) as exc:
raise SerializationError(str(exc)) from exc
11 changes: 4 additions & 7 deletions quixstreams/models/serializers/quix.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,13 @@ class QuixDeserializer(JSONDeserializer):

def __init__(
self,
column_name: Optional[str] = None,
loads: Callable[[Union[bytes, bytearray]], Any] = default_loads,
):
"""
:param column_name: if provided, the deserialized value will be wrapped into
dictionary with `column_name` as a key.
:param loads: function to parse json from bytes.
Default - :py:func:`quixstreams.utils.json.loads`.
"""
super().__init__(column_name=column_name, loads=loads)
super().__init__(loads=loads)
self._deserializers = {
QModelKey.TIMESERIESDATA: self.deserialize_timeseries,
QModelKey.PARAMETERDATA: self.deserialize_timeseries,
Expand Down Expand Up @@ -148,7 +145,7 @@ def deserialize_timeseries(
row_value["Tags"] = {tag: next(values) for tag, values in tags}

row_value[Q_TIMESTAMP_KEY] = timestamp_ns
yield self._to_dict(row_value)
yield row_value

def deserialize(
self, model_key: str, value: Union[List[Mapping], Mapping]
Expand All @@ -163,11 +160,11 @@ def deserialize(
return self._deserializers[model_key](value)

def deserialize_event_data(self, value: Mapping) -> Iterable[Mapping]:
yield self._to_dict(self._parse_event_data(value))
yield self._parse_event_data(value)

def deserialize_event_data_list(self, value: List[Mapping]) -> Iterable[Mapping]:
for item in value:
yield self._to_dict(self._parse_event_data(item))
yield self._parse_event_data(item)

def _parse_event_data(self, value: Mapping) -> Mapping:
if not isinstance(value, Mapping):
Expand Down
27 changes: 11 additions & 16 deletions quixstreams/models/serializers/simple_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,8 @@ class BytesDeserializer(Deserializer):
A deserializer to bypass bytes without any changes
"""

def __call__(
self, value: bytes, ctx: SerializationContext
) -> Union[bytes, Mapping[str, bytes]]:
return self._to_dict(value)
def __call__(self, value: bytes, ctx: SerializationContext) -> bytes:
return value


class BytesSerializer(Serializer):
Expand All @@ -62,23 +60,22 @@ def __call__(self, value: bytes, ctx: SerializationContext) -> bytes:


class StringDeserializer(Deserializer):
def __init__(self, column_name: Optional[str] = None, codec: str = "utf_8"):
def __init__(self, codec: str = "utf_8"):
"""
Deserializes bytes to strings using the specified encoding.
:param codec: string encoding
A wrapper around `confluent_kafka.serialization.StringDeserializer`.
"""
super().__init__(column_name=column_name)
super().__init__()
self._codec = codec
self._deserializer = _StringDeserializer(codec=self._codec)

@_wrap_serialization_error
def __call__(
self, value: bytes, ctx: SerializationContext
) -> Union[str, Mapping[str, str]]:
deserialized = self._deserializer(value=value)
return self._to_dict(deserialized)
return self._deserializer(value=value)


class IntegerDeserializer(Deserializer):
Expand All @@ -88,16 +85,15 @@ class IntegerDeserializer(Deserializer):
A wrapper around `confluent_kafka.serialization.IntegerDeserializer`.
"""

def __init__(self, column_name: Optional[str] = None):
super().__init__(column_name=column_name)
def __init__(self):
super().__init__()
self._deserializer = _IntegerDeserializer()

@_wrap_serialization_error
def __call__(
self, value: bytes, ctx: SerializationContext
) -> Union[int, Mapping[str, int]]:
deserialized = self._deserializer(value=value)
return self._to_dict(deserialized)
return self._deserializer(value=value)


class DoubleDeserializer(Deserializer):
Expand All @@ -107,16 +103,15 @@ class DoubleDeserializer(Deserializer):
A wrapper around `confluent_kafka.serialization.DoubleDeserializer`.
"""

def __init__(self, column_name: Optional[str] = None):
super().__init__(column_name=column_name)
def __init__(self):
super().__init__()
self._deserializer = _DoubleDeserializer()

@_wrap_serialization_error
def __call__(
self, value: bytes, ctx: SerializationContext
) -> Union[float, Mapping[str, float]]:
deserialized = self._deserializer(value=value)
return self._to_dict(deserialized)
return self._deserializer(value=value)


class StringSerializer(Serializer):
Expand Down
9 changes: 3 additions & 6 deletions tests/test_quixstreams/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,10 @@ def on_message_processed(topic_, partition, offset):
on_message_processed=on_message_processed,
)

column_name = "root"
partition_num = 0
topic_in = app.topic(
str(uuid.uuid4()),
value_deserializer=JSONDeserializer(column_name=column_name),
value_deserializer=JSONDeserializer(),
)
topic_out = app.topic(
str(uuid.uuid4()),
Expand Down Expand Up @@ -178,7 +177,7 @@ def on_message_processed(topic_, partition, offset):
for row in rows_out:
assert row.topic == topic_out.name
assert row.key == data["key"]
assert row.value == {column_name: loads(data["value"].decode())}
assert row.value == loads(data["value"].decode())
assert row.timestamp == timestamp_ms
assert row.headers == headers

Expand Down Expand Up @@ -240,9 +239,7 @@ def count_and_fail(_):
def test_run_consumer_error_raised(self, app_factory, executor):
# Set "auto_offset_reset" to "error" to simulate errors in Consumer
app = app_factory(auto_offset_reset="error")
topic = app.topic(
str(uuid.uuid4()), value_deserializer=JSONDeserializer(column_name="root")
)
topic = app.topic(str(uuid.uuid4()), value_deserializer=JSONDeserializer())
sdf = app.dataframe(topic)

# Stop app after 10s if nothing failed
Expand Down
99 changes: 0 additions & 99 deletions tests/test_quixstreams/test_models/test_quix_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,66 +261,6 @@ def test_deserialize_timeseries_timestamp_field_clash(
)
)

@pytest.mark.parametrize("as_legacy", [False, True])
def test_deserialize_timeseries_with_column_name_success(
self, quix_timeseries_factory, as_legacy
):
message = quix_timeseries_factory(
binary={"param1": [b"1", None], "param2": [None, b"1"]},
strings={"param3": [1, None], "param4": [None, 1.1]},
numeric={"param5": ["1", None], "param6": [None, "a"], "param7": ["", ""]},
tags={"tag1": ["value1", "value2"], "tag2": ["value3", "value4"]},
timestamps=[1234567890, 1234567891],
as_legacy=as_legacy,
)

expected = [
{
"root": {
"param1": b"1",
"param2": None,
"param3": 1,
"param4": None,
"param5": "1",
"param6": None,
"param7": "",
"Tags": {"tag1": "value1", "tag2": "value3"},
"Timestamp": 1234567890,
}
},
{
"root": {
"param1": None,
"param2": b"1",
"param3": None,
"param4": 1.1,
"param5": None,
"param6": "a",
"param7": "",
"Tags": {"tag1": "value2", "tag2": "value4"},
"Timestamp": 1234567891,
}
},
]

deserializer = QuixDeserializer(column_name="root")
rows = list(
deserializer(
value=message.value(),
ctx=SerializationContext(
topic=message.topic(),
headers=message.headers(),
),
)
)
assert len(rows) == len(expected)
for item, row in zip(expected, rows):
assert "root" in row
value = row["root"]
item = row["root"]
for key in item:
assert item[key] == value[key]

@pytest.mark.parametrize("as_legacy", [False, True])
def test_deserialize_eventdata_success(
self, quix_eventdata_factory, quix_eventdata_params_factory, as_legacy
Expand Down Expand Up @@ -381,45 +321,6 @@ def test_deserialize_eventdata_list_success(
assert row["Value"] == params.value
assert row["Tags"] == params.tags

@pytest.mark.parametrize("as_legacy", [False, True])
def test_deserialize_event_data_with_column(
self,
quix_eventdata_list_factory,
quix_eventdata_params_factory,
as_legacy,
):
event_params = [
quix_eventdata_params_factory(
id="test",
value={"blabla": 123},
tags={"tag1": "1"},
timestamp=1234567790,
),
quix_eventdata_params_factory(
id="test2",
value={"blabla2": 1234},
tags={"tag2": "2"},
timestamp=1234567891,
),
]
message = quix_eventdata_list_factory(params=event_params, as_legacy=as_legacy)

deserializer = QuixDeserializer(column_name="root")
rows = list(
deserializer(
value=message.value(),
ctx=SerializationContext(topic="test", headers=message.headers()),
)
)
assert len(rows) == 2
for row, params in zip(rows, event_params):
assert "root" in row
row = row["root"]
assert row["Timestamp"] == params.timestamp
assert row["Id"] == params.id
assert row["Value"] == params.value
assert row["Tags"] == params.tags


class TestQuixTimeseriesSerializer:
def test_serialize_dict_success(self):
Expand Down
26 changes: 0 additions & 26 deletions tests/test_quixstreams/test_models/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,32 +77,6 @@ def test_deserialize_no_column_name_success(
):
assert deserializer(value, ctx=dummy_context) == expected

@pytest.mark.parametrize(
"deserializer, value, expected",
[
(
IntegerDeserializer("value"),
int_to_bytes(123),
{"value": 123},
),
(DoubleDeserializer("value"), float_to_bytes(123), {"value": 123.0}),
(DoubleDeserializer("value"), float_to_bytes(123.123), {"value": 123.123}),
(StringDeserializer("value"), b"abc", {"value": "abc"}),
(
StringDeserializer("value", codec="cp1251"),
"abc".encode("cp1251"),
{"value": "abc"},
),
(BytesDeserializer("value"), b"123123", {"value": b"123123"}),
(JSONDeserializer("value"), b"123123", {"value": 123123}),
(JSONDeserializer("value"), b'{"a":"b"}', {"value": {"a": "b"}}),
],
)
def test_deserialize_with_column_name_success(
self, deserializer: Deserializer, value, expected
):
assert deserializer(value, ctx=dummy_context) == expected

@pytest.mark.parametrize(
"deserializer, value",
[
Expand Down
Loading

0 comments on commit ea3d071

Please sign in to comment.