Skip to content

Commit

Permalink
[FEAT] connect: add df.filter
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 9, 2024
1 parent 6390afa commit abfed9f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use spark_connect::{relation::RelType, Limit, Relation};
use tracing::warn;

use crate::translation::logical_plan::{
aggregate::aggregate, local_relation::local_relation, project::project, range::range,
to_df::to_df, with_columns::with_columns,
aggregate::aggregate, filter::filter, local_relation::local_relation, project::project,
range::range, to_df::to_df, with_columns::with_columns,
};

mod aggregate;
mod filter;
mod local_relation;
mod project;
mod range;
Expand Down Expand Up @@ -54,6 +55,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<Plan> {
RelType::Limit(l) => limit(*l).wrap_err("Failed to apply limit to logical plan"),
RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"),
RelType::Project(p) => project(*p).wrap_err("Failed to apply project to logical plan"),
RelType::Filter(f) => filter(*f).wrap_err("Failed to apply filter to logical plan"),
RelType::Aggregate(a) => {
aggregate(*a).wrap_err("Failed to apply aggregate to logical plan")
}
Expand Down
22 changes: 22 additions & 0 deletions src/daft-connect/src/translation/logical_plan/filter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use eyre::bail;

use crate::translation::{to_daft_expr, to_logical_plan, Plan};

pub fn filter(filter: spark_connect::Filter) -> eyre::Result<Plan> {
let spark_connect::Filter { input, condition } = filter;

let Some(input) = input else {
bail!("input is required");

Check warning on line 9 in src/daft-connect/src/translation/logical_plan/filter.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/filter.rs#L9

Added line #L9 was not covered by tests
};

let Some(condition) = condition else {
bail!("condition is required");

Check warning on line 13 in src/daft-connect/src/translation/logical_plan/filter.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/filter.rs#L13

Added line #L13 was not covered by tests
};

let condition = to_daft_expr(&condition)?;

let mut plan = to_logical_plan(*input)?;
plan.builder = plan.builder.filter(condition)?;

Ok(plan)
}
19 changes: 19 additions & 0 deletions tests/connect/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from pyspark.sql.functions import col


def test_filter(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Filter for values less than 5
df_filtered = df.filter(col("id") < 5)

# Verify the schema is unchanged after filter
assert df_filtered.schema == df.schema, "Schema should be unchanged after filter"

# Verify the filtered data is correct
df_filtered_pandas = df_filtered.toPandas()
assert len(df_filtered_pandas) == 5, "Should have 5 rows after filtering < 5"
assert all(df_filtered_pandas["id"] < 5), "All values should be less than 5"

0 comments on commit abfed9f

Please sign in to comment.