From df09bd489b66fd30ae9c8c0bf841cb6ef4af2ea5 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Mon, 12 Aug 2024 17:02:50 -0700 Subject: [PATCH] [BUG] Fix join op names and join key definition (#2631) This simplifies our logic for joins by making the right side renaming an explicit project before the join during logical planning. This allows code further in the pipeline to simply assume that the column names are preserved during the join and that the only matching column names between the left and right sides are the join keys. This PR also moves the revised `infer_join_schema` function into `daft-dsl` so that it can be used both in planning and execution. Moreover, this PR changes join behavior slightly, by only merging join keys from both sides if they are exact column expressions. For example, before `df1.join(df2, left_on="a", right_on=col("a") + col("b"), how="outer")` would have created a single column "a" with both the left and right values, but now does not, and instead splits it into "a" and "right.a". This does not fix the filter pushdown bug with joins, which will be tackled in a later PR. Also resolves #1294 --- Cargo.lock | 1 + src/daft-dsl/Cargo.toml | 1 + src/daft-dsl/src/join.rs | 116 +++++++ src/daft-dsl/src/lib.rs | 1 + .../src/sinks/hash_join.rs | 28 +- src/daft-micropartition/src/ops/join.rs | 8 +- src/daft-plan/src/display.rs | 2 +- src/daft-plan/src/logical_ops/join.rs | 147 +++++--- .../rules/push_down_filter.rs | 13 +- .../rules/push_down_projection.rs | 118 +++---- .../src/physical_planner/translate.rs | 18 +- src/daft-table/src/lib.rs | 1 - src/daft-table/src/ops/joins/hash_join.rs | 76 ++--- src/daft-table/src/ops/joins/mod.rs | 321 ++---------------- src/daft-table/src/ops/mod.rs | 2 - tests/dataframe/test_joins.py | 239 +++++++++++++ tests/table/test_joins.py | 47 +-- 17 files changed, 586 insertions(+), 553 deletions(-) create mode 100644 src/daft-dsl/src/join.rs diff --git a/Cargo.lock b/Cargo.lock index 8a5a096bab..d3171fecec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1712,6 +1712,7 @@ dependencies = [ "common-treenode", "daft-core", "daft-sketch", + "indexmap 2.3.0", "itertools 0.11.0", "log", "pyo3", diff --git a/src/daft-dsl/Cargo.toml b/src/daft-dsl/Cargo.toml index edd7bd8391..e2999b82a3 100644 --- a/src/daft-dsl/Cargo.toml +++ b/src/daft-dsl/Cargo.toml @@ -6,6 +6,7 @@ common-resource-request = {path = "../common/resource-request", default-features common-treenode = {path = "../common/treenode", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-sketch = {path = "../daft-sketch", default-features = false} +indexmap = {workspace = true} itertools = {workspace = true} log = {workspace = true} pyo3 = {workspace = true, optional = true} diff --git a/src/daft-dsl/src/join.rs b/src/daft-dsl/src/join.rs new file mode 100644 index 0000000000..26e9220c7a --- /dev/null +++ b/src/daft-dsl/src/join.rs @@ -0,0 +1,116 @@ +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use daft_core::{ + schema::{Schema, SchemaRef}, + JoinType, +}; +use indexmap::IndexSet; + +use crate::{Expr, ExprRef}; + +/// Get the columns between the two sides of the join that should be merged in the order of the join keys. +/// Join keys should only be merged if they are column expressions. +pub fn get_common_join_keys<'a>( + left_on: &'a [ExprRef], + right_on: &'a [ExprRef], +) -> impl Iterator> { + left_on.iter().zip(right_on.iter()).filter_map(|(l, r)| { + if let (Expr::Column(l_name), Expr::Column(r_name)) = (&**l, &**r) + && l_name == r_name + { + Some(l_name) + } else { + None + } + }) +} + +/// Infer the schema of a join operation +/// +/// This function assumes that the only common field names between the left and right schemas are the join fields, +/// which is valid because the right columns are renamed during the construction of a join logical operation. +pub fn infer_join_schema( + left_schema: &SchemaRef, + right_schema: &SchemaRef, + left_on: &[ExprRef], + right_on: &[ExprRef], + how: JoinType, +) -> DaftResult { + if left_on.len() != right_on.len() { + return Err(DaftError::ValueError(format!( + "Length of left_on does not match length of right_on for Join {} vs {}", + left_on.len(), + right_on.len() + ))); + } + + if matches!(how, JoinType::Anti | JoinType::Semi) { + Ok(left_schema.clone()) + } else { + let common_join_keys: IndexSet<_> = get_common_join_keys(left_on, right_on) + .map(|k| k.to_string()) + .collect(); + + // common join fields, then unique left fields, then unique right fields + let fields: Vec<_> = common_join_keys + .iter() + .map(|name| { + left_schema + .get_field(name) + .expect("Common join key should exist in left schema") + }) + .chain(left_schema.fields.iter().filter_map(|(name, field)| { + if common_join_keys.contains(name) { + None + } else { + Some(field) + } + })) + .chain(right_schema.fields.iter().filter_map(|(name, field)| { + if common_join_keys.contains(name) { + None + } else if left_schema.fields.contains_key(name) { + unreachable!("Right schema should have renamed columns") + } else { + Some(field) + } + })) + .cloned() + .collect(); + + Ok(Schema::new(fields)?.into()) + } +} + +#[cfg(test)] +mod tests { + use crate::col; + + use super::*; + + #[test] + fn test_get_common_join_keys() { + let left_on: &[ExprRef] = &[ + col("a"), + col("b_left"), + col("c").alias("c_new"), + col("d").alias("d_new"), + col("e").add(col("f")), + ]; + + let right_on: &[ExprRef] = &[ + col("a"), + col("b_right"), + col("c"), + col("d").alias("d_new"), + col("e"), + ]; + + let common_join_keys = get_common_join_keys(left_on, right_on) + .map(|k| k.to_string()) + .collect::>(); + + assert_eq!(common_join_keys, vec!["a"]); + } +} diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 9c0c9c00e9..557cf3a160 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -4,6 +4,7 @@ mod arithmetic; mod expr; pub mod functions; +pub mod join; mod lit; pub mod optimization; #[cfg(feature = "python")] diff --git a/src/daft-local-execution/src/sinks/hash_join.rs b/src/daft-local-execution/src/sinks/hash_join.rs index 82dbee00f5..992767c95b 100644 --- a/src/daft-local-execution/src/sinks/hash_join.rs +++ b/src/daft-local-execution/src/sinks/hash_join.rs @@ -21,9 +21,7 @@ use futures::{stream, StreamExt}; use tracing::info_span; use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; -use daft_table::{ - infer_join_schema_mapper, GrowableTable, JoinOutputMapper, ProbeTable, ProbeTableBuilder, Table, -}; +use daft_table::{GrowableTable, ProbeTable, ProbeTableBuilder, Table}; enum HashJoinState { Building { @@ -65,7 +63,7 @@ impl HashJoinState { panic!("add_tables can only be used during the Building Phase") } } - fn finalize(&mut self, join_mapper: &JoinOutputMapper) -> DaftResult<()> { + fn finalize(&mut self) -> DaftResult<()> { if let Self::Building { probe_table_builder, tables, @@ -74,14 +72,10 @@ impl HashJoinState { { let ptb = std::mem::take(probe_table_builder).expect("should be set in building mode"); let pt = ptb.build(); - let mapped_tables = tables - .iter() - .map(|t| join_mapper.map_left(t)) - .collect::>>()?; *self = Self::Probing { probe_table: Arc::new(pt), - tables: Arc::new(mapped_tables), + tables: Arc::new(tables.clone()), }; Ok(()) } else { @@ -93,7 +87,6 @@ impl HashJoinState { pub(crate) struct HashJoinOperator { right_on: Vec, _join_type: JoinType, - join_mapper: Arc, join_state: HashJoinState, } @@ -126,9 +119,6 @@ impl HashJoinOperator { )? .into(); - let join_mapper = - infer_join_schema_mapper(left_schema, right_schema, &left_on, &right_on, join_type)?; - let left_on = left_on .into_iter() .zip(key_schema.fields.values()) @@ -143,7 +133,6 @@ impl HashJoinOperator { Ok(Self { right_on, _join_type: join_type, - join_mapper: Arc::new(join_mapper), join_state: HashJoinState::new(&key_schema, left_on)?, }) } @@ -162,7 +151,6 @@ impl HashJoinOperator { probe_table: probe_table.clone(), tables: tables.clone(), right_on: self.right_on.clone(), - join_mapper: self.join_mapper.clone(), }) } else { panic!("can't call as_intermediate_op when not in probing state") @@ -174,7 +162,6 @@ struct HashJoinProber { probe_table: Arc, tables: Arc>, right_on: Vec, - join_mapper: Arc, } impl IntermediateOperator for HashJoinProber { @@ -192,13 +179,8 @@ impl IntermediateOperator for HashJoinProber { let right_input_tables = input.get_tables()?; - let right_tables = right_input_tables - .iter() - .map(|t| self.join_mapper.map_right(t)) - .collect::>>()?; - let mut right_growable = - GrowableTable::new(&right_tables.iter().collect::>(), false, 20)?; + GrowableTable::new(&right_input_tables.iter().collect::>(), false, 20)?; drop(_growables); { @@ -237,7 +219,7 @@ impl BlockingSink for HashJoinOperator { Ok(BlockingSinkStatus::NeedMoreInput) } fn finalize(&mut self) -> DaftResult<()> { - self.join_state.finalize(&self.join_mapper)?; + self.join_state.finalize()?; Ok(()) } fn as_source(&mut self) -> &mut dyn Source { diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs index 8cc5c3d078..eda430c3fd 100644 --- a/src/daft-micropartition/src/ops/join.rs +++ b/src/daft-micropartition/src/ops/join.rs @@ -2,9 +2,9 @@ use std::sync::Arc; use common_error::DaftResult; use daft_core::{array::ops::DaftCompare, join::JoinType}; -use daft_dsl::ExprRef; +use daft_dsl::{join::infer_join_schema, ExprRef}; use daft_io::IOStatsContext; -use daft_table::{infer_join_schema, Table}; +use daft_table::Table; use crate::micropartition::MicroPartition; @@ -29,12 +29,14 @@ impl MicroPartition { | (JoinType::Inner, _, 0) | (JoinType::Left, 0, _) | (JoinType::Right, _, 0) - | (JoinType::Outer, 0, 0) => { + | (JoinType::Outer, 0, 0) + | (JoinType::Semi, 0, _) => { return Ok(Self::empty(Some(join_schema))); } _ => {} } + // TODO(Kevin): short circuits are also possible for other join types if how == JoinType::Inner { let tv = match (&self.statistics, &right.statistics) { (_, None) => TruthValue::Maybe, diff --git a/src/daft-plan/src/display.rs b/src/daft-plan/src/display.rs index d215c93c77..1630002d43 100644 --- a/src/daft-plan/src/display.rs +++ b/src/daft-plan/src/display.rs @@ -147,7 +147,7 @@ Filter2["Filter: col(first_name) == lit('hello')"] Join3["Join: Type = Inner Strategy = Auto On = col(id) -Output schema = text#Utf8, id#Int32, first_name#Utf8, last_name#Utf8"] +Output schema = id#Int32, text#Utf8, first_name#Utf8, last_name#Utf8"] Filter4["Filter: col(id) == lit(1)"] Source5["PlaceHolder: Source ID = 0 diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 523f408c52..648dcac754 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -1,16 +1,25 @@ -use std::{collections::HashSet, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use common_error::DaftError; use daft_core::{ join::{JoinStrategy, JoinType}, - schema::{hash_index_map, Schema, SchemaRef}, + schema::{Schema, SchemaRef}, DataType, }; -use daft_dsl::{resolve_exprs, ExprRef}; +use daft_dsl::{ + col, + join::{get_common_join_keys, infer_join_schema}, + optimization::replace_columns_with_expressions, + resolve_exprs, Expr, ExprRef, +}; use itertools::Itertools; use snafu::ResultExt; use crate::{ + logical_ops::Project, logical_plan::{self, CreationSnafu}, LogicalPlan, }; @@ -26,10 +35,6 @@ pub struct Join { pub join_type: JoinType, pub join_strategy: Option, pub output_schema: SchemaRef, - - // Joins may rename columns from the right input; this struct tracks those renames. - // Output name -> Original name - pub right_input_mapping: indexmap::IndexMap, } impl std::hash::Hash for Join { @@ -41,7 +46,6 @@ impl std::hash::Hash for Join { std::hash::Hash::hash(&self.join_type, state); std::hash::Hash::hash(&self.join_strategy, state); std::hash::Hash::hash(&self.output_schema, state); - state.write_u64(hash_index_map(&self.right_input_mapping)) } } @@ -70,46 +74,101 @@ impl Join { } } } - let mut right_input_mapping = indexmap::IndexMap::new(); - // Schema inference ported from existing behaviour for parity, - // but contains bug https://github.com/Eventual-Inc/Daft/issues/1294 - let output_schema = { - let left_join_keys = left_on.iter().map(|e| e.name()).collect::>(); - let right_join_keys = right_on.iter().map(|e| e.name()).collect::>(); - let left_schema = &left.schema().fields; - let fields = left_schema + + if matches!(join_type, JoinType::Anti | JoinType::Semi) { + // The output schema is the same as the left input schema for anti and semi joins. + + let output_schema = left.schema().clone(); + + Ok(Self { + left, + right, + left_on, + right_on, + join_type, + join_strategy, + output_schema, + }) + } else { + let common_join_keys: HashSet<_> = + get_common_join_keys(left_on.as_slice(), right_on.as_slice()) + .map(|k| k.to_string()) + .collect(); + + let left_names = left.schema().names(); + let right_names = right.schema().names(); + + let mut names_so_far: HashSet = HashSet::from_iter(left_names); + + // rename right columns that have the same name as left columns and are not join keys + // old_name -> new_name + let right_rename_mapping: HashMap<_, _> = right_names .iter() - .map(|(_, field)| field) - .cloned() - .chain(right.schema().fields.iter().filter_map(|(rname, rfield)| { - if (left_join_keys.contains(rname.as_str()) - && right_join_keys.contains(rname.as_str())) - || matches!(join_type, JoinType::Anti | JoinType::Semi) - { - right_input_mapping.insert(rname.clone(), rname.clone()); + .filter_map(|name| { + if !names_so_far.contains(name) || common_join_keys.contains(name) { None - } else if left_schema.contains_key(rname) { - let new_name = format!("right.{}", rname); - right_input_mapping.insert(new_name.clone(), rname.clone()); - Some(rfield.rename(new_name)) } else { - right_input_mapping.insert(rname.clone(), rname.clone()); - Some(rfield.clone()) + let mut new_name = name.clone(); + while names_so_far.contains(&new_name) { + new_name = format!("right.{}", new_name); + } + names_so_far.insert(new_name.clone()); + + Some((name.clone(), new_name)) } - })) - .collect::>(); - Schema::new(fields).context(CreationSnafu)?.into() - }; - Ok(Self { - left, - right, - left_on, - right_on, - join_type, - join_strategy, - output_schema, - right_input_mapping, - }) + }) + .collect(); + + let (right, right_on) = if right_rename_mapping.is_empty() { + (right, right_on) + } else { + // projection to update the right side with the new column names + let new_right_projection: Vec<_> = right_names + .iter() + .map(|name| { + if let Some(new_name) = right_rename_mapping.get(name) { + Expr::Alias(col(name.clone()), new_name.clone().into()).into() + } else { + col(name.clone()) + } + }) + .collect(); + + let new_right: LogicalPlan = Project::try_new(right, new_right_projection)?.into(); + + let right_on_replace_map = right_rename_mapping + .iter() + .map(|(old_name, new_name)| (old_name.clone(), col(new_name.clone()))) + .collect::>(); + + // change any column references in the right_on expressions to the new column names + let new_right_on = right_on + .into_iter() + .map(|expr| replace_columns_with_expressions(expr, &right_on_replace_map)) + .collect::>(); + + (new_right.into(), new_right_on) + }; + + let output_schema = infer_join_schema( + &left.schema(), + &right.schema(), + &left_on, + &right_on, + join_type, + ) + .context(CreationSnafu)?; + + Ok(Self { + left, + right, + left_on, + right_on, + join_type, + join_strategy, + output_schema, + }) + } } pub fn multiline_display(&self) -> Vec { diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs index ac5d859e2d..badec59a89 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_filter.rs @@ -602,7 +602,6 @@ mod tests { #[rstest] fn filter_commutes_with_join( #[values(false, true)] push_into_left_scan: bool, - #[values(false, true)] push_into_right_scan: bool, ) -> DaftResult<()> { let scan_op = dummy_scan_operator(vec![ Field::new("a", DataType::Int64), @@ -612,10 +611,8 @@ mod tests { scan_op.clone(), Pushdowns::default().with_limit(if push_into_left_scan { None } else { Some(1) }), ); - let right_scan_plan = dummy_scan_node_with_pushdowns( - scan_op.clone(), - Pushdowns::default().with_limit(if push_into_right_scan { None } else { Some(1) }), - ); + let right_scan_plan = + dummy_scan_node_with_pushdowns(scan_op.clone(), Pushdowns::default().with_limit(None)); let join_on = vec![col("b")]; let pred = col("a").lt(lit(2)); let plan = left_scan_plan @@ -636,11 +633,7 @@ mod tests { } else { left_scan_plan.filter(pred.clone())? }; - let expected_right_filter_scan = if push_into_right_scan { - dummy_scan_node_with_pushdowns(scan_op, Pushdowns::default().with_filters(Some(pred))) - } else { - right_scan_plan.filter(pred)? - }; + let expected_right_filter_scan = right_scan_plan; let expected = expected_left_filter_scan .join( &expected_right_filter_scan, diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs index 01882e3abe..6fcbe7aa30 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs @@ -10,7 +10,7 @@ use indexmap::IndexSet; use crate::{ logical_ops::{ActorPoolProject, Aggregate, Join, Pivot, Project, Source}, source_info::SourceInfo, - LogicalPlan, + LogicalPlan, LogicalPlanRef, }; use super::{ApplyOrder, OptimizerRule, Transformed}; @@ -352,7 +352,7 @@ impl PushDownProjection { } LogicalPlan::Join(join) => { // Get required columns from projection and both upstreams. - let [projection_required_columns] = &plan.required_columns()[..] else { + let [projection_dependencies] = &plan.required_columns()[..] else { panic!() }; let [left_dependencies, right_dependencies] = &upstream_plan.required_columns()[..] @@ -360,86 +360,66 @@ impl PushDownProjection { panic!() }; - let left_upstream_names = join - .left - .schema() - .names() - .iter() - .cloned() - .collect::>(); - let right_upstream_names = join - .right - .schema() - .names() - .iter() - .cloned() - .collect::>(); - - let right_combined_dependencies = projection_required_columns - .iter() - .filter_map(|colname| join.right_input_mapping.get(colname)) - .chain(right_dependencies.iter()) - .cloned() - .collect::>(); - - let left_combined_dependencies = projection_required_columns - .iter() - .filter_map(|colname| left_upstream_names.get(colname)) - .chain(left_dependencies.iter()) - // We also have to keep any name conflict columns referenced by the right side. - // E.g. if the user wants "right.c", left must also provide "c", or "right.c" disappears. - // This is mostly an artifact of https://github.com/Eventual-Inc/Daft/issues/1303 - .chain( - right_combined_dependencies - .iter() - .filter_map(|rname| left_upstream_names.get(rname)), - ) - .cloned() - .collect::>(); + /// For one side of the join, see if a non-vacuous pushdown is possible. + fn maybe_project_upstream_input( + side: &LogicalPlanRef, + side_dependencies: &IndexSet, + projection_dependencies: &IndexSet, + ) -> DaftResult> { + let schema = side.schema(); + let upstream_names: IndexSet = schema.fields.keys().cloned().collect(); + + let combined_dependencies: IndexSet<_> = side_dependencies + .union( + &upstream_names + .intersection(projection_dependencies) + .cloned() + .collect::>(), + ) + .cloned() + .collect(); - // For each upstream, see if a non-vacuous pushdown is possible. - let maybe_new_left_upstream: Option> = { - if left_combined_dependencies.len() < left_upstream_names.len() { - let pushdown_column_exprs: Vec = left_combined_dependencies + if combined_dependencies.len() < upstream_names.len() { + let pushdown_column_exprs: Vec = combined_dependencies .into_iter() - .map(col) - .collect::>(); + .map(|d| col(d.to_string())) + .collect(); let new_project: LogicalPlan = - Project::try_new(join.left.clone(), pushdown_column_exprs)?.into(); - Some(new_project.into()) + Project::try_new(side.clone(), pushdown_column_exprs)?.into(); + Ok(Transformed::Yes(new_project.into())) } else { - None + Ok(Transformed::No(side.clone())) } - }; + } - let maybe_new_right_upstream: Option> = { - if right_combined_dependencies.len() < right_upstream_names.len() { - let pushdown_column_exprs: Vec = right_combined_dependencies - .into_iter() - .map(col) - .collect::>(); - let new_project: LogicalPlan = - Project::try_new(join.right.clone(), pushdown_column_exprs)?.into(); - Some(new_project.into()) - } else { - None - } - }; + let new_left_upstream = maybe_project_upstream_input( + &join.left, + left_dependencies, + projection_dependencies, + )?; + let new_right_upstream = maybe_project_upstream_input( + &join.right, + right_dependencies, + projection_dependencies, + )?; + + if new_left_upstream.is_no() && new_right_upstream.is_no() { + Ok(Transformed::No(plan)) + } else { + // If either pushdown is possible, create a new Join node. + let new_join = upstream_plan.with_new_children(&[ + new_left_upstream.unwrap().clone(), + new_right_upstream.unwrap().clone(), + ]); - // If either pushdown is possible, create a new Join node. - if maybe_new_left_upstream.is_some() || maybe_new_right_upstream.is_some() { - let new_left_upstream = maybe_new_left_upstream.unwrap_or(join.left.clone()); - let new_right_upstream = maybe_new_right_upstream.unwrap_or(join.right.clone()); - let new_join = - upstream_plan.with_new_children(&[new_left_upstream, new_right_upstream]); let new_plan = Arc::new(plan.with_new_children(&[new_join.into()])); + // Retry optimization now that the upstream node is different. let new_plan = self .try_optimize(new_plan.clone())? .or(Transformed::Yes(new_plan)); + Ok(new_plan) - } else { - Ok(Transformed::No(plan)) } } LogicalPlan::Distinct(_) => { diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 7f5c012afd..b0a530e992 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -423,11 +423,12 @@ pub(super) fn translate_single_logical_node( Ok(PhysicalPlan::Concat(Concat::new(input_physical, other_physical)).arced()) } LogicalPlan::Join(LogicalJoin { + left, + right, left_on, right_on, join_type, join_strategy, - output_schema, .. }) => { let mut right_physical = physical_children.pop().expect("requires 1 inputs"); @@ -496,11 +497,9 @@ pub(super) fn translate_single_logical_node( is_right_hash_partitioned || is_right_sort_partitioned }; let join_strategy = join_strategy.unwrap_or_else(|| { - // This method will panic if called with columns that aren't in the output schema, - // which is possible for anti- and semi-joins. - let is_primitive = |exprs: &Vec| { - exprs.iter().map(|e| e.name()).all(|col| { - let dtype = &output_schema.get_field(col).unwrap().dtype; + fn keys_are_primitive(on: &[ExprRef], schema: &SchemaRef) -> bool { + on.iter().all(|expr| { + let dtype = expr.get_type(schema).unwrap(); dtype.is_integer() || dtype.is_floating() || matches!( @@ -508,7 +507,8 @@ pub(super) fn translate_single_logical_node( DataType::Utf8 | DataType::Binary | DataType::Boolean ) }) - }; + } + // If larger table is not already partitioned on the join key AND the smaller table is under broadcast size threshold AND we are not broadcasting the side we are outer joining by, use broadcast join. if !is_larger_partitioned && let Some(smaller_size_bytes) = smaller_size_bytes @@ -524,8 +524,8 @@ pub(super) fn translate_single_logical_node( // TODO(Clark): Look into defaulting to sort-merge join over hash join under more input partitioning setups. // TODO(Kevin): Support sort-merge join for other types of joins. } else if *join_type == JoinType::Inner - && is_primitive(left_on) - && is_primitive(right_on) + && keys_are_primitive(left_on, &left.schema()) + && keys_are_primitive(right_on, &right.schema()) && (is_left_sort_partitioned || is_right_sort_partitioned) && (!is_larger_partitioned || (left_is_larger && is_left_sort_partitioned diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 548fbdff52..37253a7615 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -26,7 +26,6 @@ mod ops; mod probe_table; pub use growable::GrowableTable; -pub use ops::{infer_join_schema, infer_join_schema_mapper, JoinOutputMapper}; pub use probe_table::{ProbeTable, ProbeTableBuilder}; diff --git a/src/daft-table/src/ops/joins/hash_join.rs b/src/daft-table/src/ops/joins/hash_join.rs index 8eb25793e1..1aedebe509 100644 --- a/src/daft-table/src/ops/joins/hash_join.rs +++ b/src/daft-table/src/ops/joins/hash_join.rs @@ -6,9 +6,12 @@ use daft_core::{ datatypes::{BooleanArray, UInt64Array}, DataType, IntoSeries, JoinType, }; -use daft_dsl::ExprRef; +use daft_dsl::{ + join::{get_common_join_keys, infer_join_schema}, + ExprRef, +}; -use crate::{infer_join_schema, Table}; +use crate::Table; use common_error::DaftResult; use daft_core::array::ops::as_arrow::AsArrow; @@ -88,8 +91,10 @@ pub(super) fn hash_inner_join( } }; + let common_join_keys: Vec<_> = get_common_join_keys(left_on, right_on).collect(); + let mut join_series = left - .get_columns(lkeys.column_names().as_slice())? + .get_columns(common_join_keys.as_slice())? .take(&lidx)? .columns; @@ -97,8 +102,7 @@ pub(super) fn hash_inner_join( drop(rkeys); let num_rows = lidx.len(); - join_series = - add_non_join_key_columns(left, right, lidx, ridx, left_on, right_on, join_series)?; + join_series = add_non_join_key_columns(left, right, lidx, ridx, join_series)?; Table::new_with_size(join_schema, join_series, num_rows) } @@ -194,24 +198,17 @@ pub(super) fn hash_left_right_join( (lkeys, rkeys, lidx, ridx) }; + let common_join_keys = get_common_join_keys(left_on, right_on); + let mut join_series = if left_side { - left.get_columns(lkeys.column_names().as_slice())? + left.get_columns(common_join_keys.collect::>().as_slice())? .take(&lidx)? .columns } else { - lkeys - .column_names() - .iter() - .zip(rkeys.column_names().iter()) - .map(|(l, r)| { - let join_col = if l == r { - let col_dtype = left.get_column(l)?.data_type(); - right.get_column(r)?.take(&ridx)?.cast(col_dtype)? - } else { - left.get_column(l)?.take(&lidx)? - }; - - Ok(join_col) + common_join_keys + .map(|name| { + let col_dtype = &left.schema.get_field(name)?.dtype; + right.get_column(name)?.take(&ridx)?.cast(col_dtype) }) .collect::>>()? }; @@ -220,8 +217,7 @@ pub(super) fn hash_left_right_join( drop(rkeys); let num_rows = lidx.len(); - join_series = - add_non_join_key_columns(left, right, lidx, ridx, left_on, right_on, join_series)?; + join_series = add_non_join_key_columns(left, right, lidx, ridx, join_series)?; Table::new_with_size(join_schema, join_series, num_rows) } @@ -406,12 +402,11 @@ pub(super) fn hash_outer_join( } }; - let mut join_series = if lkeys - .column_names() - .iter() - .zip(rkeys.column_names().iter()) - .any(|(l, r)| l == r) - { + let common_join_keys: Vec<_> = get_common_join_keys(left_on, right_on).collect(); + + let mut join_series = if common_join_keys.is_empty() { + vec![] + } else { let join_key_predicate = BooleanArray::from(( "join_key_predicate", arrow2::array::BooleanArray::from_trusted_len_values_iter( @@ -427,33 +422,22 @@ pub(super) fn hash_outer_join( )) .into_series(); - lkeys - .column_names() - .iter() - .zip(rkeys.column_names().iter()) - .map(|(l, r)| { - if l == r { - let lcol = left.get_column(l)?.take(&lidx)?; - let rcol = right.get_column(r)?.take(&ridx)?; - - lcol.if_else(&rcol, &join_key_predicate) - } else { - left.get_column(l)?.take(&lidx) - } + common_join_keys + .into_iter() + .map(|name| { + let lcol = left.get_column(name)?.take(&lidx)?; + let rcol = right.get_column(name)?.take(&ridx)?; + + lcol.if_else(&rcol, &join_key_predicate) }) .collect::>>()? - } else { - left.get_columns(lkeys.column_names().as_slice())? - .take(&lidx)? - .columns }; drop(lkeys); drop(rkeys); let num_rows = lidx.len(); - join_series = - add_non_join_key_columns(left, right, lidx, ridx, left_on, right_on, join_series)?; + join_series = add_non_join_key_columns(left, right, lidx, ridx, join_series)?; Table::new_with_size(join_schema, join_series, num_rows) } diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index 1dea890b62..d7ef796b56 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -1,18 +1,14 @@ -use std::{ - collections::{HashMap, HashSet}, - sync::Arc, -}; +use std::collections::HashSet; use daft_core::{ - array::growable::make_growable, - datatypes::Field, - schema::{Schema, SchemaRef}, - utils::supertype::try_get_supertype, - JoinType, Series, + array::growable::make_growable, utils::supertype::try_get_supertype, JoinType, Series, }; use common_error::{DaftError, DaftResult}; -use daft_dsl::ExprRef; +use daft_dsl::{ + join::{get_common_join_keys, infer_join_schema}, + ExprRef, +}; use hash_join::hash_semi_anti_join; use crate::Table; @@ -45,301 +41,35 @@ fn match_types_for_tables(left: &Table, right: &Table) -> DaftResult<(Table, Tab )) } -pub fn infer_join_schema( - left: &SchemaRef, - right: &SchemaRef, - left_on: &[ExprRef], - right_on: &[ExprRef], - how: JoinType, -) -> DaftResult { - if left_on.len() != right_on.len() { - return Err(DaftError::ValueError(format!( - "Length of left_on does not match length of right_on for Join {} vs {}", - left_on.len(), - right_on.len() - ))); - } - if matches!(how, JoinType::Anti | JoinType::Semi) { - return Ok(left.clone()); - } - - let lfields = left_on - .iter() - .map(|e| e.to_field(left)) - .collect::>>()?; - let rfields = right_on - .iter() - .map(|e| e.to_field(right)) - .collect::>>()?; - - // Left Join Keys are first - let mut join_fields = lfields - .iter() - .map(|f| left.get_field(&f.name).cloned()) - .collect::>>()?; - - let left_names = lfields.iter().map(|e| e.name.as_str()); - let right_names = rfields.iter().map(|e| e.name.as_str()); - - let mut names_so_far = HashSet::new(); - - join_fields.iter().for_each(|f| { - names_so_far.insert(f.name.clone()); - }); - - // Then Add Left Table non-join-key columns - for field in left.fields.values() { - if names_so_far.contains(&field.name) { - continue; - } else { - join_fields.push(field.clone()); - names_so_far.insert(field.name.clone()); - } - } - - let zipped_names: Vec<_> = left_names.zip(right_names).collect(); - let right_to_left_keys: HashMap<&str, &str> = HashMap::from_iter(zipped_names.iter().copied()); - - // Then Add Right Table non-join-key columns - - for field in right.fields.values() { - // Skip fields if they were used in the join and have the same name as the corresponding left field - match right_to_left_keys.get(field.name.as_str()) { - Some(val) if val.eq(&field.name.as_str()) => { - continue; - } - _ => (), - } - - let mut curr_name = field.name.clone(); - while names_so_far.contains(&curr_name) { - curr_name = "right.".to_string() + curr_name.as_str(); - } - join_fields.push(field.rename(curr_name.clone())); - names_so_far.insert(curr_name.clone()); - } - let schema = Schema::new(join_fields)?; - Ok(Arc::new(schema)) -} - -struct JoinOutputColumn { - pub is_right: bool, - pub index: usize, - pub field: Field, -} - -pub struct JoinOutputMapper { - mapping: Vec, - left_schema: SchemaRef, - right_schema: SchemaRef, -} - -impl JoinOutputMapper { - fn try_new(mapping: Vec) -> DaftResult { - let left_schema = Schema::new( - mapping - .iter() - .filter(|jc| !jc.is_right) - .map(|jc| jc.field.clone()) - .collect::>(), - )? - .into(); - let right_schema = Schema::new( - mapping - .iter() - .filter(|jc| jc.is_right) - .map(|jc| jc.field.clone()) - .collect::>(), - )? - .into(); - Ok(JoinOutputMapper { - mapping, - left_schema, - right_schema, - }) - } - - pub fn map_left(&self, table: &Table) -> DaftResult { - let out = self - .mapping - .iter() - .filter(|jc| !jc.is_right) - .map(|jc| DaftResult::Ok(table.get_column_by_index(jc.index)?.rename(&jc.field.name))) - .collect::>>()?; - - Table::new_with_size(self.left_schema.clone(), out, table.num_rows) - } - - pub fn map_right(&self, table: &Table) -> DaftResult
{ - let out = self - .mapping - .iter() - .filter(|jc| jc.is_right) - .map(|jc| DaftResult::Ok(table.get_column_by_index(jc.index)?.rename(&jc.field.name))) - .collect::>>()?; - - Table::new_with_size(self.right_schema.clone(), out, table.num_rows) - } -} - -pub fn infer_join_schema_mapper( - left: &SchemaRef, - right: &SchemaRef, - left_on: &[ExprRef], - right_on: &[ExprRef], - how: JoinType, -) -> DaftResult { - if left_on.len() != right_on.len() { - return Err(DaftError::ValueError(format!( - "Length of left_on does not match length of right_on for Join {} vs {}", - left_on.len(), - right_on.len() - ))); - } - if matches!(how, JoinType::Anti | JoinType::Semi) { - return JoinOutputMapper::try_new( - left.fields - .values() - .enumerate() - .map(|(i, f)| JoinOutputColumn { - is_right: false, - index: i, - field: f.clone(), - }) - .collect(), - ); - } - let mut result = vec![]; - let mut names_so_far = HashSet::new(); - - let lfields = left_on - .iter() - .map(|e| e.to_field(left)) - .collect::>>()?; - let rfields = right_on - .iter() - .map(|e| e.to_field(right)) - .collect::>>()?; - - for lf in lfields.iter() { - let name = &lf.name; - let index = left.get_index(name)?; - let field = left.get_field(name)?.clone(); - result.push(JoinOutputColumn { - is_right: false, - index, - field, - }); - names_so_far.insert(name.clone()); - } - - let left_names = lfields.iter().map(|e| e.name.as_str()); - let right_names = rfields.iter().map(|e| e.name.as_str()); - - // Then Add Left Table non-join-key columns - for (index, field) in left.fields.values().enumerate() { - if names_so_far.contains(&field.name) { - continue; - } else { - result.push(JoinOutputColumn { - is_right: false, - index, - field: field.clone(), - }); - names_so_far.insert(field.name.clone()); - } - } - - let zipped_names: Vec<_> = left_names.zip(right_names).collect(); - let right_to_left_keys: HashMap<&str, &str> = HashMap::from_iter(zipped_names.iter().copied()); - - // Then Add Right Table non-join-key columns - - for (index, field) in right.fields.values().enumerate() { - // Skip fields if they were used in the join and have the same name as the corresponding left field - match right_to_left_keys.get(field.name.as_str()) { - Some(val) if val.eq(&field.name.as_str()) => { - continue; - } - _ => (), - } - - let mut curr_name = field.name.clone(); - while names_so_far.contains(&curr_name) { - curr_name = "right.".to_string() + curr_name.as_str(); - } - - names_so_far.insert(curr_name.clone()); - - let new_field = field.rename(curr_name.clone()); - result.push(JoinOutputColumn { - is_right: true, - index, - field: new_field, - }); - } - JoinOutputMapper::try_new(result) -} - fn add_non_join_key_columns( left: &Table, right: &Table, lidx: Series, ridx: Series, - left_on: &[ExprRef], - right_on: &[ExprRef], mut join_series: Vec, ) -> DaftResult> { - let mut names_so_far = join_series + let join_keys = join_series .iter() .map(|s| s.name().to_string()) .collect::>(); // TODO(Clark): Parallelize with rayon. for field in left.schema.fields.values() { - if names_so_far.contains(&field.name) { + if join_keys.contains(&field.name) { continue; } else { join_series.push(left.get_column(&field.name)?.take(&lidx)?); - names_so_far.insert(field.name.clone()); } } drop(lidx); - // Zip the names of the left and right expressions into a HashMap - let left_names = left_on - .iter() - .map(|e| e.to_field(&left.schema).map(|f| f.name)) - .collect::>>()?; - let right_names = right_on - .iter() - .map(|e| e.to_field(&right.schema).map(|f| f.name)) - .collect::>>()?; - let right_to_left_keys: HashMap = - HashMap::from_iter(left_names.into_iter().zip(right_names)); - - // TODO(Clark): Parallelize with Rayon. for field in right.schema.fields.values() { - // Skip fields if they were used in the join and have the same name as the corresponding left field - match right_to_left_keys.get(&field.name) { - Some(val) if val.eq(&field.name) => { - continue; - } - _ => (), - } - - let mut curr_name = field.name.clone(); - while names_so_far.contains(&curr_name) { - curr_name = "right.".to_string() + curr_name.as_str(); + if join_keys.contains(&field.name) { + continue; + } else { + join_series.push(right.get_column(&field.name)?.take(&ridx)?); } - join_series.push( - right - .get_column(&field.name)? - .rename(curr_name.clone()) - .take(&ridx)?, - ); - names_so_far.insert(curr_name); } Ok(join_series) @@ -427,19 +157,13 @@ impl Table { let (ltable, rtable) = match_types_for_tables(<able, &rtable)?; let (lidx, ridx) = merge_join::merge_inner_join(<able, &rtable)?; - let mut join_series = Vec::with_capacity(ltable.num_columns()); - - for (l, r) in ltable - .column_names() - .iter() - .zip(rtable.column_names().iter()) - { - if l == r { - let lcol = self.get_column(l)?; - let rcol = right.get_column(r)?; + let mut join_series = get_common_join_keys(left_on, right_on) + .map(|name| { + let lcol = self.get_column(name)?; + let rcol = right.get_column(name)?; let mut growable = - make_growable(l, lcol.data_type(), vec![lcol, rcol], false, lcol.len()); + make_growable(name, lcol.data_type(), vec![lcol, rcol], false, lcol.len()); for (li, ri) in lidx.u64()?.into_iter().zip(ridx.u64()?) { match (li, ri) { @@ -449,18 +173,15 @@ impl Table { } } - join_series.push(growable.build()?); - } else { - join_series.push(self.get_column(l)?.take(&lidx)?); - } - } + growable.build() + }) + .collect::>>()?; drop(ltable); drop(rtable); let num_rows = lidx.len(); - join_series = - add_non_join_key_columns(self, right, lidx, ridx, left_on, right_on, join_series)?; + join_series = add_non_join_key_columns(self, right, lidx, ridx, join_series)?; Table::new_with_size(join_schema, join_series, num_rows) } diff --git a/src/daft-table/src/ops/mod.rs b/src/daft-table/src/ops/mod.rs index de83d3fad1..6d53b60e13 100644 --- a/src/daft-table/src/ops/mod.rs +++ b/src/daft-table/src/ops/mod.rs @@ -8,5 +8,3 @@ mod pivot; mod search_sorted; mod sort; mod unpivot; - -pub use joins::{infer_join_schema, infer_join_schema_mapper, JoinOutputMapper}; diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 56314f4549..18ba1b8ab3 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -3,6 +3,7 @@ import pyarrow as pa import pytest +from daft import col from daft.datatype import DataType from daft.errors import ExpressionTypeError from tests.utils import sort_arrow_table @@ -832,3 +833,241 @@ def test_join_semi_anti_different_names(join_strategy, join_type, expected, make assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "id_left") == sort_arrow_table( pa.Table.from_pydict(expected), "id_left" ) + + +@pytest.mark.parametrize("join_type", ["inner", "left", "right", "outer"]) +def test_join_true_join_keys(join_type, make_df): + daft_df = make_df( + { + "id": [1, 2, 3], + "values": ["a", "b", "c"], + } + ) + daft_df2 = make_df( + { + "id": [2.0, 2.5, 3.0, 4.0], + "values": ["a2", "b2", "c2", "d2"], + } + ) + + result = daft_df.join(daft_df2, left_on=["id", "values"], right_on=["id", col("values").str.left(1)], how=join_type) + + assert result.schema().column_names() == ["id", "values", "right.values"] + assert result.schema()["id"].dtype == daft_df.schema()["id"].dtype + assert result.schema()["values"].dtype == daft_df.schema()["values"].dtype + assert result.schema()["right.values"].dtype == daft_df2.schema()["values"].dtype + + +@pytest.mark.parametrize( + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, +) +@pytest.mark.parametrize( + "join_type,expected", + [ + ( + "inner", + { + "a": [2, 3], + "b": [2, 3], + }, + ), + ( + "left", + { + "a": [1, 2, 3], + "b": [None, 2, 3], + }, + ), + ( + "right", + { + "a": [2, 3, None], + "b": [2, 3, 4], + }, + ), + ( + "outer", + { + "a": [1, 2, 3, None], + "b": [None, 2, 3, 4], + }, + ), + ( + "semi", + { + "a": [2, 3], + }, + ), + ( + "anti", + { + "a": [1], + }, + ), + ], +) +def test_join_with_alias_in_key(join_strategy, join_type, expected, make_df): + skip_invalid_join_strategies(join_strategy, join_type) + + daft_df1 = make_df( + { + "a": [1, 2, 3], + } + ) + daft_df2 = make_df( + { + "b": [2, 3, 4], + } + ) + + daft_df = daft_df1.join(daft_df2, left_on=col("a").alias("x"), right_on="b", how=join_type, strategy=join_strategy) + + assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "a") == sort_arrow_table( + pa.Table.from_pydict(expected), "a" + ) + + +@pytest.mark.parametrize( + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, +) +@pytest.mark.parametrize( + "join_type,expected", + [ + ( + "inner", + { + "a": [2, 3], + "right.a": [2, 3], + }, + ), + ( + "left", + { + "a": [1, 2, 3], + "right.a": [None, 2, 3], + }, + ), + ( + "right", + { + "a": [2, 3, None], + "right.a": [2, 3, 4], + }, + ), + ( + "outer", + { + "a": [1, 2, 3, None], + "right.a": [None, 2, 3, 4], + }, + ), + ( + "semi", + { + "a": [2, 3], + }, + ), + ( + "anti", + { + "a": [1], + }, + ), + ], +) +def test_join_same_name_alias(join_strategy, join_type, expected, make_df): + skip_invalid_join_strategies(join_strategy, join_type) + + daft_df1 = make_df( + { + "a": [1, 2, 3], + } + ) + daft_df2 = make_df( + { + "a": [2, 3, 4], + } + ) + + daft_df = daft_df1.join(daft_df2, left_on="a", right_on=col("a").alias("b"), how=join_type, strategy=join_strategy) + + assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "a") == sort_arrow_table( + pa.Table.from_pydict(expected), "a" + ) + + +@pytest.mark.parametrize( + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, +) +@pytest.mark.parametrize( + "join_type,expected", + [ + ( + "inner", + { + "a": [0.2, 0.3], + "right.a": [20, 30], + }, + ), + ( + "left", + { + "a": [0.1, 0.2, 0.3], + "right.a": [None, 20, 30], + }, + ), + ( + "right", + { + "a": [0.2, 0.3, None], + "right.a": [20, 30, 40], + }, + ), + ( + "outer", + { + "a": [0.1, 0.2, 0.3, None], + "right.a": [None, 20, 30, 40], + }, + ), + ( + "semi", + { + "a": [0.2, 0.3], + }, + ), + ( + "anti", + { + "a": [0.1], + }, + ), + ], +) +def test_join_same_name_alias_with_compute(join_strategy, join_type, expected, make_df): + skip_invalid_join_strategies(join_strategy, join_type) + + daft_df1 = make_df( + { + "a": [0.1, 0.2, 0.3], + } + ) + daft_df2 = make_df( + { + "a": [20, 30, 40], + } + ) + + daft_df = daft_df1.join( + daft_df2, left_on=col("a") * 10, right_on=(col("a") / 10).alias("b"), how=join_type, strategy=join_strategy + ) + + assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "a") == sort_arrow_table( + pa.Table.from_pydict(expected), "a" + ) diff --git a/tests/table/test_joins.py b/tests/table/test_joins.py index 5664950121..968ab97630 100644 --- a/tests/table/test_joins.py +++ b/tests/table/test_joins.py @@ -217,53 +217,10 @@ def test_table_join_no_columns(join_impl) -> None: getattr(left_table, join_impl)(right_table, left_on=[], right_on=[]) -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_single_column_name_conflicts(join_impl) -> None: - left_table = MicroPartition.from_pydict({"x": [0, 1, 2, 3], "y": [2, 3, 4, 5]}) - right_table = MicroPartition.from_pydict({"x": [3, 2, 1, 0], "y": [6, 7, 8, 9]}) - - result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("x")]) - assert result_table.column_names() == ["x", "y", "right.y"] - result_sorted = result_table.sort([col("x")]) - assert result_sorted.get_column("y").to_pylist() == [2, 3, 4, 5] - - assert result_sorted.get_column("right.y").to_pylist() == [9, 8, 7, 6] - - -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_single_column_name_conflicts_different_named_join(join_impl) -> None: - left_table = MicroPartition.from_pydict({"x": [0, 1, 2, 3], "y": [2, 3, 4, 5]}) - right_table = MicroPartition.from_pydict({"y": [3, 2, 1, 0], "x": [6, 7, 8, 9]}) - - result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("y")]) - - # NOTE: right.y is not dropped because it has a different name from the corresponding left - # column it is joined on, left_table["x"] - assert result_table.column_names() == ["x", "y", "right.y", "right.x"] - - result_sorted = result_table.sort([col("x")]) - assert result_sorted.get_column("y").to_pylist() == [2, 3, 4, 5] - assert result_sorted.get_column("right.x").to_pylist() == [9, 8, 7, 6] - - -@pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) -def test_table_join_single_column_name_multiple_conflicts(join_impl) -> None: - left_table = MicroPartition.from_pydict({"x": [0, 1, 2, 3], "y": [2, 3, 4, 5], "right.y": [6, 7, 8, 9]}) - right_table = MicroPartition.from_pydict({"x": [3, 2, 1, 0], "y": [10, 11, 12, 13]}) - - result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("x")]) - assert result_table.column_names() == ["x", "y", "right.y", "right.right.y"] - result_sorted = result_table.sort([col("x")]) - assert result_sorted.get_column("y").to_pylist() == [2, 3, 4, 5] - - assert result_sorted.get_column("right.y").to_pylist() == [6, 7, 8, 9] - assert result_sorted.get_column("right.right.y").to_pylist() == [13, 12, 11, 10] - - @pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) def test_table_join_single_column_name_boolean(join_impl) -> None: left_table = MicroPartition.from_pydict({"x": [False, True, None], "y": [0, 1, 2]}) - right_table = MicroPartition.from_pydict({"x": [None, True, False, None], "y": [0, 1, 2, 3]}) + right_table = MicroPartition.from_pydict({"x": [None, True, False, None], "right.y": [0, 1, 2, 3]}) result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("x")]) assert result_table.column_names() == ["x", "y", "right.y"] @@ -275,7 +232,7 @@ def test_table_join_single_column_name_boolean(join_impl) -> None: @pytest.mark.parametrize("join_impl", ["hash_join", "sort_merge_join"]) def test_table_join_single_column_name_null(join_impl) -> None: left_table = MicroPartition.from_pydict({"x": [None, None, None], "y": [0, 1, 2]}) - right_table = MicroPartition.from_pydict({"x": [None, None, None, None], "y": [0, 1, 2, 3]}) + right_table = MicroPartition.from_pydict({"x": [None, None, None, None], "right.y": [0, 1, 2, 3]}) result_table = getattr(left_table, join_impl)(right_table, left_on=[col("x")], right_on=[col("x")]) assert result_table.column_names() == ["x", "y", "right.y"]