From 6de169c32c1f01037fe403993f09cdeaac6dbda3 Mon Sep 17 00:00:00 2001 From: advancedxy <807537+advancedxy@users.noreply.github.com> Date: Mon, 28 Oct 2024 15:02:23 +0800 Subject: [PATCH 1/2] [FEAT] Support intersect as a DataFrame API --- daft/daft/__init__.pyi | 1 + daft/dataframe/dataframe.py | 30 ++++++ daft/logical/builder.py | 4 + src/daft-logical-plan/src/builder.rs | 11 +++ src/daft-logical-plan/src/ops/mod.rs | 2 + .../src/ops/set_operations.rs | 99 +++++++++++++++++++ tests/dataframe/test_intersect.py | 47 +++++++++ 7 files changed, 194 insertions(+) create mode 100644 src/daft-logical-plan/src/ops/set_operations.rs create mode 100644 tests/dataframe/test_intersect.py diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 6dad8b6f56..5c11235a6b 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1742,6 +1742,7 @@ class LogicalPlanBuilder: join_suffix: str | None = None, ) -> LogicalPlanBuilder: ... def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: ... + def intersect(self, other: LogicalPlanBuilder, is_all: bool) -> LogicalPlanBuilder: ... def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: ... def table_write( self, diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 52f0f7458e..7c9ffee0ea 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -2474,6 +2474,36 @@ def pivot( builder = self._builder.pivot(group_by_expr, pivot_col_expr, value_col_expr, agg_expr, names) return DataFrame(builder) + @DataframePublicAPI + def intersect(self, other: "DataFrame") -> "DataFrame": + """Returns the intersection of two DataFrames. + + Example: + >>> import daft + >>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) + >>> df2 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 8, 6]}) + >>> df1.intersect(df2).collect() + ╭───────┬───────╮ + │ a ┆ b │ + │ --- ┆ --- │ + │ Int64 ┆ Int64 │ + ╞═══════╪═══════╡ + │ 1 ┆ 4 │ + ├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤ + │ 3 ┆ 6 │ + ╰───────┴───────╯ + + (Showing first 2 of 2 rows) + + Args: + other (DataFrame): DataFrame to intersect with + + Returns: + DataFrame: DataFrame with the intersection of the two DataFrames + """ + builder = self._builder.intersect(other._builder) + return DataFrame(builder) + def _materialize_results(self) -> None: """Materializes the results of for this DataFrame and hold a pointer to the results.""" context = get_context() diff --git a/daft/logical/builder.py b/daft/logical/builder.py index d9354e0801..3d676ebc66 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -273,6 +273,10 @@ def concat(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: # type: igno builder = self._builder.concat(other._builder) return LogicalPlanBuilder(builder) + def intersect(self, other: LogicalPlanBuilder) -> LogicalPlanBuilder: + builder = self._builder.intersect(other._builder, False) + return LogicalPlanBuilder(builder) + def add_monotonically_increasing_id(self, column_name: str | None) -> LogicalPlanBuilder: builder = self._builder.add_monotonically_increasing_id(column_name) return LogicalPlanBuilder(builder) diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index 4a8d9d2cfc..9b5228c0c4 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -506,6 +506,13 @@ impl LogicalPlanBuilder { Ok(self.with_new_plan(logical_plan)) } + pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult { + let logical_plan: LogicalPlan = + ops::Intersect::try_new(self.plan.clone(), other.plan.clone(), is_all)? + .to_optimized_join()?; + Ok(self.with_new_plan(logical_plan)) + } + pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> DaftResult { let logical_plan: LogicalPlan = ops::MonotonicallyIncreasingId::new(self.plan.clone(), column_name).into(); @@ -945,6 +952,10 @@ impl PyLogicalPlanBuilder { Ok(self.builder.concat(&other.builder)?.into()) } + pub fn intersect(&self, other: &Self, is_all: bool) -> DaftResult { + Ok(self.builder.intersect(&other.builder, is_all)?.into()) + } + pub fn add_monotonically_increasing_id(&self, column_name: Option<&str>) -> PyResult { Ok(self .builder diff --git a/src/daft-logical-plan/src/ops/mod.rs b/src/daft-logical-plan/src/ops/mod.rs index 339589deea..ec0e47e0e7 100644 --- a/src/daft-logical-plan/src/ops/mod.rs +++ b/src/daft-logical-plan/src/ops/mod.rs @@ -11,6 +11,7 @@ mod pivot; mod project; mod repartition; mod sample; +mod set_operations; mod sink; mod sort; mod source; @@ -29,6 +30,7 @@ pub use pivot::Pivot; pub use project::Project; pub use repartition::Repartition; pub use sample::Sample; +pub use set_operations::Intersect; pub use sink::Sink; pub use sort::Sort; pub use source::Source; diff --git a/src/daft-logical-plan/src/ops/set_operations.rs b/src/daft-logical-plan/src/ops/set_operations.rs new file mode 100644 index 0000000000..1225093fce --- /dev/null +++ b/src/daft-logical-plan/src/ops/set_operations.rs @@ -0,0 +1,99 @@ +use std::sync::Arc; + +use common_error::DaftError; +use daft_core::join::JoinType; +use daft_dsl::col; +use snafu::ResultExt; + +use crate::{logical_plan, logical_plan::CreationSnafu, LogicalPlan}; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Intersect { + // Upstream nodes. + pub lhs: Arc, + pub rhs: Arc, + pub is_all: bool, +} + +impl Intersect { + pub(crate) fn try_new( + lhs: Arc, + rhs: Arc, + is_all: bool, + ) -> logical_plan::Result { + let lhs_schema = lhs.schema(); + let rhs_schema = rhs.schema(); + if lhs_schema.len() != rhs_schema.len() { + return Err(DaftError::SchemaMismatch(format!( + "Both plans must have the same num of fields to intersect, \ + but got[lhs: {} v.s rhs: {}], lhs schema: {}, rhs schema: {}", + lhs_schema.len(), + rhs_schema.len(), + lhs_schema, + rhs_schema + ))) + .context(CreationSnafu); + } + // lhs and rhs should have the same type for each field to intersect + if lhs_schema + .fields + .values() + .zip(rhs_schema.fields.values()) + .any(|(l, r)| l.dtype != r.dtype) + { + return Err(DaftError::SchemaMismatch(format!( + "Both plans' schemas should have the same type for each field to intersect, \ + but got lhs schema: {}, rhs schema: {}", + lhs_schema, rhs_schema + ))) + .context(CreationSnafu); + } + Ok(Self { lhs, rhs, is_all }) + } + + /// intersect distinct could be represented as a semi join + distinct + /// the following intersect operator: + /// ```sql + /// select a1, a2 from t1 intersect select b1, b2 from t2 + /// ``` + /// is the same as: + /// ```sql + /// select distinct a1, a2 from t1 left semi join t2 + /// on t1.a1 <> t2.b1 and t1.a2 <> t2.b2 + /// ``` + /// TODO: Move this logical to logical optimization rules + pub(crate) fn to_optimized_join(&self) -> logical_plan::Result { + if self.is_all { + Err(logical_plan::Error::CreationError { + source: DaftError::InternalError("intersect all is not supported yet".to_string()), + }) + } else { + let left_on = self + .lhs + .schema() + .fields + .keys() + .map(|k| col(k.clone())) + .collect(); + let right_on = self + .rhs + .schema() + .fields + .keys() + .map(|k| col(k.clone())) + .collect(); + let join = logical_plan::Join::try_new( + self.lhs.clone(), + self.rhs.clone(), + left_on, + right_on, + Some(vec![true; self.lhs.schema().fields.len()]), + JoinType::Semi, + None, + None, + None, + ); + join.map(|j| logical_plan::Distinct::new(j.into()).into()) + } + } +} diff --git a/tests/dataframe/test_intersect.py b/tests/dataframe/test_intersect.py new file mode 100644 index 0000000000..24330ec84e --- /dev/null +++ b/tests/dataframe/test_intersect.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import daft +from daft import col + + +def test_simple_intersect(make_df): + df1 = make_df({"foo": [1, 2, 3]}) + df2 = make_df({"bar": [2, 3, 4]}) + result = df1.intersect(df2) + assert result.to_pydict() == {"foo": [2, 3]} + + +def test_intersect_with_duplicate(make_df): + df1 = make_df({"foo": [1, 2, 2, 3]}) + df2 = make_df({"bar": [2, 3, 3]}) + result = df1.intersect(df2) + assert result.to_pydict() == {"foo": [2, 3]} + + +def test_self_intersect(make_df): + df = make_df({"foo": [1, 2, 3]}) + result = df.intersect(df).sort(by="foo") + assert result.to_pydict() == {"foo": [1, 2, 3]} + + +def test_intersect_empty(make_df): + df1 = make_df({"foo": [1, 2, 3]}) + df2 = make_df({"bar": []}).select(col("bar").cast(daft.DataType.int64())) + result = df1.intersect(df2) + assert result.to_pydict() == {"foo": []} + + +def test_intersect_with_nulls(make_df): + df1 = make_df({"foo": [1, 2, None]}) + df1_without_mull = make_df({"foo": [1, 2]}) + df2 = make_df({"bar": [2, 3, None]}) + df2_without_null = make_df({"bar": [2, 3]}) + + result = df1.intersect(df2) + assert result.to_pydict() == {"foo": [2, None]} + + result = df1_without_mull.intersect(df2) + assert result.to_pydict() == {"foo": [2]} + + result = df1.intersect(df2_without_null) + assert result.to_pydict() == {"foo": [2]} From dcddfb5db37588fcf98abe6c93a6bda7e0f47dec Mon Sep 17 00:00:00 2001 From: advancedxy <807537+advancedxy@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:08:37 +0800 Subject: [PATCH 2/2] [CHORE] Make Intersect a LogicalPlan node --- src/daft-logical-plan/src/logical_plan.rs | 9 +++++++++ src/daft-logical-plan/src/ops/set_operations.rs | 10 ++++++++++ .../src/optimization/rules/push_down_projection.rs | 5 +++++ .../src/physical_planner/translate.rs | 3 +++ 4 files changed, 27 insertions(+) diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index 6c5bce7c66..4e2f270ce3 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -25,6 +25,7 @@ pub enum LogicalPlan { Aggregate(Aggregate), Pivot(Pivot), Concat(Concat), + Intersect(Intersect), Join(Join), Sink(Sink), Sample(Sample), @@ -58,6 +59,7 @@ impl LogicalPlan { Self::Aggregate(Aggregate { output_schema, .. }) => output_schema.clone(), Self::Pivot(Pivot { output_schema, .. }) => output_schema.clone(), Self::Concat(Concat { input, .. }) => input.schema(), + Self::Intersect(Intersect { lhs, .. }) => lhs.schema(), Self::Join(Join { output_schema, .. }) => output_schema.clone(), Self::Sink(Sink { schema, .. }) => schema.clone(), Self::Sample(Sample { input, .. }) => input.schema(), @@ -162,6 +164,7 @@ impl LogicalPlan { .collect(); vec![left, right] } + Self::Intersect(_) => vec![IndexSet::new(), IndexSet::new()], Self::Source(_) => todo!(), Self::Sink(_) => todo!(), } @@ -183,6 +186,7 @@ impl LogicalPlan { Self::Pivot(..) => "Pivot", Self::Concat(..) => "Concat", Self::Join(..) => "Join", + Self::Intersect(..) => "Intersect", Self::Sink(..) => "Sink", Self::Sample(..) => "Sample", Self::MonotonicallyIncreasingId(..) => "MonotonicallyIncreasingId", @@ -205,6 +209,7 @@ impl LogicalPlan { Self::Aggregate(aggregate) => aggregate.multiline_display(), Self::Pivot(pivot) => pivot.multiline_display(), Self::Concat(_) => vec!["Concat".to_string()], + Self::Intersect(inner) => inner.multiline_display(), Self::Join(join) => join.multiline_display(), Self::Sink(sink) => sink.multiline_display(), Self::Sample(sample) => { @@ -231,6 +236,7 @@ impl LogicalPlan { Self::Concat(Concat { input, other }) => vec![input, other], Self::Join(Join { left, right, .. }) => vec![left, right], Self::Sink(Sink { input, .. }) => vec![input], + Self::Intersect(Intersect { lhs, rhs, .. }) => vec![lhs, rhs], Self::Sample(Sample { input, .. }) => vec![input], Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { input, .. }) => { vec![input] @@ -259,11 +265,13 @@ impl LogicalPlan { Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot { input: input.clone(), ids: ids.clone(), values: values.clone(), variable_name: variable_name.clone(), value_name: value_name.clone(), output_schema: output_schema.clone() }), Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)), Self::Concat(_) => panic!("Concat ops should never have only one input, but got one"), + Self::Intersect(_) => panic!("Intersect ops should never have only one input, but got one"), Self::Join(_) => panic!("Join ops should never have only one input, but got one"), }, [input1, input2] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), Self::Concat(_) => Self::Concat(Concat::try_new(input1.clone(), input2.clone()).unwrap()), + Self::Intersect(inner) => Self::Intersect(Intersect::try_new(input1.clone(), input2.clone(), inner.is_all).unwrap()), Self::Join(Join { left_on, right_on, null_equals_nulls, join_type, join_strategy, .. }) => Self::Join(Join::try_new( input1.clone(), input2.clone(), @@ -360,6 +368,7 @@ impl_from_data_struct_for_logical_plan!(Distinct); impl_from_data_struct_for_logical_plan!(Aggregate); impl_from_data_struct_for_logical_plan!(Pivot); impl_from_data_struct_for_logical_plan!(Concat); +impl_from_data_struct_for_logical_plan!(Intersect); impl_from_data_struct_for_logical_plan!(Join); impl_from_data_struct_for_logical_plan!(Sink); impl_from_data_struct_for_logical_plan!(Sample); diff --git a/src/daft-logical-plan/src/ops/set_operations.rs b/src/daft-logical-plan/src/ops/set_operations.rs index 1225093fce..9db43f56b2 100644 --- a/src/daft-logical-plan/src/ops/set_operations.rs +++ b/src/daft-logical-plan/src/ops/set_operations.rs @@ -96,4 +96,14 @@ impl Intersect { join.map(|j| logical_plan::Distinct::new(j.into()).into()) } } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + if self.is_all { + res.push("Intersect All:".to_string()); + } else { + res.push("Intersect:".to_string()); + } + res + } } diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs index be3abef5f2..749c39e620 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs @@ -489,6 +489,11 @@ impl PushDownProjection { // since Distinct implicitly requires all parent columns. Ok(Transformed::no(plan)) } + LogicalPlan::Intersect(_) => { + // Cannot push down past an Intersect, + // since Intersect implicitly requires all parent columns. + Ok(Transformed::no(plan)) + } LogicalPlan::Pivot(_) | LogicalPlan::MonotonicallyIncreasingId(_) => { // Cannot push down past a Pivot/MonotonicallyIncreasingId because it changes the schema. Ok(Transformed::no(plan)) diff --git a/src/daft-physical-plan/src/physical_planner/translate.rs b/src/daft-physical-plan/src/physical_planner/translate.rs index 56d29b154e..1a097c0871 100644 --- a/src/daft-physical-plan/src/physical_planner/translate.rs +++ b/src/daft-physical-plan/src/physical_planner/translate.rs @@ -737,6 +737,9 @@ pub(super) fn translate_single_logical_node( .arced(), ) } + LogicalPlan::Intersect(_) => Err(DaftError::InternalError( + "Intersect should already be optimized away".to_string(), + )), } }