From 67669558b7aad66d40aac6a42623b1c3a0d002fa Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 17 Sep 2024 10:54:52 -0500 Subject: [PATCH] [FEAT]: Dataframe.filter method (#2853) closes https://github.com/Eventual-Inc/Daft/issues/2846 Note, this does not rename 'where' to 'filter' but instead just adds an alias. We can revisit at a later date if we want to deprecate the `where` method. --- daft/dataframe/dataframe.py | 17 +++++++++++++++++ docs/source/api_docs/dataframe.rst | 1 + tests/dataframe/test_filter.py | 9 +++++++++ 3 files changed, 27 insertions(+) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index dd281f314f..9ea9fff5f9 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1310,6 +1310,23 @@ def exclude(self, *names: str) -> "DataFrame": builder = self._builder.exclude(list(names)) return DataFrame(builder) + @DataframePublicAPI + def filter(self, predicate: Union[Expression, str]) -> "DataFrame": + """Filters rows via a predicate expression, similar to SQL ``WHERE``. + + Alias for daft.DataFrame.where. + + .. seealso:: + :meth:`.where(predicate) ` + + Args: + predicate (Expression): expression that keeps row if evaluates to True. + + Returns: + DataFrame: Filtered DataFrame. + """ + return self.where(predicate) + @DataframePublicAPI def where(self, predicate: Union[Expression, str]) -> "DataFrame": """Filters rows via a predicate expression, similar to SQL ``WHERE``. diff --git a/docs/source/api_docs/dataframe.rst b/docs/source/api_docs/dataframe.rst index 065e539e25..f93f052742 100644 --- a/docs/source/api_docs/dataframe.rst +++ b/docs/source/api_docs/dataframe.rst @@ -58,6 +58,7 @@ Filtering Rows :toctree: doc_gen/dataframe_methods DataFrame.distinct + DataFrame.filter DataFrame.where DataFrame.limit DataFrame.sample diff --git a/tests/dataframe/test_filter.py b/tests/dataframe/test_filter.py index ed8d8fb33f..3844b7ed84 100644 --- a/tests/dataframe/test_filter.py +++ b/tests/dataframe/test_filter.py @@ -41,3 +41,12 @@ def test_filter_sql() -> None: expected = {"x": [3], "y": [6], "z": [9]} assert df == expected + + +def test_filter_alias_for_where() -> None: + df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6], "z": [7, 9, 9]}) + + expected = df.where("z = 9 AND y > 5").collect().to_pydict() + actual = df.filter("z = 9 AND y > 5").collect().to_pydict() + + assert actual == expected