From 926ee0bd557fe0caa404a0a262efe8b23f2c88e6 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Tue, 19 Nov 2024 22:34:21 -0800 Subject: [PATCH] [FEAT] connect: add drop support --- .../src/translation/logical_plan.rs | 6 ++- .../src/translation/logical_plan/drop.rs | 39 +++++++++++++++++++ tests/connect/test_drop.py | 17 ++++++++ 3 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 src/daft-connect/src/translation/logical_plan/drop.rs create mode 100644 tests/connect/test_drop.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index bf95827649..ae3b99f590 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -8,11 +8,12 @@ use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; use crate::translation::logical_plan::{ - aggregate::aggregate, local_relation::local_relation, project::project, range::range, - to_df::to_df, with_columns::with_columns, + aggregate::aggregate, drop::drop, local_relation::local_relation, project::project, + range::range, to_df::to_df, with_columns::with_columns, }; mod aggregate; +mod drop; mod local_relation; mod project; mod range; @@ -57,6 +58,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::LocalRelation(l) => { local_relation(l).wrap_err("Failed to apply local_relation to logical plan") } + RelType::Drop(d) => drop(*d).wrap_err("Failed to apply drop to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/drop.rs b/src/daft-connect/src/translation/logical_plan/drop.rs new file mode 100644 index 0000000000..4de38f1768 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/drop.rs @@ -0,0 +1,39 @@ +use eyre::bail; + +use crate::translation::{to_logical_plan, Plan}; + +pub fn drop(drop: spark_connect::Drop) -> eyre::Result { + let spark_connect::Drop { + input, + columns, + column_names, + } = drop; + + let Some(input) = input else { + bail!("input is required"); + }; + + if !columns.is_empty() { + bail!("columns is not supported; use column_names instead"); + } + + let mut plan = to_logical_plan(*input)?; + + // Get all column names from the schema + let all_columns = plan.builder.schema().names(); + + // Create a set of columns to drop for efficient lookup + let columns_to_drop: std::collections::HashSet<_> = column_names.iter().collect(); + + // Create expressions for all columns except the ones being dropped + let to_select = all_columns + .iter() + .filter(|col_name| !columns_to_drop.contains(*col_name)) + .map(|col_name| daft_dsl::col(col_name.clone())) + .collect(); + + // Use select to keep only the columns we want + plan.builder = plan.builder.select(to_select)?; + + Ok(plan) +} diff --git a/tests/connect/test_drop.py b/tests/connect/test_drop.py new file mode 100644 index 0000000000..635f79c1d2 --- /dev/null +++ b/tests/connect/test_drop.py @@ -0,0 +1,17 @@ +from __future__ import annotations + + +def test_drop(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Drop the 'id' column + df_dropped = df.drop("id") + + # Verify the drop was successful + assert "id" not in df_dropped.columns, "Column 'id' should be dropped" + assert len(df_dropped.columns) == len(df.columns) - 1, "Should have one less column after drop" + + # Verify the DataFrame has no columns after dropping all columns" + assert len(df_dropped.toPandas().columns) == 0, "DataFrame should have no columns after dropping 'id'" + assert df_dropped.count() == df.count(), "Row count should be unchanged after drop"