Skip to content

Commit

Permalink
[FEAT]: Dataframe.filter method (#2853)
Browse files Browse the repository at this point in the history
closes #2846

Note, this does not rename 'where' to 'filter' but instead just adds an
alias. We can revisit at a later date if we want to deprecate the
`where` method.
  • Loading branch information
universalmind303 authored Sep 17, 2024
1 parent 07e92f6 commit 6766955
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
17 changes: 17 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,23 @@ def exclude(self, *names: str) -> "DataFrame":
builder = self._builder.exclude(list(names))
return DataFrame(builder)

@DataframePublicAPI
def filter(self, predicate: Union[Expression, str]) -> "DataFrame":
"""Filters rows via a predicate expression, similar to SQL ``WHERE``.
Alias for daft.DataFrame.where.
.. seealso::
:meth:`.where(predicate) <DataFrame.where>`
Args:
predicate (Expression): expression that keeps row if evaluates to True.
Returns:
DataFrame: Filtered DataFrame.
"""
return self.where(predicate)

@DataframePublicAPI
def where(self, predicate: Union[Expression, str]) -> "DataFrame":
"""Filters rows via a predicate expression, similar to SQL ``WHERE``.
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Filtering Rows
:toctree: doc_gen/dataframe_methods

DataFrame.distinct
DataFrame.filter
DataFrame.where
DataFrame.limit
DataFrame.sample
Expand Down
9 changes: 9 additions & 0 deletions tests/dataframe/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,12 @@ def test_filter_sql() -> None:
expected = {"x": [3], "y": [6], "z": [9]}

assert df == expected


def test_filter_alias_for_where() -> None:
df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6], "z": [7, 9, 9]})

expected = df.where("z = 9 AND y > 5").collect().to_pydict()
actual = df.filter("z = 9 AND y > 5").collect().to_pydict()

assert actual == expected

0 comments on commit 6766955

Please sign in to comment.