From 55fc032ba1ebe77c108443eac50defcd561ba30a Mon Sep 17 00:00:00 2001 From: Conor Kennedy <32619800+Vince7778@users.noreply.github.com> Date: Mon, 8 Jul 2024 11:22:35 -0700 Subject: [PATCH] [BUG] Fix anti-join on different column names (#2477) Closes #2475. A bit of a bandaid fix. The issue was, while selecting the join strategy, for checking whether sort-merge join works, the joined columns need to be primitive types. However, in the plan we only have access to the output schema (as far as I know, maybe we can extract it from the children?). In a semi- or anti-join, the right column does not end up in the output schema, so the primitive type check panicked while looking for it. This issue does not appear if the left column has the same name as the right. My fix is that, since we only support SMJ for inner joins, I move the join type check before the primitive check. This will need to be fixed later if we support SMJ for anti-joins, so I added a comment. --- .../src/physical_planner/translate.rs | 6 +- tests/dataframe/test_joins.py | 59 +++++++++++++++++++ tests/table/test_joins.py | 20 +++++++ 3 files changed, 83 insertions(+), 2 deletions(-) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 019bb5e50c..c4280480b9 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -504,6 +504,8 @@ pub(super) fn translate_single_logical_node( is_right_hash_partitioned || is_right_sort_partitioned }; let join_strategy = join_strategy.unwrap_or_else(|| { + // This method will panic if called with columns that aren't in the output schema, + // which is possible for anti- and semi-joins. let is_primitive = |exprs: &Vec| { exprs.iter().map(|e| e.name()).all(|col| { let dtype = &output_schema.get_field(col).unwrap().dtype; @@ -529,10 +531,10 @@ pub(super) fn translate_single_logical_node( // TODO(Clark): Also do a sort-merge join if a downstream op needs the table to be sorted on the join key. // TODO(Clark): Look into defaulting to sort-merge join over hash join under more input partitioning setups. // TODO(Kevin): Support sort-merge join for other types of joins. - } else if is_primitive(left_on) + } else if *join_type == JoinType::Inner + && is_primitive(left_on) && is_primitive(right_on) && (is_left_sort_partitioned || is_right_sort_partitioned) - && *join_type == JoinType::Inner && (!is_larger_partitioned || (left_is_larger && is_left_sort_partitioned || !left_is_larger && is_right_sort_partitioned)) diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index b416319637..e2e8cc962e 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -777,3 +777,62 @@ def test_join_semi_anti(join_strategy, join_type, expected, make_df, repartition assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "id") == sort_arrow_table( pa.Table.from_pydict(expected), "id" ) + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +@pytest.mark.parametrize( + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, +) +@pytest.mark.parametrize( + "join_type,expected", + [ + ( + "semi", + { + "id_left": [2, 3], + "values_left": ["b1", "c1"], + }, + ), + ( + "anti", + { + "id_left": [1, None], + "values_left": ["a1", "d1"], + }, + ), + ], +) +def test_join_semi_anti_different_names(join_strategy, join_type, expected, make_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy, join_type) + + daft_df1 = make_df( + { + "id_left": [1, 2, 3, None], + "values_left": ["a1", "b1", "c1", "d1"], + }, + repartition=repartition_nparts, + ) + daft_df2 = make_df( + { + "id_right": [2, 2, 3, 4], + "values_right": ["a2", "b2", "c2", "d2"], + }, + repartition=repartition_nparts, + ) + daft_df = ( + daft_df1.with_column("id_left", daft_df1["id_left"].cast(DataType.int64())) + .join( + daft_df2, + left_on="id_left", + right_on="id_right", + how=join_type, + strategy=join_strategy, + ) + .sort(["id_left", "values_left"]) + ).select("id_left", "values_left") + + assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "id_left") == sort_arrow_table( + pa.Table.from_pydict(expected), "id_left" + ) diff --git a/tests/table/test_joins.py b/tests/table/test_joins.py index 2ddd4e82d7..5664950121 100644 --- a/tests/table/test_joins.py +++ b/tests/table/test_joins.py @@ -282,3 +282,23 @@ def test_table_join_single_column_name_null(join_impl) -> None: result_sorted = result_table.sort([col("x")]) assert result_sorted.get_column("y").to_pylist() == [] assert result_sorted.get_column("right.y").to_pylist() == [] + + +def test_table_join_anti() -> None: + left_table = MicroPartition.from_pydict({"x": [1, 2, 3, 4], "y": [3, 4, 5, 6]}) + right_table = MicroPartition.from_pydict({"x": [2, 3, 5]}) + + result_table = left_table.hash_join(right_table, left_on=[col("x")], right_on=[col("x")], how=JoinType.Anti) + assert result_table.column_names() == ["x", "y"] + result_sorted = result_table.sort([col("x")]) + assert result_sorted.get_column("y").to_pylist() == [3, 6] + + +def test_table_join_anti_different_names() -> None: + left_table = MicroPartition.from_pydict({"x": [1, 2, 3, 4], "y": [3, 4, 5, 6]}) + right_table = MicroPartition.from_pydict({"z": [2, 3, 5]}) + + result_table = left_table.hash_join(right_table, left_on=[col("x")], right_on=[col("z")], how=JoinType.Anti) + assert result_table.column_names() == ["x", "y"] + result_sorted = result_table.sort([col("x")]) + assert result_sorted.get_column("y").to_pylist() == [3, 6]