From 1380cd4562914cfe3282ecec00a3cc8937c53b5b Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 03:27:48 -0800 Subject: [PATCH] [FEAT] (WIP) connect: createDataFrame --- .../src/translation/logical_plan.rs | 7 +++++- .../src/translation/logical_plan/to_df.rs | 25 +++++++++++++++++++ tests/connect/test_create_df.py | 12 +++++++++ tests/connect/test_distinct.py | 17 +++++++++++++ 4 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 src/daft-connect/src/translation/logical_plan/to_df.rs create mode 100644 tests/connect/test_create_df.py create mode 100644 tests/connect/test_distinct.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 947e0cd0d3..b840407961 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,11 +3,15 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{ + aggregate::aggregate, project::project, range::range, to_df::to_df, +}; mod aggregate; mod project; mod range; +mod to_df; +mod pub fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { @@ -24,6 +28,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Aggregate(a) => { aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") } + RelType::ToDf(t) => to_df(*t).wrap_err("Failed to apply to_df to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/to_df.rs b/src/daft-connect/src/translation/logical_plan/to_df.rs new file mode 100644 index 0000000000..24949d14ba --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/to_df.rs @@ -0,0 +1,25 @@ +use eyre::{bail, WrapErr}; + +use crate::translation::to_logical_plan; + +pub fn to_df(to_df: spark_connect::ToDf) -> eyre::Result { + let spark_connect::ToDf { + input, + column_names, + } = to_df; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let plan = to_logical_plan(*input) + .wrap_err_with(|| format!("Failed to translate relation to logical plan: {input:?}"))?; + + let column_names: Vec<_> = column_names.iter().map(daft_dsl::col).collect(); + + let plan = plan + .with_columns(column_names) + .wrap_err_with(|| format!("Failed to add columns to logical plan: {column_names:?}"))?; + + Ok(plan) +} diff --git a/tests/connect/test_create_df.py b/tests/connect/test_create_df.py new file mode 100644 index 0000000000..de2315a84d --- /dev/null +++ b/tests/connect/test_create_df.py @@ -0,0 +1,12 @@ +from __future__ import annotations + + +def test_create_df(spark_session): + # Create simple DataFrame + data = [(1,), (2,), (3,)] + df = spark_session.createDataFrame(data, ["id"]) + + # Convert to pandas + df_pandas = df.toPandas() + assert len(df_pandas) == 3, "DataFrame should have 3 rows" + assert list(df_pandas["id"]) == [1, 2, 3], "DataFrame should contain expected values" diff --git a/tests/connect/test_distinct.py b/tests/connect/test_distinct.py new file mode 100644 index 0000000000..3531f7c528 --- /dev/null +++ b/tests/connect/test_distinct.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_distinct(spark_session): + # Create DataFrame with duplicates + data = [(1,), (1,), (2,), (2,), (3,)] + df = spark_session.createDataFrame(data, ["id"]) + + # Get distinct rows + df_distinct = df.distinct() + + # Verify distinct operation removed duplicates + df_distinct_pandas = df_distinct.toPandas() + assert len(df_distinct_pandas) == 3, "Distinct should remove duplicates" + assert set(df_distinct_pandas["id"]) == {1, 2, 3}, "Distinct values should be preserved"