Skip to content

Commit

Permalink
[BUG] Fix join op names and join key definition (#2631)
Browse files Browse the repository at this point in the history
This simplifies our logic for joins by making the right side renaming an
explicit project before the join during logical planning. This allows
code further in the pipeline to simply assume that the column names are
preserved during the join and that the only matching column names
between the left and right sides are the join keys. This PR also moves
the revised `infer_join_schema` function into `daft-dsl` so that it can
be used both in planning and execution.

Moreover, this PR changes join behavior slightly, by only merging join
keys from both sides if they are exact column expressions. For example,
before `df1.join(df2, left_on="a", right_on=col("a") + col("b"),
how="outer")` would have created a single column "a" with both the left
and right values, but now does not, and instead splits it into "a" and
"right.a".

This does not fix the filter pushdown bug with joins, which will be
tackled in a later PR.

Also resolves #1294
  • Loading branch information
kevinzwang authored Aug 13, 2024
1 parent d809072 commit df09bd4
Show file tree
Hide file tree
Showing 17 changed files with 586 additions and 553 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-dsl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ common-resource-request = {path = "../common/resource-request", default-features
common-treenode = {path = "../common/treenode", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-sketch = {path = "../daft-sketch", default-features = false}
indexmap = {workspace = true}
itertools = {workspace = true}
log = {workspace = true}
pyo3 = {workspace = true, optional = true}
Expand Down
116 changes: 116 additions & 0 deletions src/daft-dsl/src/join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::sync::Arc;

use common_error::{DaftError, DaftResult};
use daft_core::{
schema::{Schema, SchemaRef},
JoinType,
};
use indexmap::IndexSet;

use crate::{Expr, ExprRef};

/// Get the columns between the two sides of the join that should be merged in the order of the join keys.
/// Join keys should only be merged if they are column expressions.
pub fn get_common_join_keys<'a>(
left_on: &'a [ExprRef],
right_on: &'a [ExprRef],
) -> impl Iterator<Item = &'a Arc<str>> {
left_on.iter().zip(right_on.iter()).filter_map(|(l, r)| {
if let (Expr::Column(l_name), Expr::Column(r_name)) = (&**l, &**r)
&& l_name == r_name
{
Some(l_name)
} else {
None
}
})
}

/// Infer the schema of a join operation
///
/// This function assumes that the only common field names between the left and right schemas are the join fields,
/// which is valid because the right columns are renamed during the construction of a join logical operation.
pub fn infer_join_schema(
left_schema: &SchemaRef,
right_schema: &SchemaRef,
left_on: &[ExprRef],
right_on: &[ExprRef],
how: JoinType,
) -> DaftResult<SchemaRef> {
if left_on.len() != right_on.len() {
return Err(DaftError::ValueError(format!(
"Length of left_on does not match length of right_on for Join {} vs {}",
left_on.len(),
right_on.len()
)));
}

if matches!(how, JoinType::Anti | JoinType::Semi) {
Ok(left_schema.clone())
} else {
let common_join_keys: IndexSet<_> = get_common_join_keys(left_on, right_on)
.map(|k| k.to_string())
.collect();

// common join fields, then unique left fields, then unique right fields
let fields: Vec<_> = common_join_keys
.iter()
.map(|name| {
left_schema
.get_field(name)
.expect("Common join key should exist in left schema")
})
.chain(left_schema.fields.iter().filter_map(|(name, field)| {
if common_join_keys.contains(name) {
None
} else {
Some(field)
}
}))
.chain(right_schema.fields.iter().filter_map(|(name, field)| {
if common_join_keys.contains(name) {
None
} else if left_schema.fields.contains_key(name) {
unreachable!("Right schema should have renamed columns")
} else {
Some(field)
}
}))
.cloned()
.collect();

Ok(Schema::new(fields)?.into())
}
}

#[cfg(test)]
mod tests {
use crate::col;

use super::*;

#[test]
fn test_get_common_join_keys() {
let left_on: &[ExprRef] = &[
col("a"),
col("b_left"),
col("c").alias("c_new"),
col("d").alias("d_new"),
col("e").add(col("f")),
];

let right_on: &[ExprRef] = &[
col("a"),
col("b_right"),
col("c"),
col("d").alias("d_new"),
col("e"),
];

let common_join_keys = get_common_join_keys(left_on, right_on)
.map(|k| k.to_string())
.collect::<Vec<_>>();

assert_eq!(common_join_keys, vec!["a"]);
}
}
1 change: 1 addition & 0 deletions src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
mod arithmetic;
mod expr;
pub mod functions;
pub mod join;
mod lit;
pub mod optimization;
#[cfg(feature = "python")]
Expand Down
28 changes: 5 additions & 23 deletions src/daft-local-execution/src/sinks/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ use futures::{stream, StreamExt};
use tracing::info_span;

use super::blocking_sink::{BlockingSink, BlockingSinkStatus};
use daft_table::{
infer_join_schema_mapper, GrowableTable, JoinOutputMapper, ProbeTable, ProbeTableBuilder, Table,
};
use daft_table::{GrowableTable, ProbeTable, ProbeTableBuilder, Table};

enum HashJoinState {
Building {
Expand Down Expand Up @@ -65,7 +63,7 @@ impl HashJoinState {
panic!("add_tables can only be used during the Building Phase")
}
}
fn finalize(&mut self, join_mapper: &JoinOutputMapper) -> DaftResult<()> {
fn finalize(&mut self) -> DaftResult<()> {
if let Self::Building {
probe_table_builder,
tables,
Expand All @@ -74,14 +72,10 @@ impl HashJoinState {
{
let ptb = std::mem::take(probe_table_builder).expect("should be set in building mode");
let pt = ptb.build();
let mapped_tables = tables
.iter()
.map(|t| join_mapper.map_left(t))
.collect::<DaftResult<Vec<_>>>()?;

*self = Self::Probing {
probe_table: Arc::new(pt),
tables: Arc::new(mapped_tables),
tables: Arc::new(tables.clone()),
};
Ok(())
} else {
Expand All @@ -93,7 +87,6 @@ impl HashJoinState {
pub(crate) struct HashJoinOperator {
right_on: Vec<ExprRef>,
_join_type: JoinType,
join_mapper: Arc<JoinOutputMapper>,
join_state: HashJoinState,
}

Expand Down Expand Up @@ -126,9 +119,6 @@ impl HashJoinOperator {
)?
.into();

let join_mapper =
infer_join_schema_mapper(left_schema, right_schema, &left_on, &right_on, join_type)?;

let left_on = left_on
.into_iter()
.zip(key_schema.fields.values())
Expand All @@ -143,7 +133,6 @@ impl HashJoinOperator {
Ok(Self {
right_on,
_join_type: join_type,
join_mapper: Arc::new(join_mapper),
join_state: HashJoinState::new(&key_schema, left_on)?,
})
}
Expand All @@ -162,7 +151,6 @@ impl HashJoinOperator {
probe_table: probe_table.clone(),
tables: tables.clone(),
right_on: self.right_on.clone(),
join_mapper: self.join_mapper.clone(),
})
} else {
panic!("can't call as_intermediate_op when not in probing state")
Expand All @@ -174,7 +162,6 @@ struct HashJoinProber {
probe_table: Arc<ProbeTable>,
tables: Arc<Vec<Table>>,
right_on: Vec<ExprRef>,
join_mapper: Arc<JoinOutputMapper>,
}

impl IntermediateOperator for HashJoinProber {
Expand All @@ -192,13 +179,8 @@ impl IntermediateOperator for HashJoinProber {

let right_input_tables = input.get_tables()?;

let right_tables = right_input_tables
.iter()
.map(|t| self.join_mapper.map_right(t))
.collect::<DaftResult<Vec<_>>>()?;

let mut right_growable =
GrowableTable::new(&right_tables.iter().collect::<Vec<_>>(), false, 20)?;
GrowableTable::new(&right_input_tables.iter().collect::<Vec<_>>(), false, 20)?;

drop(_growables);
{
Expand Down Expand Up @@ -237,7 +219,7 @@ impl BlockingSink for HashJoinOperator {
Ok(BlockingSinkStatus::NeedMoreInput)
}
fn finalize(&mut self) -> DaftResult<()> {
self.join_state.finalize(&self.join_mapper)?;
self.join_state.finalize()?;
Ok(())
}
fn as_source(&mut self) -> &mut dyn Source {
Expand Down
8 changes: 5 additions & 3 deletions src/daft-micropartition/src/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use std::sync::Arc;

use common_error::DaftResult;
use daft_core::{array::ops::DaftCompare, join::JoinType};
use daft_dsl::ExprRef;
use daft_dsl::{join::infer_join_schema, ExprRef};
use daft_io::IOStatsContext;
use daft_table::{infer_join_schema, Table};
use daft_table::Table;

use crate::micropartition::MicroPartition;

Expand All @@ -29,12 +29,14 @@ impl MicroPartition {
| (JoinType::Inner, _, 0)
| (JoinType::Left, 0, _)
| (JoinType::Right, _, 0)
| (JoinType::Outer, 0, 0) => {
| (JoinType::Outer, 0, 0)
| (JoinType::Semi, 0, _) => {
return Ok(Self::empty(Some(join_schema)));
}
_ => {}
}

// TODO(Kevin): short circuits are also possible for other join types
if how == JoinType::Inner {
let tv = match (&self.statistics, &right.statistics) {
(_, None) => TruthValue::Maybe,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ Filter2["Filter: col(first_name) == lit('hello')"]
Join3["Join: Type = Inner
Strategy = Auto
On = col(id)
Output schema = text#Utf8, id#Int32, first_name#Utf8, last_name#Utf8"]
Output schema = id#Int32, text#Utf8, first_name#Utf8, last_name#Utf8"]
Filter4["Filter: col(id) == lit(1)"]
Source5["PlaceHolder:
Source ID = 0
Expand Down
Loading

0 comments on commit df09bd4

Please sign in to comment.