diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 019bb5e50c..c4280480b9 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -504,6 +504,8 @@ pub(super) fn translate_single_logical_node( is_right_hash_partitioned || is_right_sort_partitioned }; let join_strategy = join_strategy.unwrap_or_else(|| { + // This method will panic if called with columns that aren't in the output schema, + // which is possible for anti- and semi-joins. let is_primitive = |exprs: &Vec| { exprs.iter().map(|e| e.name()).all(|col| { let dtype = &output_schema.get_field(col).unwrap().dtype; @@ -529,10 +531,10 @@ pub(super) fn translate_single_logical_node( // TODO(Clark): Also do a sort-merge join if a downstream op needs the table to be sorted on the join key. // TODO(Clark): Look into defaulting to sort-merge join over hash join under more input partitioning setups. // TODO(Kevin): Support sort-merge join for other types of joins. - } else if is_primitive(left_on) + } else if *join_type == JoinType::Inner + && is_primitive(left_on) && is_primitive(right_on) && (is_left_sort_partitioned || is_right_sort_partitioned) - && *join_type == JoinType::Inner && (!is_larger_partitioned || (left_is_larger && is_left_sort_partitioned || !left_is_larger && is_right_sort_partitioned)) diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index b416319637..e2e8cc962e 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -777,3 +777,62 @@ def test_join_semi_anti(join_strategy, join_type, expected, make_df, repartition assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "id") == sort_arrow_table( pa.Table.from_pydict(expected), "id" ) + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +@pytest.mark.parametrize( + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, +) +@pytest.mark.parametrize( + "join_type,expected", + [ + ( + "semi", + { + "id_left": [2, 3], + "values_left": ["b1", "c1"], + }, + ), + ( + "anti", + { + "id_left": [1, None], + "values_left": ["a1", "d1"], + }, + ), + ], +) +def test_join_semi_anti_different_names(join_strategy, join_type, expected, make_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy, join_type) + + daft_df1 = make_df( + { + "id_left": [1, 2, 3, None], + "values_left": ["a1", "b1", "c1", "d1"], + }, + repartition=repartition_nparts, + ) + daft_df2 = make_df( + { + "id_right": [2, 2, 3, 4], + "values_right": ["a2", "b2", "c2", "d2"], + }, + repartition=repartition_nparts, + ) + daft_df = ( + daft_df1.with_column("id_left", daft_df1["id_left"].cast(DataType.int64())) + .join( + daft_df2, + left_on="id_left", + right_on="id_right", + how=join_type, + strategy=join_strategy, + ) + .sort(["id_left", "values_left"]) + ).select("id_left", "values_left") + + assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "id_left") == sort_arrow_table( + pa.Table.from_pydict(expected), "id_left" + ) diff --git a/tests/table/test_joins.py b/tests/table/test_joins.py index 2ddd4e82d7..5664950121 100644 --- a/tests/table/test_joins.py +++ b/tests/table/test_joins.py @@ -282,3 +282,23 @@ def test_table_join_single_column_name_null(join_impl) -> None: result_sorted = result_table.sort([col("x")]) assert result_sorted.get_column("y").to_pylist() == [] assert result_sorted.get_column("right.y").to_pylist() == [] + + +def test_table_join_anti() -> None: + left_table = MicroPartition.from_pydict({"x": [1, 2, 3, 4], "y": [3, 4, 5, 6]}) + right_table = MicroPartition.from_pydict({"x": [2, 3, 5]}) + + result_table = left_table.hash_join(right_table, left_on=[col("x")], right_on=[col("x")], how=JoinType.Anti) + assert result_table.column_names() == ["x", "y"] + result_sorted = result_table.sort([col("x")]) + assert result_sorted.get_column("y").to_pylist() == [3, 6] + + +def test_table_join_anti_different_names() -> None: + left_table = MicroPartition.from_pydict({"x": [1, 2, 3, 4], "y": [3, 4, 5, 6]}) + right_table = MicroPartition.from_pydict({"z": [2, 3, 5]}) + + result_table = left_table.hash_join(right_table, left_on=[col("x")], right_on=[col("z")], how=JoinType.Anti) + assert result_table.column_names() == ["x", "y"] + result_sorted = result_table.sort([col("x")]) + assert result_sorted.get_column("y").to_pylist() == [3, 6]