From 07f6b2c82ad5f8a288e6c48db58b23e1d105c558 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Thu, 19 Dec 2024 10:54:46 -0800 Subject: [PATCH] feat(connect): `with_columns_renamed` (#3386) --- .../src/translation/logical_plan.rs | 5 ++ .../logical_plan/with_columns_renamed.rs | 48 +++++++++++++++++++ tests/connect/test_with_columns_renamed.py | 24 ++++++++++ 3 files changed, 77 insertions(+) create mode 100644 src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs create mode 100644 tests/connect/test_with_columns_renamed.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 439f5bd551..5bf831756e 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -27,6 +27,7 @@ mod range; mod read; mod to_df; mod with_columns; +mod with_columns_renamed; pub struct SparkAnalyzer<'a> { pub psets: &'a InMemoryPartitionSetCache, @@ -110,6 +111,10 @@ impl SparkAnalyzer<'_> { self.local_relation(plan_id, l) .wrap_err("Failed to apply local_relation to logical plan") } + RelType::WithColumnsRenamed(w) => self + .with_columns_renamed(*w) + .await + .wrap_err("Failed to apply with_columns_renamed to logical plan"), RelType::Read(r) => read::read(r) .await .wrap_err("Failed to apply read to logical plan"), diff --git a/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs new file mode 100644 index 0000000000..856a7214fc --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs @@ -0,0 +1,48 @@ +use daft_dsl::col; +use daft_logical_plan::LogicalPlanBuilder; +use eyre::{bail, Context}; + +use crate::translation::SparkAnalyzer; + +impl SparkAnalyzer<'_> { + pub async fn with_columns_renamed( + &self, + with_columns_renamed: spark_connect::WithColumnsRenamed, + ) -> eyre::Result { + let spark_connect::WithColumnsRenamed { + input, + rename_columns_map, + renames, + } = with_columns_renamed; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + + // todo: let's implement this directly into daft + + // Convert the rename mappings into expressions + let rename_exprs = if !rename_columns_map.is_empty() { + // Use rename_columns_map if provided (legacy format) + rename_columns_map + .into_iter() + .map(|(old_name, new_name)| col(old_name.as_str()).alias(new_name.as_str())) + .collect() + } else { + // Use renames if provided (new format) + renames + .into_iter() + .map(|rename| col(rename.col_name.as_str()).alias(rename.new_col_name.as_str())) + .collect() + }; + + // Apply the rename expressions to the plan + let plan = plan + .select(rename_exprs) + .wrap_err("Failed to apply rename expressions to logical plan")?; + + Ok(plan) + } +} diff --git a/tests/connect/test_with_columns_renamed.py b/tests/connect/test_with_columns_renamed.py new file mode 100644 index 0000000000..124f142ca2 --- /dev/null +++ b/tests/connect/test_with_columns_renamed.py @@ -0,0 +1,24 @@ +from __future__ import annotations + + +def test_with_columns_renamed(spark_session): + # Test withColumnRenamed + df = spark_session.range(5) + renamed_df = df.withColumnRenamed("id", "number") + + collected = renamed_df.collect() + assert len(collected) == 5 + assert "number" in renamed_df.columns + assert "id" not in renamed_df.columns + assert [row["number"] for row in collected] == list(range(5)) + + # todo: this edge case is a spark connect bug; it will only send rename of id -> character over protobuf + # # Test withColumnsRenamed + # df = spark_session.range(2) + # renamed_df = df.withColumnsRenamed({"id": "number", "id": "character"}) + # + # collected = renamed_df.collect() + # assert len(collected) == 2 + # assert set(renamed_df.columns) == {"number", "character"} + # assert "id" not in renamed_df.columns + # assert [(row["number"], row["character"]) for row in collected] == [(0, 0), (1, 1)]