-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BUG] Fix join op names and join key definition (#2631)
This simplifies our logic for joins by making the right side renaming an explicit project before the join during logical planning. This allows code further in the pipeline to simply assume that the column names are preserved during the join and that the only matching column names between the left and right sides are the join keys. This PR also moves the revised `infer_join_schema` function into `daft-dsl` so that it can be used both in planning and execution. Moreover, this PR changes join behavior slightly, by only merging join keys from both sides if they are exact column expressions. For example, before `df1.join(df2, left_on="a", right_on=col("a") + col("b"), how="outer")` would have created a single column "a" with both the left and right values, but now does not, and instead splits it into "a" and "right.a". This does not fix the filter pushdown bug with joins, which will be tackled in a later PR. Also resolves #1294
- Loading branch information
1 parent
d809072
commit df09bd4
Showing
17 changed files
with
586 additions
and
553 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
use std::sync::Arc; | ||
|
||
use common_error::{DaftError, DaftResult}; | ||
use daft_core::{ | ||
schema::{Schema, SchemaRef}, | ||
JoinType, | ||
}; | ||
use indexmap::IndexSet; | ||
|
||
use crate::{Expr, ExprRef}; | ||
|
||
/// Get the columns between the two sides of the join that should be merged in the order of the join keys. | ||
/// Join keys should only be merged if they are column expressions. | ||
pub fn get_common_join_keys<'a>( | ||
left_on: &'a [ExprRef], | ||
right_on: &'a [ExprRef], | ||
) -> impl Iterator<Item = &'a Arc<str>> { | ||
left_on.iter().zip(right_on.iter()).filter_map(|(l, r)| { | ||
if let (Expr::Column(l_name), Expr::Column(r_name)) = (&**l, &**r) | ||
&& l_name == r_name | ||
{ | ||
Some(l_name) | ||
} else { | ||
None | ||
} | ||
}) | ||
} | ||
|
||
/// Infer the schema of a join operation | ||
/// | ||
/// This function assumes that the only common field names between the left and right schemas are the join fields, | ||
/// which is valid because the right columns are renamed during the construction of a join logical operation. | ||
pub fn infer_join_schema( | ||
left_schema: &SchemaRef, | ||
right_schema: &SchemaRef, | ||
left_on: &[ExprRef], | ||
right_on: &[ExprRef], | ||
how: JoinType, | ||
) -> DaftResult<SchemaRef> { | ||
if left_on.len() != right_on.len() { | ||
return Err(DaftError::ValueError(format!( | ||
"Length of left_on does not match length of right_on for Join {} vs {}", | ||
left_on.len(), | ||
right_on.len() | ||
))); | ||
} | ||
|
||
if matches!(how, JoinType::Anti | JoinType::Semi) { | ||
Ok(left_schema.clone()) | ||
} else { | ||
let common_join_keys: IndexSet<_> = get_common_join_keys(left_on, right_on) | ||
.map(|k| k.to_string()) | ||
.collect(); | ||
|
||
// common join fields, then unique left fields, then unique right fields | ||
let fields: Vec<_> = common_join_keys | ||
.iter() | ||
.map(|name| { | ||
left_schema | ||
.get_field(name) | ||
.expect("Common join key should exist in left schema") | ||
}) | ||
.chain(left_schema.fields.iter().filter_map(|(name, field)| { | ||
if common_join_keys.contains(name) { | ||
None | ||
} else { | ||
Some(field) | ||
} | ||
})) | ||
.chain(right_schema.fields.iter().filter_map(|(name, field)| { | ||
if common_join_keys.contains(name) { | ||
None | ||
} else if left_schema.fields.contains_key(name) { | ||
unreachable!("Right schema should have renamed columns") | ||
} else { | ||
Some(field) | ||
} | ||
})) | ||
.cloned() | ||
.collect(); | ||
|
||
Ok(Schema::new(fields)?.into()) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::col; | ||
|
||
use super::*; | ||
|
||
#[test] | ||
fn test_get_common_join_keys() { | ||
let left_on: &[ExprRef] = &[ | ||
col("a"), | ||
col("b_left"), | ||
col("c").alias("c_new"), | ||
col("d").alias("d_new"), | ||
col("e").add(col("f")), | ||
]; | ||
|
||
let right_on: &[ExprRef] = &[ | ||
col("a"), | ||
col("b_right"), | ||
col("c"), | ||
col("d").alias("d_new"), | ||
col("e"), | ||
]; | ||
|
||
let common_join_keys = get_common_join_keys(left_on, right_on) | ||
.map(|k| k.to_string()) | ||
.collect::<Vec<_>>(); | ||
|
||
assert_eq!(common_join_keys, vec!["a"]); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.