From 7f6b1e3061542a78de1056cb6d10bf9afc74439c Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 16:55:15 -0800 Subject: [PATCH] [FEAT] connect: `createDataFrame` (WIP) help needed --- .../src/translation/logical_plan.rs | 6 ++++- .../logical_plan/local_relation.rs | 10 +++++++ .../src/translation/logical_plan/to_df.rs | 26 +++++++++++++++++++ tests/connect/test_create_df.py | 13 ++++++++++ 4 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 src/daft-connect/src/translation/logical_plan/local_relation.rs create mode 100644 src/daft-connect/src/translation/logical_plan/to_df.rs create mode 100644 tests/connect/test_create_df.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 93c9e9bd4a..bcdf7e67f8 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,11 +3,13 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range, to_df::to_df}; mod aggregate; mod project; mod range; +mod to_df; +mod local_relation; pub fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { @@ -25,6 +27,8 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Aggregate(a) => { aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") } + RelType::ToDf(t) => to_df(*t).wrap_err("Failed to apply to_df to logical plan"), + RelType::LocalRelation(l) => local_relation(*l).wrap_err("Failed to apply local_relation to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/local_relation.rs b/src/daft-connect/src/translation/logical_plan/local_relation.rs new file mode 100644 index 0000000000..3841e0afe4 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/local_relation.rs @@ -0,0 +1,10 @@ +use daft_logical_plan::LogicalPlanBuilder; + +pub fn local_relation( + local_relation: spark_connect::LocalRelation, +) -> eyre::Result { + let spark_connect::LocalRelation { + data, + schema, + } = local_relation; +} diff --git a/src/daft-connect/src/translation/logical_plan/to_df.rs b/src/daft-connect/src/translation/logical_plan/to_df.rs new file mode 100644 index 0000000000..91f02de488 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/to_df.rs @@ -0,0 +1,26 @@ +use eyre::{bail, WrapErr}; +use daft_logical_plan::LogicalPlanBuilder; +use crate::translation::to_logical_plan; + +pub fn to_df(to_df: spark_connect::ToDf) -> eyre::Result { + let spark_connect::ToDf { + input, + column_names, + } = to_df; + + let Some(input) = input else { + bail!("Input is required") + }; + + let plan = to_logical_plan(*input) + .wrap_err_with(|| format!("Failed to translate relation to logical plan: {input:?}"))?; + + let column_names: Vec<_> = column_names + .iter() + .map(|name| daft_dsl::col(name)) + .collect(); + + let plan = plan.with_columns(column_names)? + + Ok(plan) +} diff --git a/tests/connect/test_create_df.py b/tests/connect/test_create_df.py new file mode 100644 index 0000000000..95da8e33be --- /dev/null +++ b/tests/connect/test_create_df.py @@ -0,0 +1,13 @@ +from __future__ import annotations + + +def test_create_df(spark_session): + # Create a DataFrame with duplicate values + data = [(1,), (2,), (2,), (3,), (3,), (3,)] + df = spark_session.createDataFrame(data, ["value"]) + + # Collect and verify results + result = df.collect() + + # Verify the DataFrame has the expected number of rows and values + assert sorted([row.value for row in result]) == [1, 2, 2, 3, 3, 3]