From fecfc7dac353a7e4308fc8f47cfe44055e0033d7 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Tue, 19 Nov 2024 23:15:27 -0800 Subject: [PATCH] [FEAT] connect: add `df.filter` --- .../src/translation/logical_plan.rs | 8 +++++-- .../src/translation/logical_plan/filter.rs | 22 +++++++++++++++++++ tests/connect/test_filter.py | 19 ++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 src/daft-connect/src/translation/logical_plan/filter.rs create mode 100644 tests/connect/test_filter.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index c05ca41e20..a425c600ea 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -5,11 +5,12 @@ use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; use crate::translation::logical_plan::{ - aggregate::aggregate, local_relation::local_relation, project::project, range::range, - read::read, to_df::to_df, with_columns::with_columns, + aggregate::aggregate, filter::filter, local_relation::local_relation, project::project, + range::range, read::read, to_df::to_df, with_columns::with_columns, }; mod aggregate; +mod filter; mod local_relation; mod project; mod range; @@ -59,6 +60,9 @@ pub async fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Project(p) => project(*p) .await .wrap_err("Failed to apply project to logical plan"), + RelType::Filter(f) => filter(*f) + .await + .wrap_err("Failed to apply filter to logical plan"), RelType::Aggregate(a) => aggregate(*a) .await .wrap_err("Failed to apply aggregate to logical plan"), diff --git a/src/daft-connect/src/translation/logical_plan/filter.rs b/src/daft-connect/src/translation/logical_plan/filter.rs new file mode 100644 index 0000000000..6879464abc --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/filter.rs @@ -0,0 +1,22 @@ +use eyre::bail; + +use crate::translation::{to_daft_expr, to_logical_plan, Plan}; + +pub async fn filter(filter: spark_connect::Filter) -> eyre::Result { + let spark_connect::Filter { input, condition } = filter; + + let Some(input) = input else { + bail!("input is required"); + }; + + let Some(condition) = condition else { + bail!("condition is required"); + }; + + let condition = to_daft_expr(&condition)?; + + let mut plan = Box::pin(to_logical_plan(*input)).await?; + plan.builder = plan.builder.filter(condition)?; + + Ok(plan) +} diff --git a/tests/connect/test_filter.py b/tests/connect/test_filter.py new file mode 100644 index 0000000000..1586c7e7b5 --- /dev/null +++ b/tests/connect/test_filter.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_filter(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Filter for values less than 5 + df_filtered = df.filter(col("id") < 5) + + # Verify the schema is unchanged after filter + assert df_filtered.schema == df.schema, "Schema should be unchanged after filter" + + # Verify the filtered data is correct + df_filtered_pandas = df_filtered.toPandas() + assert len(df_filtered_pandas) == 5, "Should have 5 rows after filtering < 5" + assert all(df_filtered_pandas["id"] < 5), "All values should be less than 5"