From 0aa67fbdf4ee782efb5fc77fa95a2b53dbc85260 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 21:16:19 -0800 Subject: [PATCH 1/2] [FEAT] connect: `with_columns_renamed` --- .../src/translation/logical_plan.rs | 6 +++ .../logical_plan/with_columns_renamed.rs | 45 +++++++++++++++++++ tests/connect/test_with_columns_renamed.py | 24 ++++++++++ 3 files changed, 75 insertions(+) create mode 100644 src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs create mode 100644 tests/connect/test_with_columns_renamed.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 439f5bd551..df0ae6c2a4 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -18,6 +18,8 @@ use futures::TryStreamExt; use spark_connect::{relation::RelType, Limit, Relation, ShowString}; use tracing::warn; +use crate::translation::logical_plan::with_columns_renamed::with_columns_renamed; + mod aggregate; mod drop; mod filter; @@ -27,6 +29,7 @@ mod range; mod read; mod to_df; mod with_columns; +mod with_columns_renamed; pub struct SparkAnalyzer<'a> { pub psets: &'a InMemoryPartitionSetCache, @@ -110,6 +113,9 @@ impl SparkAnalyzer<'_> { self.local_relation(plan_id, l) .wrap_err("Failed to apply local_relation to logical plan") } + RelType::WithColumnsRenamed(w) => with_columns_renamed(*w) + .await + .wrap_err("Failed to apply with_columns_renamed to logical plan"), RelType::Read(r) => read::read(r) .await .wrap_err("Failed to apply read to logical plan"), diff --git a/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs new file mode 100644 index 0000000000..01c6493974 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs @@ -0,0 +1,45 @@ +use daft_dsl::col; +use eyre::{bail, Context}; + +use crate::translation::Plan; + +pub async fn with_columns_renamed( + with_columns_renamed: spark_connect::WithColumnsRenamed, +) -> eyre::Result { + let spark_connect::WithColumnsRenamed { + input, + rename_columns_map, + renames, + } = with_columns_renamed; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let mut plan = Box::pin(crate::translation::to_logical_plan(*input)).await?; + + // todo: let's implement this directly into daft + + // Convert the rename mappings into expressions + let rename_exprs = if !rename_columns_map.is_empty() { + // Use rename_columns_map if provided (legacy format) + rename_columns_map + .into_iter() + .map(|(old_name, new_name)| col(old_name.as_str()).alias(new_name.as_str())) + .collect() + } else { + // Use renames if provided (new format) + renames + .into_iter() + .map(|rename| col(rename.col_name.as_str()).alias(rename.new_col_name.as_str())) + .collect() + }; + + // Apply the rename expressions to the plan + plan.builder = plan + .builder + .select(rename_exprs) + .wrap_err("Failed to apply rename expressions to logical plan")?; + + Ok(plan) +} diff --git a/tests/connect/test_with_columns_renamed.py b/tests/connect/test_with_columns_renamed.py new file mode 100644 index 0000000000..124f142ca2 --- /dev/null +++ b/tests/connect/test_with_columns_renamed.py @@ -0,0 +1,24 @@ +from __future__ import annotations + + +def test_with_columns_renamed(spark_session): + # Test withColumnRenamed + df = spark_session.range(5) + renamed_df = df.withColumnRenamed("id", "number") + + collected = renamed_df.collect() + assert len(collected) == 5 + assert "number" in renamed_df.columns + assert "id" not in renamed_df.columns + assert [row["number"] for row in collected] == list(range(5)) + + # todo: this edge case is a spark connect bug; it will only send rename of id -> character over protobuf + # # Test withColumnsRenamed + # df = spark_session.range(2) + # renamed_df = df.withColumnsRenamed({"id": "number", "id": "character"}) + # + # collected = renamed_df.collect() + # assert len(collected) == 2 + # assert set(renamed_df.columns) == {"number", "character"} + # assert "id" not in renamed_df.columns + # assert [(row["number"], row["character"]) for row in collected] == [(0, 0), (1, 1)] From ce15fe093b3e874b40d989cfe10d89daa970e9d8 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Thu, 19 Dec 2024 09:10:58 -0800 Subject: [PATCH 2/2] slight restructure --- .../src/translation/logical_plan.rs | 5 +- .../logical_plan/with_columns_renamed.rs | 85 ++++++++++--------- 2 files changed, 46 insertions(+), 44 deletions(-) diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index df0ae6c2a4..5bf831756e 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -18,8 +18,6 @@ use futures::TryStreamExt; use spark_connect::{relation::RelType, Limit, Relation, ShowString}; use tracing::warn; -use crate::translation::logical_plan::with_columns_renamed::with_columns_renamed; - mod aggregate; mod drop; mod filter; @@ -113,7 +111,8 @@ impl SparkAnalyzer<'_> { self.local_relation(plan_id, l) .wrap_err("Failed to apply local_relation to logical plan") } - RelType::WithColumnsRenamed(w) => with_columns_renamed(*w) + RelType::WithColumnsRenamed(w) => self + .with_columns_renamed(*w) .await .wrap_err("Failed to apply with_columns_renamed to logical plan"), RelType::Read(r) => read::read(r) diff --git a/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs index 01c6493974..856a7214fc 100644 --- a/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs +++ b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs @@ -1,45 +1,48 @@ use daft_dsl::col; +use daft_logical_plan::LogicalPlanBuilder; use eyre::{bail, Context}; -use crate::translation::Plan; - -pub async fn with_columns_renamed( - with_columns_renamed: spark_connect::WithColumnsRenamed, -) -> eyre::Result { - let spark_connect::WithColumnsRenamed { - input, - rename_columns_map, - renames, - } = with_columns_renamed; - - let Some(input) = input else { - bail!("Input is required"); - }; - - let mut plan = Box::pin(crate::translation::to_logical_plan(*input)).await?; - - // todo: let's implement this directly into daft - - // Convert the rename mappings into expressions - let rename_exprs = if !rename_columns_map.is_empty() { - // Use rename_columns_map if provided (legacy format) - rename_columns_map - .into_iter() - .map(|(old_name, new_name)| col(old_name.as_str()).alias(new_name.as_str())) - .collect() - } else { - // Use renames if provided (new format) - renames - .into_iter() - .map(|rename| col(rename.col_name.as_str()).alias(rename.new_col_name.as_str())) - .collect() - }; - - // Apply the rename expressions to the plan - plan.builder = plan - .builder - .select(rename_exprs) - .wrap_err("Failed to apply rename expressions to logical plan")?; - - Ok(plan) +use crate::translation::SparkAnalyzer; + +impl SparkAnalyzer<'_> { + pub async fn with_columns_renamed( + &self, + with_columns_renamed: spark_connect::WithColumnsRenamed, + ) -> eyre::Result { + let spark_connect::WithColumnsRenamed { + input, + rename_columns_map, + renames, + } = with_columns_renamed; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + + // todo: let's implement this directly into daft + + // Convert the rename mappings into expressions + let rename_exprs = if !rename_columns_map.is_empty() { + // Use rename_columns_map if provided (legacy format) + rename_columns_map + .into_iter() + .map(|(old_name, new_name)| col(old_name.as_str()).alias(new_name.as_str())) + .collect() + } else { + // Use renames if provided (new format) + renames + .into_iter() + .map(|rename| col(rename.col_name.as_str()).alias(rename.new_col_name.as_str())) + .collect() + }; + + // Apply the rename expressions to the plan + let plan = plan + .select(rename_exprs) + .wrap_err("Failed to apply rename expressions to logical plan")?; + + Ok(plan) + } }