From e069c2dab00db68861a44c7355eb46a6063b2792 Mon Sep 17 00:00:00 2001 From: siddharth-gulia Date: Wed, 29 May 2024 11:27:15 +0530 Subject: [PATCH] [BUG]Fix missing columns after join (#2321) Closes #2300 --- src/daft-plan/src/logical_ops/join.rs | 5 ++++- tests/dataframe/test_joins.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index a55cce32d2..6d3c57b8a4 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -75,13 +75,16 @@ impl Join { // but contains bug https://github.com/Eventual-Inc/Daft/issues/1294 let output_schema = { let left_join_keys = left_on.iter().map(|e| e.name()).collect::>(); + let right_join_keys = right_on.iter().map(|e| e.name()).collect::>(); let left_schema = &left.schema().fields; let fields = left_schema .iter() .map(|(_, field)| field) .cloned() .chain(right.schema().fields.iter().filter_map(|(rname, rfield)| { - if left_join_keys.contains(rname.as_str()) { + if left_join_keys.contains(rname.as_str()) + && right_join_keys.contains(rname.as_str()) + { right_input_mapping.insert(rname.clone(), rname.clone()); None } else if left_schema.contains_key(rname) { diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 8aea55def7..0c47513cd6 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -31,6 +31,23 @@ def test_invalid_join_strategies(make_df): df.join(df, on="A", strategy="broadcast", how="outer") +def test_columns_after_join(make_df): + df1 = make_df( + { + "A": [1, 2, 3], + }, + ) + + df2 = make_df({"A": [1, 2, 3], "B": [1, 2, 3]}) + + joined_df1 = df1.join(df2, left_on="A", right_on="B") + joined_df2 = df1.join(df2, left_on="A", right_on="A") + + assert set(joined_df1.schema().column_names()) == set(["A", "B", "right.A"]) + + assert set(joined_df2.schema().column_names()) == set(["A", "B"]) + + @pytest.mark.parametrize("n_partitions", [1, 2, 4]) @pytest.mark.parametrize( "join_strategy",