From 5cb56f0ed37c7af0f07527878ebab49867fb9f6c Mon Sep 17 00:00:00 2001 From: Daniil Gusev Date: Thu, 26 Oct 2023 23:26:07 +0200 Subject: [PATCH] Add more comparators methods to Column --- .../streamingdataframes/dataframe/column.py | 79 +++++++++++++------ .../dataframe/dataframe.py | 8 +- .../test_dataframe/test_column.py | 73 +++++++++++++++++ 3 files changed, 133 insertions(+), 27 deletions(-) diff --git a/src/StreamingDataFrames/streamingdataframes/dataframe/column.py b/src/StreamingDataFrames/streamingdataframes/dataframe/column.py index 507b71286..818a2e3c1 100644 --- a/src/StreamingDataFrames/streamingdataframes/dataframe/column.py +++ b/src/StreamingDataFrames/streamingdataframes/dataframe/column.py @@ -1,14 +1,13 @@ import operator -from typing import Optional, Any, Callable, Union +from typing import Optional, Any, Callable, Container + from typing_extensions import Self, TypeAlias from ..models import Row -OpValue: TypeAlias = Union[int, float, bool] -ColumnValue: TypeAlias = Union[int, float, bool, list, dict] -ColumnApplier: TypeAlias = Callable[[ColumnValue], OpValue] +ColumnApplier: TypeAlias = Callable[[Any], Any] -__all__ = ("Column", "OpValue", "ColumnValue", "ColumnApplier") +__all__ = ("Column", "ColumnApplier") def invert(value): @@ -27,14 +26,14 @@ def __init__( self.col_name = col_name self._eval_func = _eval_func if _eval_func else lambda row: row[self.col_name] - def _operation(self, other: Any, op: Callable[[OpValue, OpValue], OpValue]) -> Self: - return Column( + def _operation(self, other: Any, op: Callable[[Any, Any], Any]) -> Self: + return self.__class__( _eval_func=lambda x: op( self.eval(x), other.eval(x) if isinstance(other, Column) else other ), ) - def eval(self, row: Row) -> ColumnValue: + def eval(self, row: Row) -> Any: """ Execute all the functions accumulated on this Column. @@ -55,44 +54,78 @@ def apply(self, func: ColumnApplier) -> Self: """ return Column(_eval_func=lambda x: func(self.eval(x))) - def __and__(self, other): + def isin(self, other: Container) -> Self: + return self._operation(other, lambda a, b: operator.contains(b, a)) + + def contains(self, other: Any) -> Self: + return self._operation(other, operator.contains) + + def is_(self, other: Any) -> Self: + """ + Check if column value refers to the same object as `other` + :param other: object to check for "is" + :return: + """ + return self._operation(other, operator.is_) + + def isnot(self, other: Any) -> Self: + """ + Check if column value refers to the same object as `other` + :param other: object to check for "is" + :return: + """ + return self._operation(other, operator.is_not) + + def isnull(self) -> Self: + """ + Check if column value is None + """ + return self._operation(None, operator.is_) + + def notnull(self) -> Self: + """ + Check if column value is not None + """ + return self._operation(None, operator.is_not) + + def __and__(self, other: Any) -> Self: return self._operation(other, operator.and_) - def __or__(self, other): + def __or__(self, other: Any) -> Self: return self._operation(other, operator.or_) - def __mod__(self, other): + def __mod__(self, other: Any) -> Self: return self._operation(other, operator.mod) - def __add__(self, other): + def __add__(self, other: Any) -> Self: return self._operation(other, operator.add) - def __sub__(self, other): + def __sub__(self, other: Any) -> Self: return self._operation(other, operator.sub) - def __mul__(self, other): + def __mul__(self, other: Any) -> Self: return self._operation(other, operator.mul) - def __truediv__(self, other): + def __truediv__(self, other: Any) -> Self: return self._operation(other, operator.truediv) - def __eq__(self, other): + def __eq__(self, other: Any) -> Self: return self._operation(other, operator.eq) - def __ne__(self, other): + def __ne__(self, other: Any) -> Self: return self._operation(other, operator.ne) - def __lt__(self, other): + def __lt__(self, other: Any) -> Self: return self._operation(other, operator.lt) - def __le__(self, other): + def __le__(self, other: Any) -> Self: return self._operation(other, operator.le) - def __gt__(self, other): + def __gt__(self, other: Any) -> Self: return self._operation(other, operator.gt) - def __ge__(self, other): + def __ge__(self, other: Any) -> Self: return self._operation(other, operator.ge) - def __invert__(self): - return Column(_eval_func=lambda x: invert(self.eval(x))) + def __invert__(self) -> Self: + return self.__class__(_eval_func=lambda x: invert(self.eval(x))) diff --git a/src/StreamingDataFrames/streamingdataframes/dataframe/dataframe.py b/src/StreamingDataFrames/streamingdataframes/dataframe/dataframe.py index 73926722d..b26987fd3 100644 --- a/src/StreamingDataFrames/streamingdataframes/dataframe/dataframe.py +++ b/src/StreamingDataFrames/streamingdataframes/dataframe/dataframe.py @@ -1,9 +1,9 @@ import uuid +from typing import Optional, Callable, Union, List, Mapping, Any -from typing import Optional, Callable, Union, List, Mapping from typing_extensions import Self, TypeAlias -from .column import Column, OpValue +from .column import Column from .exceptions import InvalidApplyResultType from .pipeline import Pipeline from ..models import Row, Topic @@ -20,7 +20,7 @@ def subset(keys: List[str], row: Row) -> Row: return row -def setitem(k: str, v: Union[Column, OpValue], row: Row) -> Row: +def setitem(k: str, v: Any, row: Row) -> Row: row[k] = v.eval(row) if isinstance(v, Column) else v return row @@ -189,7 +189,7 @@ def producer(self) -> RowProducerProto: def producer(self, producer: RowProducerProto): self._real_producer = producer - def __setitem__(self, key: str, value: Union[Column, OpValue, str]): + def __setitem__(self, key: str, value: Any): self._apply(lambda row: setitem(key, value, row)) def __getitem__( diff --git a/src/StreamingDataFrames/tests/test_dataframes/test_dataframe/test_column.py b/src/StreamingDataFrames/tests/test_dataframes/test_dataframe/test_column.py index dc01b2a6e..ce7eda96a 100644 --- a/src/StreamingDataFrames/tests/test_dataframes/test_dataframe/test_column.py +++ b/src/StreamingDataFrames/tests/test_dataframes/test_dataframe/test_column.py @@ -1,3 +1,5 @@ +import pytest + from streamingdataframes.dataframe.column import Column @@ -211,3 +213,74 @@ def test_invert_bool_from_inequalities(self, row_factory): result = ~(Column("x") <= Column("y")) assert isinstance(result, Column) assert result.eval(msg_value) is False + + @pytest.mark.parametrize( + "value, other, expected", + [ + ({"x": 1}, [1, 2, 3], True), + ({"x": 1}, [], False), + ({"x": 1}, {1: 456}, True), + ], + ) + def test_isin(self, row_factory, value, other, expected): + row = row_factory(value) + assert Column("x").isin(other).eval(row) == expected + + @pytest.mark.parametrize( + "value, other, expected", + [ + ({"x": [1, 2, 3]}, 1, True), + ({"x": [1, 2, 3]}, 5, False), + ({"x": "abc"}, "a", True), + ({"x": {"y": "z"}}, "y", True), + ], + ) + def test_contains(self, row_factory, value, other, expected): + row = row_factory(value) + assert Column("x").contains(other).eval(row) == expected + + @pytest.mark.parametrize( + "value, expected", + [ + ({"x": None}, True), + ({"x": [1, 2, 3]}, False), + ], + ) + def test_isnull(self, row_factory, value, expected): + row = row_factory(value) + assert Column("x").isnull().eval(row) == expected + + @pytest.mark.parametrize( + "value, expected", + [ + ({"x": None}, False), + ({"x": [1, 2, 3]}, True), + ], + ) + def test_notnull(self, row_factory, value, expected): + row = row_factory(value) + assert Column("x").notnull().eval(row) == expected + + @pytest.mark.parametrize( + "value, other, expected", + [ + ({"x": [1, 2, 3]}, None, False), + ({"x": None}, None, True), + ({"x": 1}, 1, True), + ], + ) + def test_is_(self, row_factory, value, other, expected): + row = row_factory(value) + assert Column("x").is_(other).eval(row) == expected + + @pytest.mark.parametrize( + "value, other, expected", + [ + ({"x": [1, 2, 3]}, None, True), + ({"x": None}, None, False), + ({"x": 1}, 1, False), + ], + ) + def test_isnot(self, row_factory, value, other, expected): + row = row_factory(value) + assert Column("x").isnot(other).eval(row) == expected