diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index c05ca41e20..59bdc41d4c 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -6,7 +6,7 @@ use tracing::warn; use crate::translation::logical_plan::{ aggregate::aggregate, local_relation::local_relation, project::project, range::range, - read::read, to_df::to_df, with_columns::with_columns, + read::read, set_op::set_op, to_df::to_df, with_columns::with_columns, }; mod aggregate; @@ -14,6 +14,7 @@ mod local_relation; mod project; mod range; mod read; +mod set_op; mod to_df; mod with_columns; @@ -74,6 +75,9 @@ pub async fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Read(r) => read(r) .await .wrap_err("Failed to apply read to logical plan"), + RelType::SetOp(s) => set_op(*s) + .await + .wrap_err("Failed to apply set_op to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/set_op.rs b/src/daft-connect/src/translation/logical_plan/set_op.rs new file mode 100644 index 0000000000..066e4e44c5 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/set_op.rs @@ -0,0 +1,65 @@ +use eyre::{bail, Context}; +use spark_connect::set_operation::SetOpType; +use tracing::warn; + +use crate::translation::{to_logical_plan, Plan}; + +pub async fn set_op(set_op: spark_connect::SetOperation) -> eyre::Result { + let spark_connect::SetOperation { + left_input, + right_input, + set_op_type, + is_all, + by_name, + allow_missing_columns, + } = set_op; + + let Some(left_input) = left_input else { + bail!("Left input is required"); + }; + + let Some(right_input) = right_input else { + bail!("Right input is required"); + }; + + let set_op = SetOpType::try_from(set_op_type) + .wrap_err_with(|| format!("Invalid set operation type: {set_op_type}"))?; + + if let Some(by_name) = by_name { + warn!("Ignoring by_name: {by_name}"); + } + + if let Some(allow_missing_columns) = allow_missing_columns { + warn!("Ignoring allow_missing_columns: {allow_missing_columns}"); + } + + let mut left = Box::pin(to_logical_plan(*left_input)).await?; + let right = Box::pin(to_logical_plan(*right_input)).await?; + + left.psets.partitions.extend(right.psets.partitions); + + let is_all = is_all.unwrap_or(false); + + let builder = match set_op { + SetOpType::Unspecified => { + bail!("Unspecified set operation is not supported"); + } + SetOpType::Intersect => left + .builder + .intersect(&right.builder, is_all) + .wrap_err("Failed to apply intersect to logical plan"), + SetOpType::Union => left + .builder + .union(&right.builder, is_all) + .wrap_err("Failed to apply union to logical plan"), + SetOpType::Except => { + bail!("Except set operation is not supported"); + } + }?; + + // we merged left and right psets + Ok(Plan { + builder, + psets: left.psets, + }) +} diff --git a/tests/connect/test_intersection.py b/tests/connect/test_intersection.py new file mode 100644 index 0000000000..200f391f39 --- /dev/null +++ b/tests/connect/test_intersection.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_intersection(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Intersect the two ranges + intersected = range1.intersect(range2) + + # Collect results + results = intersected.collect() + + # Verify the DataFrame has expected values + # Intersection should only include overlapping values once + assert len(results) == 4, "DataFrame should have 4 rows (overlapping values 3,4,5,6)" + + # Check that all expected values are present + values = [row.id for row in results] + assert sorted(values) == [3, 4, 5, 6], "Values should match expected overlapping sequence" diff --git a/tests/connect/test_union.py b/tests/connect/test_union.py new file mode 100644 index 0000000000..34157fd2c1 --- /dev/null +++ b/tests/connect/test_union.py @@ -0,0 +1,36 @@ +from __future__ import annotations + + +def test_union(spark_session): + # Create ranges using Spark - with overlap + range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6 + range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9 + + # Union the two ranges + unioned = range1.union(range2) + + # Collect results + results = unioned.collect() + + # Verify the DataFrame has expected values + # Union includes duplicates, so length should be sum of both ranges + assert len(results) == 14, "DataFrame should have 14 rows (7 + 7)" + + # Check that all expected values are present, including duplicates + values = [row.id for row in results] + assert sorted(values) == [ + 0, + 1, + 2, + 3, + 3, + 4, + 4, + 5, + 5, + 6, + 6, + 7, + 8, + 9, + ], "Values should match expected sequence with duplicates"