Skip to content

Commit

Permalink
[BUG] Fix anti-join on different column names (#2477)
Browse files Browse the repository at this point in the history
Closes #2475.

A bit of a bandaid fix. The issue was, while selecting the join
strategy, for checking whether sort-merge join works, the joined columns
need to be primitive types. However, in the plan we only have access to
the output schema (as far as I know, maybe we can extract it from the
children?). In a semi- or anti-join, the right column does not end up in
the output schema, so the primitive type check panicked while looking
for it. This issue does not appear if the left column has the same name
as the right.

My fix is that, since we only support SMJ for inner joins, I move the
join type check before the primitive check. This will need to be fixed
later if we support SMJ for anti-joins, so I added a comment.
  • Loading branch information
Vince7778 authored Jul 8, 2024
1 parent e5ff7d7 commit 55fc032
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/daft-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ pub(super) fn translate_single_logical_node(
is_right_hash_partitioned || is_right_sort_partitioned
};
let join_strategy = join_strategy.unwrap_or_else(|| {
// This method will panic if called with columns that aren't in the output schema,
// which is possible for anti- and semi-joins.
let is_primitive = |exprs: &Vec<ExprRef>| {
exprs.iter().map(|e| e.name()).all(|col| {
let dtype = &output_schema.get_field(col).unwrap().dtype;
Expand All @@ -529,10 +531,10 @@ pub(super) fn translate_single_logical_node(
// TODO(Clark): Also do a sort-merge join if a downstream op needs the table to be sorted on the join key.
// TODO(Clark): Look into defaulting to sort-merge join over hash join under more input partitioning setups.
// TODO(Kevin): Support sort-merge join for other types of joins.
} else if is_primitive(left_on)
} else if *join_type == JoinType::Inner
&& is_primitive(left_on)
&& is_primitive(right_on)
&& (is_left_sort_partitioned || is_right_sort_partitioned)
&& *join_type == JoinType::Inner
&& (!is_larger_partitioned
|| (left_is_larger && is_left_sort_partitioned
|| !left_is_larger && is_right_sort_partitioned))
Expand Down
59 changes: 59 additions & 0 deletions tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,62 @@ def test_join_semi_anti(join_strategy, join_type, expected, make_df, repartition
assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "id") == sort_arrow_table(
pa.Table.from_pydict(expected), "id"
)


@pytest.mark.parametrize("repartition_nparts", [1, 2, 4])
@pytest.mark.parametrize(
"join_strategy",
[None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"],
indirect=True,
)
@pytest.mark.parametrize(
"join_type,expected",
[
(
"semi",
{
"id_left": [2, 3],
"values_left": ["b1", "c1"],
},
),
(
"anti",
{
"id_left": [1, None],
"values_left": ["a1", "d1"],
},
),
],
)
def test_join_semi_anti_different_names(join_strategy, join_type, expected, make_df, repartition_nparts):
skip_invalid_join_strategies(join_strategy, join_type)

daft_df1 = make_df(
{
"id_left": [1, 2, 3, None],
"values_left": ["a1", "b1", "c1", "d1"],
},
repartition=repartition_nparts,
)
daft_df2 = make_df(
{
"id_right": [2, 2, 3, 4],
"values_right": ["a2", "b2", "c2", "d2"],
},
repartition=repartition_nparts,
)
daft_df = (
daft_df1.with_column("id_left", daft_df1["id_left"].cast(DataType.int64()))
.join(
daft_df2,
left_on="id_left",
right_on="id_right",
how=join_type,
strategy=join_strategy,
)
.sort(["id_left", "values_left"])
).select("id_left", "values_left")

assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "id_left") == sort_arrow_table(
pa.Table.from_pydict(expected), "id_left"
)
20 changes: 20 additions & 0 deletions tests/table/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,23 @@ def test_table_join_single_column_name_null(join_impl) -> None:
result_sorted = result_table.sort([col("x")])
assert result_sorted.get_column("y").to_pylist() == []
assert result_sorted.get_column("right.y").to_pylist() == []


def test_table_join_anti() -> None:
left_table = MicroPartition.from_pydict({"x": [1, 2, 3, 4], "y": [3, 4, 5, 6]})
right_table = MicroPartition.from_pydict({"x": [2, 3, 5]})

result_table = left_table.hash_join(right_table, left_on=[col("x")], right_on=[col("x")], how=JoinType.Anti)
assert result_table.column_names() == ["x", "y"]
result_sorted = result_table.sort([col("x")])
assert result_sorted.get_column("y").to_pylist() == [3, 6]


def test_table_join_anti_different_names() -> None:
left_table = MicroPartition.from_pydict({"x": [1, 2, 3, 4], "y": [3, 4, 5, 6]})
right_table = MicroPartition.from_pydict({"z": [2, 3, 5]})

result_table = left_table.hash_join(right_table, left_on=[col("x")], right_on=[col("z")], how=JoinType.Anti)
assert result_table.column_names() == ["x", "y"]
result_sorted = result_table.sort([col("x")])
assert result_sorted.get_column("y").to_pylist() == [3, 6]

0 comments on commit 55fc032

Please sign in to comment.