Skip to content

Commit

Permalink
[FEAT] connect: add df.{intersection,union}
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 5, 2024
1 parent 9739bb6 commit 8efe7d7
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ use tracing::warn;

use crate::translation::logical_plan::{
aggregate::aggregate, local_relation::local_relation, project::project, range::range,
to_df::to_df, with_columns::with_columns,
set_op::set_op, to_df::to_df, with_columns::with_columns,
};

mod aggregate;
mod local_relation;
mod project;
mod range;
mod set_op;
mod to_df;
mod with_columns;

Expand Down Expand Up @@ -59,6 +60,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<Plan> {
RelType::LocalRelation(l) => {
local_relation(l).wrap_err("Failed to apply local_relation to logical plan")
}
RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
Expand Down
62 changes: 62 additions & 0 deletions src/daft-connect/src/translation/logical_plan/set_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use eyre::{bail, Context};
use spark_connect::set_operation::SetOpType;
use tracing::warn;

use crate::translation::{to_logical_plan, Plan};

pub fn set_op(set_op: spark_connect::SetOperation) -> eyre::Result<Plan> {
let spark_connect::SetOperation {
left_input,
right_input,
set_op_type,
is_all,
by_name,
allow_missing_columns,
} = set_op;

let Some(left_input) = left_input else {
bail!("Left input is required");

Check warning on line 18 in src/daft-connect/src/translation/logical_plan/set_op.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/set_op.rs#L18

Added line #L18 was not covered by tests
};

let Some(right_input) = right_input else {
bail!("Right input is required");

Check warning on line 22 in src/daft-connect/src/translation/logical_plan/set_op.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/set_op.rs#L22

Added line #L22 was not covered by tests
};

let set_op = SetOpType::try_from(set_op_type)
.wrap_err_with(|| format!("Invalid set operation type: {set_op_type}"))?;

if let Some(by_name) = by_name {
warn!("Ignoring by_name: {by_name}");
}

Check warning on line 30 in src/daft-connect/src/translation/logical_plan/set_op.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/set_op.rs#L30

Added line #L30 was not covered by tests

if let Some(allow_missing_columns) = allow_missing_columns {
warn!("Ignoring allow_missing_columns: {allow_missing_columns}");
}

Check warning on line 34 in src/daft-connect/src/translation/logical_plan/set_op.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/set_op.rs#L34

Added line #L34 was not covered by tests

let mut left = to_logical_plan(*left_input)?;
let right = to_logical_plan(*right_input)?;

left.psets.extend(right.psets);

let is_all = is_all.unwrap_or(false);

let builder = match set_op {
SetOpType::Unspecified => {
bail!("Unspecified set operation is not supported");

Check warning on line 45 in src/daft-connect/src/translation/logical_plan/set_op.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/set_op.rs#L45

Added line #L45 was not covered by tests
}
SetOpType::Intersect => left
.builder
.intersect(&right.builder, is_all)
.wrap_err("Failed to apply intersect to logical plan"),
SetOpType::Union => left
.builder
.union(&right.builder, is_all)
.wrap_err("Failed to apply union to logical plan"),
SetOpType::Except => {
bail!("Except set operation is not supported");

Check warning on line 56 in src/daft-connect/src/translation/logical_plan/set_op.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/set_op.rs#L56

Added line #L56 was not covered by tests
}
}?;

Check warning on line 58 in src/daft-connect/src/translation/logical_plan/set_op.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/set_op.rs#L58

Added line #L58 was not covered by tests

// we merged left and right psets
Ok(Plan::new(builder, left.psets))
}
21 changes: 21 additions & 0 deletions tests/connect/test_intersection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations


def test_intersection(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Intersect the two ranges
intersected = range1.intersect(range2)

# Collect results
results = intersected.collect()

# Verify the DataFrame has expected values
# Intersection should only include overlapping values once
assert len(results) == 4, "DataFrame should have 4 rows (overlapping values 3,4,5,6)"

# Check that all expected values are present
values = [row.id for row in results]
assert sorted(values) == [3, 4, 5, 6], "Values should match expected overlapping sequence"
36 changes: 36 additions & 0 deletions tests/connect/test_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations


def test_union(spark_session):
# Create ranges using Spark - with overlap
range1 = spark_session.range(7) # Creates DataFrame with numbers 0 to 6
range2 = spark_session.range(3, 10) # Creates DataFrame with numbers 3 to 9

# Union the two ranges
unioned = range1.union(range2)

# Collect results
results = unioned.collect()

# Verify the DataFrame has expected values
# Union includes duplicates, so length should be sum of both ranges
assert len(results) == 14, "DataFrame should have 14 rows (7 + 7)"

# Check that all expected values are present, including duplicates
values = [row.id for row in results]
assert sorted(values) == [
0,
1,
2,
3,
3,
4,
4,
5,
5,
6,
6,
7,
8,
9,
], "Values should match expected sequence with duplicates"

0 comments on commit 8efe7d7

Please sign in to comment.