Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(connect): add df.{intersection,union} #3373

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ mod local_relation;
mod project;
mod range;
mod read;
mod set_op;
mod to_df;
mod with_columns;

Expand Down Expand Up @@ -129,6 +130,9 @@ impl SparkAnalyzer<'_> {
.await
.wrap_err("Failed to show string")
}
RelType::SetOp(s) => set_op(*s)
.await
.wrap_err("Failed to apply set_op to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
Expand Down
65 changes: 65 additions & 0 deletions src/daft-connect/src/translation/logical_plan/set_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use eyre::{bail, Context};
use spark_connect::set_operation::SetOpType;
use tracing::warn;

use crate::translation::{to_logical_plan, Plan};

pub async fn set_op(set_op: spark_connect::SetOperation) -> eyre::Result<Plan> {
let spark_connect::SetOperation {
left_input,
right_input,
set_op_type,
is_all,
by_name,
allow_missing_columns,
} = set_op;

let Some(left_input) = left_input else {
bail!("Left input is required");
};

let Some(right_input) = right_input else {
bail!("Right input is required");
};

let set_op = SetOpType::try_from(set_op_type)
.wrap_err_with(|| format!("Invalid set operation type: {set_op_type}"))?;

if let Some(by_name) = by_name {
warn!("Ignoring by_name: {by_name}");
}

if let Some(allow_missing_columns) = allow_missing_columns {
warn!("Ignoring allow_missing_columns: {allow_missing_columns}");
}

let mut left = Box::pin(to_logical_plan(*left_input)).await?;
let right = Box::pin(to_logical_plan(*right_input)).await?;

left.psets.partitions.extend(right.psets.partitions);

let is_all = is_all.unwrap_or(false);

let builder = match set_op {
SetOpType::Unspecified => {
bail!("Unspecified set operation is not supported");
}
SetOpType::Intersect => left
.builder
.intersect(&right.builder, is_all)
.wrap_err("Failed to apply intersect to logical plan"),
SetOpType::Union => left
.builder
.union(&right.builder, is_all)
.wrap_err("Failed to apply union to logical plan"),
SetOpType::Except => {
bail!("Except set operation is not supported");
}
}?;

// we merged left and right psets
Ok(Plan {
builder,
psets: left.psets,
})
}
21 changes: 21 additions & 0 deletions tests/connect/test_intersection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations


def test_intersection(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Intersect the two ranges
intersected = range1.intersect(range2)

# Collect results
results = intersected.collect()

# Verify the DataFrame has expected values
# Intersection should only include overlapping values once
assert len(results) == 4, "DataFrame should have 4 rows (overlapping values 3,4,5,6)"

# Check that all expected values are present
values = [row.id for row in results]
assert sorted(values) == [3, 4, 5, 6], "Values should match expected overlapping sequence"
36 changes: 36 additions & 0 deletions tests/connect/test_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations


def test_union(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Union the two ranges
unioned = range1.union(range2)

# Collect results
results = unioned.collect()

# Verify the DataFrame has expected values
# Union includes duplicates, so length should be sum of both ranges
assert len(results) == 14, "DataFrame should have 14 rows (7 + 7)"

# Check that all expected values are present, including duplicates
values = [row.id for row in results]
assert sorted(values) == [
0,
1,
2,
3,
3,
4,
4,
5,
5,
6,
6,
7,
8,
9,
], "Values should match expected sequence with duplicates"
Loading