From dd5e05c6fc620962860ce415211a815663765a53 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 01:25:34 -0800 Subject: [PATCH] [FEAT] connect: `df.join` --- .../src/translation/logical_plan.rs | 6 +- .../src/translation/logical_plan/join.rs | 76 +++++++++++++++++++ tests/connect/test_join.py | 26 +++++++ 3 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 src/daft-connect/src/translation/logical_plan/join.rs create mode 100644 tests/connect/test_join.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 93c9e9bd4a..380d95b26d 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,9 +3,12 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{ + aggregate::aggregate, join::join, project::project, range::range, +}; mod aggregate; +mod join; mod project; mod range; @@ -25,6 +28,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Aggregate(a) => { aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") } + RelType::Join(j) => join(*j).wrap_err("Failed to apply join to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/join.rs b/src/daft-connect/src/translation/logical_plan/join.rs new file mode 100644 index 0000000000..f4761f68ce --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/join.rs @@ -0,0 +1,76 @@ +use daft_logical_plan::LogicalPlanBuilder; +use eyre::{bail, WrapErr}; +use spark_connect::join::JoinType; +use tracing::warn; + +use crate::translation::to_logical_plan; + +pub fn join(join: spark_connect::Join) -> eyre::Result { + let spark_connect::Join { + left, + right, + join_condition, + join_type, + using_columns, + join_data_type, + } = join; + + let Some(left) = left else { + bail!("Left side of join is required"); + }; + + let Some(right) = right else { + bail!("Right side of join is required"); + }; + + if let Some(join_condition) = join_condition { + bail!("Join conditions are not yet supported; use using_columns (join keys) instead; got {join_condition:?}"); + } + + let join_type = JoinType::try_from(join_type) + .wrap_err_with(|| format!("Invalid join type: {join_type:?}"))?; + + let join_type = to_daft_join_type(join_type)?; + + let using_columns_exprs: Vec<_> = using_columns + .iter() + .map(|s| daft_dsl::col(s.as_str())) + .collect(); + + if let Some(join_data_type) = join_data_type { + warn!("Ignoring join data type {join_data_type:?} for join; not yet implemented"); + } + + let left = to_logical_plan(*left)?; + let right = to_logical_plan(*right)?; + + Ok(left.join( + &right, + // join_conditions.clone(), // todo(correctness): is this correct? + // join_conditions, // todo(correctness): is this correct? + using_columns_exprs.clone(), + using_columns_exprs, + join_type, + None, + None, + None, + false, // todo(correctness): we want join keys or not + )?) +} + +fn to_daft_join_type(join_type: JoinType) -> eyre::Result { + match join_type { + JoinType::Unspecified => { + bail!("Join type must be specified; got Unspecified") + } + JoinType::Inner => Ok(daft_core::join::JoinType::Inner), + JoinType::FullOuter => { + bail!("Full outer joins not yet supported") // todo(completeness): add support for full outer joins if it is not already implemented + } + JoinType::LeftOuter => Ok(daft_core::join::JoinType::Left), // todo(correctness): is this correct? + JoinType::RightOuter => Ok(daft_core::join::JoinType::Right), + JoinType::LeftAnti => Ok(daft_core::join::JoinType::Anti), // todo(correctness): is this correct? + JoinType::LeftSemi => bail!("Left semi joins not yet supported"), // todo(completeness): add support for left semi joins if it is not already implemented + JoinType::Cross => bail!("Cross joins not yet supported"), // todo(completeness): add support for cross joins if it is not already implemented + } +} diff --git a/tests/connect/test_join.py b/tests/connect/test_join.py new file mode 100644 index 0000000000..45e66c769b --- /dev/null +++ b/tests/connect/test_join.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_join(spark_session): + # Create two DataFrames with overlapping IDs + df1 = spark_session.range(5) + df2 = spark_session.range(3, 7) + + # Perform inner join on 'id' column + joined_df = df1.join(df2, "id", "inner") + + # Verify join results using collect() + joined_ids = {row.id for row in joined_df.select("id").collect()} + assert joined_ids == {3, 4}, "Inner join should only contain IDs 3 and 4" + + # Test left outer join + left_joined_df = df1.join(df2, "id", "left") + left_joined_ids = {row.id for row in left_joined_df.select("id").collect()} + assert left_joined_ids == {0, 1, 2, 3, 4}, "Left join should keep all rows from left DataFrame" + + # Test right outer join + right_joined_df = df1.join(df2, "id", "right") + right_joined_ids = {row.id for row in right_joined_df.select("id").collect()} + assert right_joined_ids == {3, 4, 5, 6}, "Right join should keep all rows from right DataFrame"