From dab006f971e4fd2fbd979fa27611eccc5a7d2668 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 19 Nov 2024 14:23:11 -0600 Subject: [PATCH] [BUG]: tbl alias with join (#3333) closes https://github.com/Eventual-Inc/Daft/issues/3311 --- src/daft-sql/src/planner.rs | 9 +++++++-- tests/sql/test_joins.py | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index f450aca35f..86dd232011 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1025,7 +1025,6 @@ impl SQLPlanner { table_not_found_err!(table_name); }; let right_schema = table_rel.inner.schema(); - let schema = rel.inner.schema(); let keys = schema.fields.keys(); let right_schema = if let Some(exclude) = &wildcard_opts.opt_exclude { @@ -1033,6 +1032,7 @@ impl SQLPlanner { } else { right_schema }; + let columns = right_schema .fields .keys() @@ -1041,12 +1041,17 @@ impl SQLPlanner { .clone() .any(|s| s.starts_with(&table_name) && s.ends_with(field)) { - col(format!("{}.{}", table_name, field)).alias(field.as_ref()) + if table_name == rel.get_name() { + col(field.clone()) + } else { + col(format!("{}.{}", &table_name, field)).alias(field.as_ref()) + } } else { col(field.clone()) } }) .collect::>(); + Ok(columns) } } diff --git a/tests/sql/test_joins.py b/tests/sql/test_joins.py index e91cee427b..d7d23ffad8 100644 --- a/tests/sql/test_joins.py +++ b/tests/sql/test_joins.py @@ -126,13 +126,43 @@ def test_joins_with_duplicate_columns(): "b.y = a.x", ], ) -def test_join_qualifiers(join_condition): +@pytest.mark.parametrize("selection", ["*", "a.*, b.y, b.score", "a.x, a.val, b.*", "a.x, a.val, b.y, b.score"]) +def test_join_qualifiers(join_condition, selection): a = daft.from_pydict({"x": [1, None], "val": [10, 20]}) b = daft.from_pydict({"y": [1, None], "score": [0.1, 0.2]}) catalog = SQLCatalog({"a": a, "b": b}) - df_sql = daft.sql(f"select * from a join b on {join_condition}", catalog).to_pydict() + df_sql = daft.sql(f"select {selection} from a join b on {join_condition}", catalog).to_pydict() + + expected = {"x": [1], "val": [10], "y": [1], "score": [0.1]} + + assert df_sql == expected + + +@pytest.mark.parametrize( + "join_condition", + [ + "x = y", + "x = b1.y", + "y = x", + "y = a1.x", + "a1.x = y", + "a1.x = b1.y", + "b1.y = x", + "b1.y = a1.x", + ], +) +@pytest.mark.parametrize( + "selection", ["*", "a1.*, b1.y, b.score", "a1.x, a1.val, b1.*", "a1.x, a1.val, b1.y, b1.score"] +) +def test_join_qualifiers_with_alias(join_condition, selection): + a = daft.from_pydict({"x": [1, None], "val": [10, 20]}) + b = daft.from_pydict({"y": [1, None], "score": [0.1, 0.2]}) + + catalog = SQLCatalog({"a": a, "b": b}) + + df_sql = daft.sql(f"select {selection} from a as a1 join b as b1 on {join_condition}", catalog).to_pydict() expected = {"x": [1], "val": [10], "y": [1], "score": [0.1]}