From 6a7d83138ee03451f9735d5831e77664bbb724a9 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Mon, 17 Jun 2024 20:38:05 -0700 Subject: [PATCH] [FEAT] Implement Anti and Semi Join (#2379) closes: https://github.com/Eventual-Inc/Daft/issues/2369 --- daft/daft.pyi | 2 + daft/dataframe/dataframe.py | 4 +- daft/hudi/hudi_scan.py | 2 +- src/daft-core/src/join.rs | 6 +- src/daft-micropartition/src/ops/join.rs | 11 ++- src/daft-plan/src/logical_ops/join.rs | 5 +- .../rules/push_down_projection.rs | 52 ++++++++++- .../src/physical_planner/translate.rs | 10 +++ src/daft-table/src/ops/hash.rs | 43 +++++++++ src/daft-table/src/ops/joins/hash_join.rs | 87 ++++++++++++++++++- src/daft-table/src/ops/joins/mod.rs | 36 ++++++-- tests/dataframe/test_joins.py | 57 ++++++++++++ 12 files changed, 289 insertions(+), 26 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index c74ffe8afa..4b8140e6bd 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -89,6 +89,8 @@ class JoinType(Enum): Left: int Right: int Outer: int + Semi: int + Anti: int @staticmethod def from_join_type_str(join_type: str) -> JoinType: diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 8821f31f4f..824e2ea302 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1164,8 +1164,10 @@ def join( if join_strategy == JoinStrategy.SortMerge and join_type != JoinType.Inner: raise ValueError("Sort merge join only supports inner joins") - if join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Outer: + elif join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Outer: raise ValueError("Broadcast join does not support outer joins") + elif join_strategy == JoinStrategy.Broadcast and join_type == JoinType.Anti: + raise ValueError("Broadcast join does not support Anti joins") left_exprs = self.__column_input_to_expression(tuple(left_on) if isinstance(left_on, list) else (left_on,)) right_exprs = self.__column_input_to_expression(tuple(right_on) if isinstance(right_on, list) else (right_on,)) diff --git a/daft/hudi/hudi_scan.py b/daft/hudi/hudi_scan.py index fcef509bff..8c1c98f298 100644 --- a/daft/hudi/hudi_scan.py +++ b/daft/hudi/hudi_scan.py @@ -115,8 +115,8 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: file=path, file_format=file_format_config, schema=self._schema._schema, - num_rows=record_count, storage_config=self._storage_config, + num_rows=record_count, size_bytes=size_bytes, pushdowns=pushdowns, partition_values=partition_values, diff --git a/src/daft-core/src/join.rs b/src/daft-core/src/join.rs index 3d1385a1be..018fba15b9 100644 --- a/src/daft-core/src/join.rs +++ b/src/daft-core/src/join.rs @@ -21,6 +21,8 @@ pub enum JoinType { Left, Right, Outer, + Anti, + Semi, } #[cfg(feature = "python")] @@ -46,7 +48,7 @@ impl JoinType { pub fn iterator() -> std::slice::Iter<'static, JoinType> { use JoinType::*; - static JOIN_TYPES: [JoinType; 4] = [Inner, Left, Right, Outer]; + static JOIN_TYPES: [JoinType; 6] = [Inner, Left, Right, Outer, Anti, Semi]; JOIN_TYPES.iter() } } @@ -62,6 +64,8 @@ impl FromStr for JoinType { "left" => Ok(Left), "right" => Ok(Right), "outer" => Ok(Outer), + "anti" => Ok(Anti), + "semi" => Ok(Semi), _ => Err(DaftError::TypeError(format!( "Join type {} is not supported; only the following types are supported: {:?}", join_type, diff --git a/src/daft-micropartition/src/ops/join.rs b/src/daft-micropartition/src/ops/join.rs index 8c0c26b4a3..8cc5c3d078 100644 --- a/src/daft-micropartition/src/ops/join.rs +++ b/src/daft-micropartition/src/ops/join.rs @@ -23,15 +23,14 @@ impl MicroPartition { where F: FnOnce(&Table, &Table, &[ExprRef], &[ExprRef], JoinType) -> DaftResult, { - let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on)?; - + let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on, how)?; match (how, self.len(), right.len()) { (JoinType::Inner, 0, _) | (JoinType::Inner, _, 0) | (JoinType::Left, 0, _) | (JoinType::Right, _, 0) | (JoinType::Outer, 0, 0) => { - return Ok(Self::empty(Some(join_schema.into()))); + return Ok(Self::empty(Some(join_schema))); } _ => {} } @@ -58,7 +57,7 @@ impl MicroPartition { } }; if let TruthValue::False = tv { - return Ok(Self::empty(Some(join_schema.into()))); + return Ok(Self::empty(Some(join_schema))); } } @@ -67,11 +66,11 @@ impl MicroPartition { let rt = right.concat_or_get(io_stats)?; match (lt.as_slice(), rt.as_slice()) { - ([], _) | (_, []) => Ok(Self::empty(Some(join_schema.into()))), + ([], _) | (_, []) => Ok(Self::empty(Some(join_schema))), ([lt], [rt]) => { let joined_table = table_join(lt, rt, left_on, right_on, how)?; Ok(MicroPartition::new_loaded( - join_schema.into(), + join_schema, vec![joined_table].into(), None, )) diff --git a/src/daft-plan/src/logical_ops/join.rs b/src/daft-plan/src/logical_ops/join.rs index 6d3c57b8a4..8fe2bad6fb 100644 --- a/src/daft-plan/src/logical_ops/join.rs +++ b/src/daft-plan/src/logical_ops/join.rs @@ -82,8 +82,9 @@ impl Join { .map(|(_, field)| field) .cloned() .chain(right.schema().fields.iter().filter_map(|(rname, rfield)| { - if left_join_keys.contains(rname.as_str()) - && right_join_keys.contains(rname.as_str()) + if (left_join_keys.contains(rname.as_str()) + && right_join_keys.contains(rname.as_str())) + || matches!(join_type, JoinType::Anti | JoinType::Semi) { right_input_mapping.insert(rname.clone(), rname.clone()); None diff --git a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs index 26b64d5551..db91617d15 100644 --- a/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/logical_optimization/rules/push_down_projection.rs @@ -2,12 +2,12 @@ use std::{collections::HashMap, sync::Arc}; use common_error::DaftResult; -use daft_core::schema::Schema; +use daft_core::{schema::Schema, JoinType}; use daft_dsl::{col, optimization::replace_columns_with_expressions, Expr, ExprRef}; use indexmap::IndexSet; use crate::{ - logical_ops::{Aggregate, Pivot, Project, Source}, + logical_ops::{Aggregate, Join, Pivot, Project, Source}, source_info::SourceInfo, LogicalPlan, ResourceRequest, }; @@ -478,6 +478,52 @@ impl PushDownProjection { } } + fn try_optimize_join( + &self, + join: &Join, + plan: Arc, + ) -> DaftResult>> { + // If this join prunes columns from its upstream, + // then explicitly create a projection to do so. + // this is the case for semi and anti joins. + + if matches!(join.join_type, JoinType::Anti | JoinType::Semi) { + let required_cols = plan.required_columns(); + let right_required_cols = required_cols + .get(1) + .expect("we expect 2 set of required columns for join"); + let right_schema = join.right.schema(); + + if right_required_cols.len() < right_schema.fields.len() { + let new_subprojection: LogicalPlan = { + let pushdown_column_exprs = right_required_cols + .iter() + .map(|s| col(s.as_str())) + .collect::>(); + + Project::try_new( + join.right.clone(), + pushdown_column_exprs, + Default::default(), + )? + .into() + }; + + let new_join = plan + .with_new_children(&[(join.left).clone(), new_subprojection.into()]) + .arced(); + + Ok(self + .try_optimize(new_join.clone())? + .or(Transformed::Yes(new_join))) + } else { + Ok(Transformed::No(plan)) + } + } else { + Ok(Transformed::No(plan)) + } + } + fn try_optimize_pivot( &self, pivot: &Pivot, @@ -524,6 +570,8 @@ impl OptimizerRule for PushDownProjection { LogicalPlan::Aggregate(aggregation) => { self.try_optimize_aggregation(aggregation, plan.clone()) } + // Joins also do column projection + LogicalPlan::Join(join) => self.try_optimize_join(join, plan.clone()), // Pivots also do column projection LogicalPlan::Pivot(pivot) => self.try_optimize_pivot(pivot, plan.clone()), _ => Ok(Transformed::No(plan)), diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index ec2511f867..072e308321 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -558,6 +558,16 @@ pub(super) fn translate_single_logical_node( "Broadcast join does not support outer joins.".to_string(), )); } + (JoinType::Anti, _) => { + return Err(common_error::DaftError::ValueError( + "Broadcast join does not support anti joins.".to_string(), + )); + } + (JoinType::Semi, _) => { + return Err(common_error::DaftError::ValueError( + "Broadcast join does not support semi joins.".to_string(), + )); + } }; if is_swapped { diff --git a/src/daft-table/src/ops/hash.rs b/src/daft-table/src/ops/hash.rs index 735fa01199..d71ec9d3d8 100644 --- a/src/daft-table/src/ops/hash.rs +++ b/src/daft-table/src/ops/hash.rs @@ -103,4 +103,47 @@ impl Table { } Ok(probe_table) } + + pub fn to_probe_hash_map_without_idx( + &self, + ) -> DaftResult> { + let hashes = self.hash_rows()?; + + const DEFAULT_SIZE: usize = 20; + let comparator = build_multi_array_is_equal( + self.columns.as_slice(), + self.columns.as_slice(), + true, + true, + )?; + + let mut probe_table = + HashMap::::with_capacity_and_hasher( + DEFAULT_SIZE, + Default::default(), + ); + // TODO(Sammy): Drop nulls using validity array if requested + for (i, h) in hashes.as_arrow().values_iter().enumerate() { + let entry = probe_table.raw_entry_mut().from_hash(*h, |other| { + (*h == other.hash) && { + let j = other.idx; + comparator(i, j as usize) + } + }); + match entry { + RawEntryMut::Vacant(entry) => { + entry.insert_hashed_nocheck( + *h, + IndexHash { + idx: i as u64, + hash: *h, + }, + (), + ); + } + RawEntryMut::Occupied(_) => {} + } + } + Ok(probe_table) + } } diff --git a/src/daft-table/src/ops/joins/hash_join.rs b/src/daft-table/src/ops/joins/hash_join.rs index 6ae40a8f3a..1b8ad0bc62 100644 --- a/src/daft-table/src/ops/joins/hash_join.rs +++ b/src/daft-table/src/ops/joins/hash_join.rs @@ -4,7 +4,7 @@ use arrow2::{bitmap::MutableBitmap, types::IndexRange}; use daft_core::{ array::ops::{arrow2::comparison::build_multi_array_is_equal, full::FullNull}, datatypes::{BooleanArray, UInt64Array}, - DataType, IntoSeries, + DataType, IntoSeries, JoinType, }; use daft_dsl::ExprRef; @@ -21,7 +21,13 @@ pub(super) fn hash_inner_join( left_on: &[ExprRef], right_on: &[ExprRef], ) -> DaftResult
{ - let join_schema = infer_join_schema(&left.schema, &right.schema, left_on, right_on)?; + let join_schema = infer_join_schema( + &left.schema, + &right.schema, + left_on, + right_on, + JoinType::Inner, + )?; let lkeys = left.eval_expression_list(left_on)?; let rkeys = right.eval_expression_list(right_on)?; @@ -103,7 +109,13 @@ pub(super) fn hash_left_right_join( right_on: &[ExprRef], left_side: bool, ) -> DaftResult
{ - let join_schema = infer_join_schema(&left.schema, &right.schema, left_on, right_on)?; + let join_schema = infer_join_schema( + &left.schema, + &right.schema, + left_on, + right_on, + JoinType::Right, + )?; let lkeys = left.eval_expression_list(left_on)?; let rkeys = right.eval_expression_list(right_on)?; @@ -212,13 +224,80 @@ pub(super) fn hash_left_right_join( Table::new(join_schema, join_series) } +pub(super) fn hash_semi_anti_join( + left: &Table, + right: &Table, + left_on: &[ExprRef], + right_on: &[ExprRef], + is_anti: bool, +) -> DaftResult
{ + let lkeys = left.eval_expression_list(left_on)?; + let rkeys = right.eval_expression_list(right_on)?; + + let (lkeys, rkeys) = match_types_for_tables(&lkeys, &rkeys)?; + + let lidx = if lkeys.columns.iter().any(|s| s.data_type().is_null()) + || rkeys.columns.iter().any(|s| s.data_type().is_null()) + { + if is_anti { + // if we have a null column match, then all of the rows match for an anti join! + return Ok(left.clone()); + } else { + UInt64Array::empty("left_indices", &DataType::UInt64).into_series() + } + } else { + let probe_table = rkeys.to_probe_hash_map_without_idx()?; + + let l_hashes = lkeys.hash_rows()?; + + let is_equal = build_multi_array_is_equal( + lkeys.columns.as_slice(), + rkeys.columns.as_slice(), + false, + false, + )?; + let rows = rkeys.len(); + + drop(lkeys); + drop(rkeys); + + let mut left_idx = Vec::with_capacity(rows); + let is_semi = !is_anti; + for (l_idx, h) in l_hashes.as_arrow().values_iter().enumerate() { + let is_match = probe_table + .raw_entry() + .from_hash(*h, |other| { + *h == other.hash && { + let r_idx = other.idx as usize; + is_equal(l_idx, r_idx) + } + }) + .is_some(); + dbg!(l_idx); + if is_match == is_semi { + left_idx.push(l_idx as u64); + } + } + + UInt64Array::from(("left_indices", left_idx)).into_series() + }; + + left.take(&lidx) +} + pub(super) fn hash_outer_join( left: &Table, right: &Table, left_on: &[ExprRef], right_on: &[ExprRef], ) -> DaftResult
{ - let join_schema = infer_join_schema(&left.schema, &right.schema, left_on, right_on)?; + let join_schema = infer_join_schema( + &left.schema, + &right.schema, + left_on, + right_on, + JoinType::Outer, + )?; let lkeys = left.eval_expression_list(left_on)?; let rkeys = right.eval_expression_list(right_on)?; diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index d0805258f2..418b12a623 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -1,12 +1,18 @@ -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; use daft_core::{ - array::growable::make_growable, schema::Schema, utils::supertype::try_get_supertype, JoinType, - Series, + array::growable::make_growable, + schema::{Schema, SchemaRef}, + utils::supertype::try_get_supertype, + JoinType, Series, }; use common_error::{DaftError, DaftResult}; use daft_dsl::ExprRef; +use hash_join::hash_semi_anti_join; use crate::Table; @@ -36,11 +42,12 @@ fn match_types_for_tables(left: &Table, right: &Table) -> DaftResult<(Table, Tab } pub fn infer_join_schema( - left: &Schema, - right: &Schema, + left: &SchemaRef, + right: &SchemaRef, left_on: &[ExprRef], right_on: &[ExprRef], -) -> DaftResult { + how: JoinType, +) -> DaftResult { if left_on.len() != right_on.len() { return Err(DaftError::ValueError(format!( "Length of left_on does not match length of right_on for Join {} vs {}", @@ -48,6 +55,9 @@ pub fn infer_join_schema( right_on.len() ))); } + if matches!(how, JoinType::Anti | JoinType::Semi) { + return Ok(left.clone()); + } let lfields = left_on .iter() @@ -104,8 +114,8 @@ pub fn infer_join_schema( join_fields.push(field.rename(curr_name.clone())); names_so_far.insert(curr_name.clone()); } - - Schema::new(join_fields) + let schema = Schema::new(join_fields)?; + Ok(Arc::new(schema)) } fn add_non_join_key_columns( @@ -199,6 +209,8 @@ impl Table { JoinType::Left => hash_left_right_join(self, right, left_on, right_on, true), JoinType::Right => hash_left_right_join(self, right, left_on, right_on, false), JoinType::Outer => hash_outer_join(self, right, left_on, right_on), + JoinType::Semi => hash_semi_anti_join(self, right, left_on, right_on, false), + JoinType::Anti => hash_semi_anti_join(self, right, left_on, right_on, true), } } @@ -239,7 +251,13 @@ impl Table { return left.sort_merge_join(&right, left_on, right_on, true); } - let join_schema = infer_join_schema(&self.schema, &right.schema, left_on, right_on)?; + let join_schema = infer_join_schema( + &self.schema, + &right.schema, + left_on, + right_on, + JoinType::Inner, + )?; let ltable = self.eval_expression_list(left_on)?; let rtable = right.eval_expression_list(right_on)?; diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 0c47513cd6..b416319637 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -13,6 +13,10 @@ def skip_invalid_join_strategies(join_strategy, join_type): pytest.skip("Sort merge currently only supports inner joins") elif join_strategy == "broadcast" and join_type == "outer": pytest.skip("Broadcast join does not support outer joins") + elif join_strategy == "broadcast" and join_type == "anti": + pytest.skip("Broadcast join does not support anti joins") + elif join_strategy == "broadcast" and join_type == "semi": + pytest.skip("Broadcast join does not support semi joins") def test_invalid_join_strategies(make_df): @@ -720,3 +724,56 @@ def test_join_null_type_column(join_strategy, join_type, make_df): with pytest.raises((ExpressionTypeError, ValueError)): daft_df.join(daft_df2, on="id", how=join_type, strategy=join_strategy) + + +@pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) +@pytest.mark.parametrize( + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, +) +@pytest.mark.parametrize( + "join_type,expected", + [ + ( + "semi", + { + "id": [2, 3], + "values_left": ["b1", "c1"], + }, + ), + ( + "anti", + { + "id": [1, None], + "values_left": ["a1", "d1"], + }, + ), + ], +) +def test_join_semi_anti(join_strategy, join_type, expected, make_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy, join_type) + + daft_df1 = make_df( + { + "id": [1, 2, 3, None], + "values_left": ["a1", "b1", "c1", "d1"], + }, + repartition=repartition_nparts, + ) + daft_df2 = make_df( + { + "id": [2, 2, 3, 4], + "values_right": ["a2", "b2", "c2", "d2"], + }, + repartition=repartition_nparts, + ) + daft_df = ( + daft_df1.with_column("id", daft_df1["id"].cast(DataType.int64())) + .join(daft_df2, on="id", how=join_type, strategy=join_strategy) + .sort(["id", "values_left"]) + ).select("id", "values_left") + + assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "id") == sort_arrow_table( + pa.Table.from_pydict(expected), "id" + )