Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Support intersect as a DataFrame API #3134

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1742,6 +1742,7 @@ class LogicalPlanBuilder:
join_suffix: str | None = None,
) -> LogicalPlanBuilder: ...
def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ...
def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ...
def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: ...
def table_write(
self,
Expand Down
30 changes: 30 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2474,6 +2474,36 @@ def pivot(
builder = self._builder.pivot(group_by_expr, pivot_col_expr, value_col_expr, agg_expr, names)
return DataFrame(builder)

@DataframePublicAPI
def intersect(self, other: "DataFrame") -> "DataFrame":
"""Returns the intersection of two DataFrames.

Example:
>>> import daft
>>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> df2 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 8, 6]})
>>> df1.intersect(df2).collect()
╭───────┬───────╮
│ a ┆ b │
│ --- ┆ --- │
│ Int64 ┆ Int64 │
╞═══════╪═══════╡
│ 1 ┆ 4 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 6 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 2 of 2 rows)

Args:
other (DataFrame): DataFrame to intersect with

Returns:
DataFrame: DataFrame with the intersection of the two DataFrames
"""
builder = self._builder.intersect(other._builder)
return DataFrame(builder)

def _materialize_results(self) -> None:
"""Materializes the results of for this DataFrame and hold a pointer to the results."""
context = get_context()
Expand Down
4 changes: 4 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: # type: igno
builder = self._builder.concat(other._builder)
return LogicalPlanBuilder(builder)

def intersect(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder:
builder = self._builder.intersect(other._builder, False)
return LogicalPlanBuilder(builder)

def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder:
builder = self._builder.add_monotonically_increasing_id(column_name)
return LogicalPlanBuilder(builder)
Expand Down
11 changes: 11 additions & 0 deletions src/daft-logical-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,13 @@ impl LogicalPlanBuilder {
Ok(self.with_new_plan(logical_plan))
}

pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::Intersect::try_new(self.plan.clone(), other.plan.clone(), is_all)?
.to_optimized_join()?;
Ok(self.with_new_plan(logical_plan))
}
Comment on lines +509 to +514
Copy link
Contributor

@universalmind303 universalmind303 Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of immediately converting it to a join, I think we should instead defer this until translate stage.

If we want to have a concept of a logical Intersect as discussed here, then I don't think we should immediately convert it to a logical join, but only convert it once to a physical join during the logical -> physical translation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should instead defer this until translate stage.

That's definitely an option. And DuckDB does that indeed: https://github.com/duckdb/duckdb/blob/a2dce8b1c9fa6039c82e9a32bfcc4c49b03ca871/src/execution/physical_plan/plan_set_operation.cpp#L53. The problem is like you stated in #3241 (comment), it will make the optimization rules more complex: a lot of rules should be amend to be set-operations aware.

On the other side, Spark doesn't defer this to translate stage, the intersect/except operators are optimized during the optimization phase, which makes the optimization phase simpler(at least for set-operation handling).

From my perspective, I agree with you that we should convert set operations to join early to avoid complexing optimization phase currently. I just want to decouple the logical into a separate func/struct and make sure it's extensible for long term plan. I can avoid define these operations if that doesn't sound right to you. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, Lets go with what's in here now. We can always revisit it needed 😅


pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> DaftResult<Self> {
let logical_plan: LogicalPlan =
ops::MonotonicallyIncreasingId::new(self.plan.clone(), column_name).into();
Expand Down Expand Up @@ -945,6 +952,10 @@ impl PyLogicalPlanBuilder {
Ok(self.builder.concat(&other.builder)?.into())
}

pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult<Self> {
Ok(self.builder.intersect(&other.builder, is_all)?.into())
}

pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> PyResult<Self> {
Ok(self
.builder
Expand Down
9 changes: 9 additions & 0 deletions src/daft-logical-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub enum LogicalPlan {
Aggregate(Aggregate),
Pivot(Pivot),
Concat(Concat),
Intersect(Intersect),
Join(Join),
Sink(Sink),
Sample(Sample),
Expand Down Expand Up @@ -58,6 +59,7 @@ impl LogicalPlan {
Self::Aggregate(Aggregate { output_schema, .. }) => output_schema.clone(),
Self::Pivot(Pivot { output_schema, .. }) => output_schema.clone(),
Self::Concat(Concat { input, .. }) => input.schema(),
Self::Intersect(Intersect { lhs, .. }) => lhs.schema(),
Self::Join(Join { output_schema, .. }) => output_schema.clone(),
Self::Sink(Sink { schema, .. }) => schema.clone(),
Self::Sample(Sample { input, .. }) => input.schema(),
Expand Down Expand Up @@ -162,6 +164,7 @@ impl LogicalPlan {
.collect();
vec![left, right]
}
Self::Intersect(_) => vec![IndexSet::new(), IndexSet::new()],
Self::Source(_) => todo!(),
Self::Sink(_) => todo!(),
}
Expand All @@ -183,6 +186,7 @@ impl LogicalPlan {
Self::Pivot(..) => "Pivot",
Self::Concat(..) => "Concat",
Self::Join(..) => "Join",
Self::Intersect(..) => "Intersect",
Self::Sink(..) => "Sink",
Self::Sample(..) => "Sample",
Self::MonotonicallyIncreasingId(..) => "MonotonicallyIncreasingId",
Expand All @@ -205,6 +209,7 @@ impl LogicalPlan {
Self::Aggregate(aggregate) => aggregate.multiline_display(),
Self::Pivot(pivot) => pivot.multiline_display(),
Self::Concat(_) => vec!["Concat".to_string()],
Self::Intersect(inner) => inner.multiline_display(),
Self::Join(join) => join.multiline_display(),
Self::Sink(sink) => sink.multiline_display(),
Self::Sample(sample) => {
Expand All @@ -231,6 +236,7 @@ impl LogicalPlan {
Self::Concat(Concat { input, other }) => vec![input, other],
Self::Join(Join { left, right, .. }) => vec![left, right],
Self::Sink(Sink { input, .. }) => vec![input],
Self::Intersect(Intersect { lhs, rhs, .. }) => vec![lhs, rhs],
Self::Sample(Sample { input, .. }) => vec![input],
Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { input, .. }) => {
vec![input]
Expand Down Expand Up @@ -259,11 +265,13 @@ impl LogicalPlan {
Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot { input: input.clone(), ids: ids.clone(), values: values.clone(), variable_name: variable_name.clone(), value_name: value_name.clone(), output_schema: output_schema.clone() }),
Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)),
Self::Concat(_) => panic!("Concat ops should never have only one input, but got one"),
Self::Intersect(_) => panic!("Intersect ops should never have only one input, but got one"),
Self::Join(_) => panic!("Join ops should never have only one input, but got one"),
},
[input1, input2] => match self {
Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"),
Self::Concat(_) => Self::Concat(Concat::try_new(input1.clone(), input2.clone()).unwrap()),
Self::Intersect(inner) => Self::Intersect(Intersect::try_new(input1.clone(), input2.clone(), inner.is_all).unwrap()),
Self::Join(Join { left_on, right_on, null_equals_nulls, join_type, join_strategy, .. }) => Self::Join(Join::try_new(
input1.clone(),
input2.clone(),
Expand Down Expand Up @@ -360,6 +368,7 @@ impl_from_data_struct_for_logical_plan!(Distinct);
impl_from_data_struct_for_logical_plan!(Aggregate);
impl_from_data_struct_for_logical_plan!(Pivot);
impl_from_data_struct_for_logical_plan!(Concat);
impl_from_data_struct_for_logical_plan!(Intersect);
impl_from_data_struct_for_logical_plan!(Join);
impl_from_data_struct_for_logical_plan!(Sink);
impl_from_data_struct_for_logical_plan!(Sample);
Expand Down
2 changes: 2 additions & 0 deletions src/daft-logical-plan/src/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod pivot;
mod project;
mod repartition;
mod sample;
mod set_operations;
mod sink;
mod sort;
mod source;
Expand All @@ -29,6 +30,7 @@ pub use pivot::Pivot;
pub use project::Project;
pub use repartition::Repartition;
pub use sample::Sample;
pub use set_operations::Intersect;
pub use sink::Sink;
pub use sort::Sort;
pub use source::Source;
Expand Down
109 changes: 109 additions & 0 deletions src/daft-logical-plan/src/ops/set_operations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use std::sync::Arc;

use common_error::DaftError;
use daft_core::join::JoinType;
use daft_dsl::col;
use snafu::ResultExt;

use crate::{logical_plan, logical_plan::CreationSnafu, LogicalPlan};

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Intersect {
// Upstream nodes.
pub lhs: Arc<LogicalPlan>,
pub rhs: Arc<LogicalPlan>,
pub is_all: bool,
}

impl Intersect {
pub(crate) fn try_new(
lhs: Arc<LogicalPlan>,
rhs: Arc<LogicalPlan>,
is_all: bool,
) -> logical_plan::Result<Self> {
let lhs_schema = lhs.schema();
let rhs_schema = rhs.schema();
if lhs_schema.len() != rhs_schema.len() {
return Err(DaftError::SchemaMismatch(format!(
"Both plans must have the same num of fields to intersect, \
but got[lhs: {} v.s rhs: {}], lhs schema: {}, rhs schema: {}",
lhs_schema.len(),
rhs_schema.len(),
lhs_schema,
rhs_schema
)))
.context(CreationSnafu);
}
// lhs and rhs should have the same type for each field to intersect
if lhs_schema
.fields
.values()
.zip(rhs_schema.fields.values())
.any(|(l, r)| l.dtype != r.dtype)
{
return Err(DaftError::SchemaMismatch(format!(
"Both plans' schemas should have the same type for each field to intersect, \
but got lhs schema: {}, rhs schema: {}",
lhs_schema, rhs_schema
)))
.context(CreationSnafu);
}
Ok(Self { lhs, rhs, is_all })
}

/// intersect distinct could be represented as a semi join + distinct
/// the following intersect operator:
/// ```sql
/// select a1, a2 from t1 intersect select b1, b2 from t2
/// ```
/// is the same as:
/// ```sql
/// select distinct a1, a2 from t1 left semi join t2
/// on t1.a1 <> t2.b1 and t1.a2 <> t2.b2
/// ```
/// TODO: Move this logical to logical optimization rules
pub(crate) fn to_optimized_join(&self) -> logical_plan::Result<LogicalPlan> {
if self.is_all {
Err(logical_plan::Error::CreationError {
source: DaftError::InternalError("intersect all is not supported yet".to_string()),
})
} else {
let left_on = self
.lhs
.schema()
.fields
.keys()
.map(|k| col(k.clone()))
.collect();
let right_on = self
.rhs
.schema()
.fields
.keys()
.map(|k| col(k.clone()))
.collect();
let join = logical_plan::Join::try_new(
self.lhs.clone(),
self.rhs.clone(),
left_on,
right_on,
Some(vec![true; self.lhs.schema().fields.len()]),
JoinType::Semi,
None,
None,
None,
);
join.map(|j| logical_plan::Distinct::new(j.into()).into())
}
}

pub fn multiline_display(&self) -> Vec<String> {
let mut res = vec![];
if self.is_all {
res.push("Intersect All:".to_string());
} else {
res.push("Intersect:".to_string());
}
res
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,11 @@ impl PushDownProjection {
// since Distinct implicitly requires all parent columns.
Ok(Transformed::no(plan))
}
LogicalPlan::Intersect(_) => {
// Cannot push down past an Intersect,
// since Intersect implicitly requires all parent columns.
Ok(Transformed::no(plan))
}
LogicalPlan::Pivot(_) | LogicalPlan::MonotonicallyIncreasingId(_) => {
// Cannot push down past a Pivot/MonotonicallyIncreasingId because it changes the schema.
Ok(Transformed::no(plan))
Expand Down
3 changes: 3 additions & 0 deletions src/daft-physical-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,9 @@ pub(super) fn translate_single_logical_node(
.arced(),
)
}
LogicalPlan::Intersect(_) => Err(DaftError::InternalError(
"Intersect should already be optimized away".to_string(),
)),
}
}

Expand Down
47 changes: 47 additions & 0 deletions tests/dataframe/test_intersect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import daft
from daft import col


def test_simple_intersect(make_df):
df1 = make_df({"foo": [1, 2, 3]})
df2 = make_df({"bar": [2, 3, 4]})
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, 3]}


def test_intersect_with_duplicate(make_df):
df1 = make_df({"foo": [1, 2, 2, 3]})
df2 = make_df({"bar": [2, 3, 3]})
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, 3]}


def test_self_intersect(make_df):
df = make_df({"foo": [1, 2, 3]})
result = df.intersect(df).sort(by="foo")
assert result.to_pydict() == {"foo": [1, 2, 3]}


def test_intersect_empty(make_df):
df1 = make_df({"foo": [1, 2, 3]})
df2 = make_df({"bar": []}).select(col("bar").cast(daft.DataType.int64()))
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": []}


def test_intersect_with_nulls(make_df):
df1 = make_df({"foo": [1, 2, None]})
df1_without_mull = make_df({"foo": [1, 2]})
df2 = make_df({"bar": [2, 3, None]})
df2_without_null = make_df({"bar": [2, 3]})

result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, None]}

result = df1_without_mull.intersect(df2)
assert result.to_pydict() == {"foo": [2]}

result = df1.intersect(df2_without_null)
assert result.to_pydict() == {"foo": [2]}
Loading