Skip to content

Commit

Permalink
[BUG]: tbl alias with join (#3333)
Browse files Browse the repository at this point in the history
closes #3311
  • Loading branch information
universalmind303 authored Nov 19, 2024
1 parent 1b84250 commit dab006f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
9 changes: 7 additions & 2 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1025,14 +1025,14 @@ impl SQLPlanner {
table_not_found_err!(table_name);
};
let right_schema = table_rel.inner.schema();
let schema = rel.inner.schema();
let keys = schema.fields.keys();

let right_schema = if let Some(exclude) = &wildcard_opts.opt_exclude {
Arc::new(wildcard_exclude(right_schema, exclude)?)
} else {
right_schema
};

let columns = right_schema
.fields
.keys()
Expand All @@ -1041,12 +1041,17 @@ impl SQLPlanner {
.clone()
.any(|s| s.starts_with(&table_name) && s.ends_with(field))
{
col(format!("{}.{}", table_name, field)).alias(field.as_ref())
if table_name == rel.get_name() {
col(field.clone())
} else {
col(format!("{}.{}", &table_name, field)).alias(field.as_ref())
}
} else {
col(field.clone())
}
})
.collect::<Vec<_>>();

Ok(columns)
}
}
Expand Down
34 changes: 32 additions & 2 deletions tests/sql/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,43 @@ def test_joins_with_duplicate_columns():
"b.y = a.x",
],
)
def test_join_qualifiers(join_condition):
@pytest.mark.parametrize("selection", ["*", "a.*, b.y, b.score", "a.x, a.val, b.*", "a.x, a.val, b.y, b.score"])
def test_join_qualifiers(join_condition, selection):
a = daft.from_pydict({"x": [1, None], "val": [10, 20]})
b = daft.from_pydict({"y": [1, None], "score": [0.1, 0.2]})

catalog = SQLCatalog({"a": a, "b": b})

df_sql = daft.sql(f"select * from a join b on {join_condition}", catalog).to_pydict()
df_sql = daft.sql(f"select {selection} from a join b on {join_condition}", catalog).to_pydict()

expected = {"x": [1], "val": [10], "y": [1], "score": [0.1]}

assert df_sql == expected


@pytest.mark.parametrize(
"join_condition",
[
"x = y",
"x = b1.y",
"y = x",
"y = a1.x",
"a1.x = y",
"a1.x = b1.y",
"b1.y = x",
"b1.y = a1.x",
],
)
@pytest.mark.parametrize(
"selection", ["*", "a1.*, b1.y, b.score", "a1.x, a1.val, b1.*", "a1.x, a1.val, b1.y, b1.score"]
)
def test_join_qualifiers_with_alias(join_condition, selection):
a = daft.from_pydict({"x": [1, None], "val": [10, 20]})
b = daft.from_pydict({"y": [1, None], "score": [0.1, 0.2]})

catalog = SQLCatalog({"a": a, "b": b})

df_sql = daft.sql(f"select {selection} from a as a1 join b as b1 on {join_condition}", catalog).to_pydict()

expected = {"x": [1], "val": [10], "y": [1], "score": [0.1]}

Expand Down

0 comments on commit dab006f

Please sign in to comment.