Skip to content

Commit

Permalink
Add more comparators methods to Column (#207)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-quix authored Oct 28, 2023
1 parent b7a3143 commit a2af330
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 33 deletions.
79 changes: 56 additions & 23 deletions src/StreamingDataFrames/streamingdataframes/dataframe/column.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import operator
from typing import Optional, Any, Callable, Union
from typing import Optional, Any, Callable, Container

from typing_extensions import Self, TypeAlias

from ..models import Row

OpValue: TypeAlias = Union[int, float, bool]
ColumnValue: TypeAlias = Union[int, float, bool, list, dict]
ColumnApplier: TypeAlias = Callable[[ColumnValue], OpValue]
ColumnApplier: TypeAlias = Callable[[Any], Any]

__all__ = ("Column", "OpValue", "ColumnValue", "ColumnApplier")
__all__ = ("Column", "ColumnApplier")


def invert(value):
Expand All @@ -27,14 +26,14 @@ def __init__(
self.col_name = col_name
self._eval_func = _eval_func if _eval_func else lambda row: row[self.col_name]

def _operation(self, other: Any, op: Callable[[OpValue, OpValue], OpValue]) -> Self:
return Column(
def _operation(self, other: Any, op: Callable[[Any, Any], Any]) -> Self:
return self.__class__(
_eval_func=lambda x: op(
self.eval(x), other.eval(x) if isinstance(other, Column) else other
),
)

def eval(self, row: Row) -> ColumnValue:
def eval(self, row: Row) -> Any:
"""
Execute all the functions accumulated on this Column.
Expand All @@ -55,44 +54,78 @@ def apply(self, func: ColumnApplier) -> Self:
"""
return Column(_eval_func=lambda x: func(self.eval(x)))

def __and__(self, other):
def isin(self, other: Container) -> Self:
return self._operation(other, lambda a, b: operator.contains(b, a))

def contains(self, other: Any) -> Self:
return self._operation(other, operator.contains)

def is_(self, other: Any) -> Self:
"""
Check if column value refers to the same object as `other`
:param other: object to check for "is"
:return:
"""
return self._operation(other, operator.is_)

def isnot(self, other: Any) -> Self:
"""
Check if column value refers to the same object as `other`
:param other: object to check for "is"
:return:
"""
return self._operation(other, operator.is_not)

def isnull(self) -> Self:
"""
Check if column value is None
"""
return self._operation(None, operator.is_)

def notnull(self) -> Self:
"""
Check if column value is not None
"""
return self._operation(None, operator.is_not)

def __and__(self, other: Any) -> Self:
return self._operation(other, operator.and_)

def __or__(self, other):
def __or__(self, other: Any) -> Self:
return self._operation(other, operator.or_)

def __mod__(self, other):
def __mod__(self, other: Any) -> Self:
return self._operation(other, operator.mod)

def __add__(self, other):
def __add__(self, other: Any) -> Self:
return self._operation(other, operator.add)

def __sub__(self, other):
def __sub__(self, other: Any) -> Self:
return self._operation(other, operator.sub)

def __mul__(self, other):
def __mul__(self, other: Any) -> Self:
return self._operation(other, operator.mul)

def __truediv__(self, other):
def __truediv__(self, other: Any) -> Self:
return self._operation(other, operator.truediv)

def __eq__(self, other):
def __eq__(self, other: Any) -> Self:
return self._operation(other, operator.eq)

def __ne__(self, other):
def __ne__(self, other: Any) -> Self:
return self._operation(other, operator.ne)

def __lt__(self, other):
def __lt__(self, other: Any) -> Self:
return self._operation(other, operator.lt)

def __le__(self, other):
def __le__(self, other: Any) -> Self:
return self._operation(other, operator.le)

def __gt__(self, other):
def __gt__(self, other: Any) -> Self:
return self._operation(other, operator.gt)

def __ge__(self, other):
def __ge__(self, other: Any) -> Self:
return self._operation(other, operator.ge)

def __invert__(self):
return Column(_eval_func=lambda x: invert(self.eval(x)))
def __invert__(self) -> Self:
return self.__class__(_eval_func=lambda x: invert(self.eval(x)))
14 changes: 4 additions & 10 deletions src/StreamingDataFrames/streamingdataframes/dataframe/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import uuid
from typing import (
Optional,
Callable,
Union,
List,
Mapping,
)
from typing import Optional, Callable, Union, List, Mapping, Any

from typing_extensions import Self, TypeAlias

from .column import Column, OpValue
from .column import Column
from .exceptions import InvalidApplyResultType
from .pipeline import Pipeline
from ..models import Row, Topic
Expand All @@ -30,7 +24,7 @@ def subset(keys: List[str], row: Row) -> Row:
return row


def setitem(k: str, v: Union[Column, OpValue], row: Row) -> Row:
def setitem(k: str, v: Any, row: Row) -> Row:
row[k] = v.eval(row) if isinstance(v, Column) else v
return row

Expand Down Expand Up @@ -223,7 +217,7 @@ def producer(self) -> RowProducerProto:
def producer(self, producer: RowProducerProto):
self._real_producer = producer

def __setitem__(self, key: str, value: Union[Column, OpValue, str]):
def __setitem__(self, key: str, value: Any):
self._apply(lambda row: setitem(key, value, row))

def __getitem__(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from streamingdataframes.dataframe.column import Column


Expand Down Expand Up @@ -211,3 +213,74 @@ def test_invert_bool_from_inequalities(self, row_factory):
result = ~(Column("x") <= Column("y"))
assert isinstance(result, Column)
assert result.eval(msg_value) is False

@pytest.mark.parametrize(
"value, other, expected",
[
({"x": 1}, [1, 2, 3], True),
({"x": 1}, [], False),
({"x": 1}, {1: 456}, True),
],
)
def test_isin(self, row_factory, value, other, expected):
row = row_factory(value)
assert Column("x").isin(other).eval(row) == expected

@pytest.mark.parametrize(
"value, other, expected",
[
({"x": [1, 2, 3]}, 1, True),
({"x": [1, 2, 3]}, 5, False),
({"x": "abc"}, "a", True),
({"x": {"y": "z"}}, "y", True),
],
)
def test_contains(self, row_factory, value, other, expected):
row = row_factory(value)
assert Column("x").contains(other).eval(row) == expected

@pytest.mark.parametrize(
"value, expected",
[
({"x": None}, True),
({"x": [1, 2, 3]}, False),
],
)
def test_isnull(self, row_factory, value, expected):
row = row_factory(value)
assert Column("x").isnull().eval(row) == expected

@pytest.mark.parametrize(
"value, expected",
[
({"x": None}, False),
({"x": [1, 2, 3]}, True),
],
)
def test_notnull(self, row_factory, value, expected):
row = row_factory(value)
assert Column("x").notnull().eval(row) == expected

@pytest.mark.parametrize(
"value, other, expected",
[
({"x": [1, 2, 3]}, None, False),
({"x": None}, None, True),
({"x": 1}, 1, True),
],
)
def test_is_(self, row_factory, value, other, expected):
row = row_factory(value)
assert Column("x").is_(other).eval(row) == expected

@pytest.mark.parametrize(
"value, other, expected",
[
({"x": [1, 2, 3]}, None, True),
({"x": None}, None, False),
({"x": 1}, 1, False),
],
)
def test_isnot(self, row_factory, value, other, expected):
row = row_factory(value)
assert Column("x").isnot(other).eval(row) == expected

0 comments on commit a2af330

Please sign in to comment.