Skip to content

Commit

Permalink
[FEAT] Support intersect as a DataFrame API
Browse files Browse the repository at this point in the history
  • Loading branch information
advancedxy committed Oct 28, 2024
1 parent 8b16405 commit 51da893
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
67 changes: 66 additions & 1 deletion daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Expand All @@ -36,7 +37,7 @@
from daft.dataframe.preview import DataFramePreview
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
from daft.expressions import Expression, ExpressionsProjection, col, lit
from daft.expressions import Expression, ExpressionsProjection, col, lit, zero_lit
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import LocalPartitionSet, PartitionCacheEntry, PartitionSet
from daft.table import MicroPartition
Expand Down Expand Up @@ -2457,6 +2458,70 @@ def pivot(
builder = self._builder.pivot(group_by_expr, pivot_col_expr, value_col_expr, agg_expr, names)
return DataFrame(builder)

@DataframePublicAPI
def intersect(self, other: "DataFrame") -> "DataFrame":
"""Returns the intersection of two DataFrames.
Example:
>>> import daft
>>> df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> df2 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 8, 6]})
>>> df1.intersect(df2).collect()
╭───────┬───────╮
│ a ┆ b │
│ --- ┆ --- │
│ Int64 ┆ Int64 │
╞═══════╪═══════╡
│ 1 ┆ 4 │
├╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ 3 ┆ 6 │
╰───────┴───────╯
<BLANKLINE>
(Showing first 2 of 2 rows)
Args:
other (DataFrame): DataFrame to intersect with
Returns:
DataFrame: DataFrame with the intersection of the two DataFrames
"""
# TODO: we may relax the schema check to be more flexible in the future
if self.schema() != other.schema():
raise ValueError(
f"DataFrames must have the same schema to intersect, expected: {self.schema()}, got: {other.schema()}"
)
"""
The intersect operation could be rewrote as a semi join operation.
```
SELECT a1, a2 FROM t1 INTERSECT SELECT b1, b2 FROM t2
--->
SELECT distinct a1, a2 FROM t1 LEFT SEMI JOIN t2 ON a1 <> b1 AND a2 <> b2
```
Note: the join condition should be null safe equal, which is effectively the same as
```
a1 <> b1
---> the same as
nvl(a1, zero_val) = nvl(b1, zero_val) AND is_null(a1) = is_null(b1)
```
"""

def to_null_safe_equal_join_keys(schema: "Schema") -> Generator:
field_names = schema.column_names()
for name in field_names:
field = schema[name]
# TODO: expr name should be updated automatically
yield col(name).fill_null(zero_lit(field.dtype)).alias(name + "__with_zero__")
yield col(name).is_null().alias(name + "__is_null__")

semi_join = self.join(
other,
how="semi",
left_on=list(to_null_safe_equal_join_keys(self.schema())),
right_on=list(to_null_safe_equal_join_keys(other.schema())),
)
distinct_table = semi_join.distinct()
return DataFrame(distinct_table._builder)

def _materialize_results(self) -> None:
"""Materializes the results of for this DataFrame and hold a pointer to the results."""
context = get_context()
Expand Down
47 changes: 47 additions & 0 deletions tests/dataframe/test_intersect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import daft
from daft import col


def test_simple_intersect(make_df):
df1 = make_df({"foo": [1, 2, 3]})
df2 = make_df({"foo": [2, 3, 4]})
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, 3]}


def test_intersect_with_duplicate(make_df):
df1 = make_df({"foo": [1, 2, 2, 3]})
df2 = make_df({"foo": [2, 3, 3]})
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, 3]}


def test_self_intersect(make_df):
df = make_df({"foo": [1, 2, 3]})
result = df.intersect(df).sort(by="foo")
assert result.to_pydict() == {"foo": [1, 2, 3]}


def test_intersect_empty(make_df):
df1 = make_df({"foo": [1, 2, 3]})
df2 = make_df({"foo": []}).select(col("foo").cast(daft.DataType.int64()))
result = df1.intersect(df2)
assert result.to_pydict() == {"foo": []}


def test_intersect_with_nulls(make_df):
df1 = make_df({"foo": [1, 2, None]})
df1_without_mull = make_df({"foo": [1, 2]})
df2 = make_df({"foo": [2, 3, None]})
df2_without_null = make_df({"foo": [2, 3]})

result = df1.intersect(df2)
assert result.to_pydict() == {"foo": [2, None]}

result = df1_without_mull.intersect(df2)
assert result.to_pydict() == {"foo": [2]}

result = df1.intersect(df2_without_null)
assert result.to_pydict() == {"foo": [2]}

0 comments on commit 51da893

Please sign in to comment.