diff --git a/daft/context.py b/daft/context.py index 10be0264eb..61b69284af 100644 --- a/daft/context.py +++ b/daft/context.py @@ -308,6 +308,7 @@ def set_execution_config( csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + shuffle_join_default_partitions: int | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -344,6 +345,7 @@ def set_execution_config( csv_target_filesize: Target File Size when writing out CSV Files. Defaults to 512MB csv_inflation_factor: Inflation Factor of CSV files (In-Memory-Size / File-Size) ratio. Defaults to 0.5 shuffle_aggregation_default_partitions: Minimum number of partitions to create when performing aggregations. Defaults to 200, unless the number of input partitions is less than 200. + shuffle_join_default_partitions: Minimum number of partitions to create when performing joins. Defaults to 16, unless the number of input partitions is greater than 16. read_sql_partition_size_bytes: Target size of partition when reading from SQL databases. Defaults to 512MB enable_aqe: Enables Adaptive Query Execution, Defaults to False enable_native_executor: Enables new local executor. Defaults to False @@ -369,6 +371,7 @@ def set_execution_config( csv_target_filesize=csv_target_filesize, csv_inflation_factor=csv_inflation_factor, shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions, + shuffle_join_default_partitions=shuffle_join_default_partitions, read_sql_partition_size_bytes=read_sql_partition_size_bytes, enable_aqe=enable_aqe, enable_native_executor=enable_native_executor, diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index b08a0633e4..83f1eff059 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1753,6 +1753,7 @@ class PyDaftExecutionConfig: csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + shuffle_join_default_partitions: int | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -1785,6 +1786,8 @@ class PyDaftExecutionConfig: @property def shuffle_aggregation_default_partitions(self) -> int: ... @property + def shuffle_join_default_partitions(self) -> int: ... + @property def read_sql_partition_size_bytes(self) -> int: ... @property def enable_aqe(self) -> bool: ... diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index dcaef0a2f8..153d2a80c5 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -52,6 +52,7 @@ pub struct DaftExecutionConfig { pub csv_target_filesize: usize, pub csv_inflation_factor: f64, pub shuffle_aggregation_default_partitions: usize, + pub shuffle_join_default_partitions: usize, pub read_sql_partition_size_bytes: usize, pub enable_aqe: bool, pub enable_native_executor: bool, @@ -75,6 +76,7 @@ impl Default for DaftExecutionConfig { csv_target_filesize: 512 * 1024 * 1024, // 512MB csv_inflation_factor: 0.5, shuffle_aggregation_default_partitions: 200, + shuffle_join_default_partitions: 16, read_sql_partition_size_bytes: 512 * 1024 * 1024, // 512MB enable_aqe: false, enable_native_executor: false, diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 5dda71eda8..818934261a 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -94,6 +94,7 @@ impl PyDaftExecutionConfig { csv_target_filesize: Option, csv_inflation_factor: Option, shuffle_aggregation_default_partitions: Option, + shuffle_join_default_partitions: Option, read_sql_partition_size_bytes: Option, enable_aqe: Option, enable_native_executor: Option, @@ -143,10 +144,16 @@ impl PyDaftExecutionConfig { if let Some(csv_inflation_factor) = csv_inflation_factor { config.csv_inflation_factor = csv_inflation_factor; } + if let Some(shuffle_aggregation_default_partitions) = shuffle_aggregation_default_partitions { config.shuffle_aggregation_default_partitions = shuffle_aggregation_default_partitions; } + + if let Some(shuffle_join_default_partitions) = shuffle_join_default_partitions { + config.shuffle_join_default_partitions = shuffle_join_default_partitions; + } + if let Some(read_sql_partition_size_bytes) = read_sql_partition_size_bytes { config.read_sql_partition_size_bytes = read_sql_partition_size_bytes; } @@ -231,6 +238,11 @@ impl PyDaftExecutionConfig { Ok(self.config.shuffle_aggregation_default_partitions) } + #[getter] + fn get_shuffle_join_default_partitions(&self) -> PyResult { + Ok(self.config.shuffle_join_default_partitions) + } + #[getter] fn get_read_sql_partition_size_bytes(&self) -> PyResult { Ok(self.config.read_sql_partition_size_bytes) diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 639c571871..408d4f62a6 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -571,6 +571,7 @@ pub(super) fn translate_single_logical_node( "Sort-merge join currently only supports inner joins".to_string(), )); } + let num_partitions = max(num_partitions, cfg.shuffle_join_default_partitions); let needs_presort = if cfg.sort_merge_join_sort_with_aligned_boundaries { // Use the special-purpose presorting that ensures join inputs are sorted with aligned @@ -616,7 +617,6 @@ pub(super) fn translate_single_logical_node( // allow for leniency in partition size to avoid minor repartitions let num_left_partitions = left_clustering_spec.num_partitions(); let num_right_partitions = right_clustering_spec.num_partitions(); - let num_partitions = match ( is_left_hash_partitioned, is_right_hash_partitioned, @@ -637,6 +637,7 @@ pub(super) fn translate_single_logical_node( } (_, _, a, b) => max(a, b), }; + let num_partitions = max(num_partitions, cfg.shuffle_join_default_partitions); if num_left_partitions != num_partitions || (num_partitions > 1 && !is_left_hash_partitioned) @@ -1076,6 +1077,13 @@ mod tests { Self::Reversed(v) => Self::Reversed(v * x), } } + fn unwrap(&self) -> usize { + match self { + Self::Good(v) => *v, + Self::Bad(v) => *v, + Self::Reversed(v) => *v, + } + } } fn force_repartition( @@ -1128,21 +1136,31 @@ mod tests { fn check_physical_matches( plan: PhysicalPlanRef, + left_partition_size: usize, + right_partition_size: usize, left_repartitions: bool, right_repartitions: bool, + shuffle_join_default_partitions: usize, ) -> bool { match plan.as_ref() { PhysicalPlan::HashJoin(HashJoin { left, right, .. }) => { - let left_works = match (left.as_ref(), left_repartitions) { + let left_works = match ( + left.as_ref(), + left_repartitions || left_partition_size < shuffle_join_default_partitions, + ) { (PhysicalPlan::ReduceMerge(_), true) => true, (PhysicalPlan::Project(_), false) => true, _ => false, }; - let right_works = match (right.as_ref(), right_repartitions) { + let right_works = match ( + right.as_ref(), + right_repartitions || right_partition_size < shuffle_join_default_partitions, + ) { (PhysicalPlan::ReduceMerge(_), true) => true, (PhysicalPlan::Project(_), false) => true, _ => false, }; + left_works && right_works } _ => false, @@ -1152,7 +1170,7 @@ mod tests { /// Tests a variety of settings regarding hash join repartitioning. #[test] fn repartition_hash_join_tests() -> DaftResult<()> { - use RepartitionOptions::*; + use RepartitionOptions::{Bad, Good, Reversed}; let cases = vec![ (Good(30), Good(30), false, false), (Good(30), Good(40), true, false), @@ -1170,9 +1188,17 @@ mod tests { let cfg: Arc = DaftExecutionConfig::default().into(); for (l_opts, r_opts, l_exp, r_exp) in cases { for mult in [1, 10] { - let plan = - get_hash_join_plan(cfg.clone(), l_opts.scale_by(mult), r_opts.scale_by(mult))?; - if !check_physical_matches(plan, l_exp, r_exp) { + let l_opts = l_opts.scale_by(mult); + let r_opts = r_opts.scale_by(mult); + let plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; + if !check_physical_matches( + plan, + l_opts.unwrap(), + r_opts.unwrap(), + l_exp, + r_exp, + cfg.shuffle_join_default_partitions, + ) { panic!( "Failed hash join test on case ({:?}, {:?}, {}, {}) with mult {}", l_opts, r_opts, l_exp, r_exp, mult @@ -1180,9 +1206,15 @@ mod tests { } // reversed direction - let plan = - get_hash_join_plan(cfg.clone(), r_opts.scale_by(mult), l_opts.scale_by(mult))?; - if !check_physical_matches(plan, r_exp, l_exp) { + let plan = get_hash_join_plan(cfg.clone(), r_opts.clone(), l_opts.clone())?; + if !check_physical_matches( + plan, + l_opts.unwrap(), + r_opts.unwrap(), + r_exp, + l_exp, + cfg.shuffle_join_default_partitions, + ) { panic!( "Failed hash join test on case ({:?}, {:?}, {}, {}) with mult {}", r_opts, l_opts, r_exp, l_exp, mult @@ -1199,27 +1231,38 @@ mod tests { let mut cfg = DaftExecutionConfig::default(); cfg.hash_join_partition_size_leniency = 0.8; let cfg = Arc::new(cfg); + let (l_opts, r_opts) = (RepartitionOptions::Good(30), RepartitionOptions::Bad(40)); + let physical_plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; + assert!(check_physical_matches( + physical_plan, + l_opts.unwrap(), + r_opts.unwrap(), + true, + true, + cfg.shuffle_join_default_partitions + )); - let physical_plan = get_hash_join_plan( - cfg.clone(), - RepartitionOptions::Good(20), - RepartitionOptions::Bad(40), - )?; - assert!(check_physical_matches(physical_plan, true, true)); - - let physical_plan = get_hash_join_plan( - cfg.clone(), - RepartitionOptions::Good(20), - RepartitionOptions::Bad(25), - )?; - assert!(check_physical_matches(physical_plan, false, true)); + let (l_opts, r_opts) = (RepartitionOptions::Good(20), RepartitionOptions::Bad(25)); + let physical_plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; + assert!(check_physical_matches( + physical_plan, + l_opts.unwrap(), + r_opts.unwrap(), + false, + true, + cfg.shuffle_join_default_partitions + )); - let physical_plan = get_hash_join_plan( - cfg.clone(), - RepartitionOptions::Good(20), - RepartitionOptions::Bad(26), - )?; - assert!(check_physical_matches(physical_plan, true, true)); + let (l_opts, r_opts) = (RepartitionOptions::Good(20), RepartitionOptions::Bad(26)); + let physical_plan = get_hash_join_plan(cfg.clone(), l_opts.clone(), r_opts.clone())?; + assert!(check_physical_matches( + physical_plan, + l_opts.unwrap(), + r_opts.unwrap(), + true, + true, + cfg.shuffle_join_default_partitions + )); Ok(()) } @@ -1237,7 +1280,14 @@ mod tests { let cfg: Arc = DaftExecutionConfig::default().into(); for (l_opts, r_opts, l_exp, r_exp) in cases { let plan = get_hash_join_plan(cfg.clone(), l_opts, r_opts)?; - if !check_physical_matches(plan, l_exp, r_exp) { + if !check_physical_matches( + plan, + l_opts.unwrap(), + r_opts.unwrap(), + l_exp, + r_exp, + cfg.shuffle_join_default_partitions, + ) { panic!( "Failed single partition hash join test on case ({:?}, {:?}, {}, {})", l_opts, r_opts, l_exp, r_exp @@ -1246,7 +1296,14 @@ mod tests { // reversed direction let plan = get_hash_join_plan(cfg.clone(), r_opts, l_opts)?; - if !check_physical_matches(plan, r_exp, l_exp) { + if !check_physical_matches( + plan, + l_opts.unwrap(), + r_opts.unwrap(), + r_exp, + l_exp, + cfg.shuffle_join_default_partitions, + ) { panic!( "Failed single partition hash join test on case ({:?}, {:?}, {}, {})", r_opts, l_opts, r_exp, l_exp diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index b0bdbf9df4..8ccc3f72cd 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -3,14 +3,16 @@ import pyarrow as pa import pytest -from daft import col, context +import daft +from daft import col +from daft.context import get_context from daft.datatype import DataType from daft.errors import ExpressionTypeError from tests.utils import sort_arrow_table def skip_invalid_join_strategies(join_strategy, join_type): - if context.get_context().daft_execution_config.enable_native_executor is True: + if get_context().daft_execution_config.enable_native_executor is True: if join_type == "outer" or join_strategy not in [None, "hash"]: pytest.skip("Native executor fails for these tests") else: @@ -1075,3 +1077,92 @@ def test_join_same_name_alias_with_compute(join_strategy, join_type, expected, m assert sort_arrow_table(pa.Table.from_pydict(daft_df.to_pydict()), "a") == sort_arrow_table( pa.Table.from_pydict(expected), "a" ) + + +# the partition size should be the max(shuffle_join_default_partitions, max(left_partition_size, right_partition_size)) +@pytest.mark.parametrize("shuffle_join_default_partitions", [None, 20]) +def test_join_result_partitions_smaller_than_input(shuffle_join_default_partitions): + skip_invalid_join_strategies("hash", "inner") + if shuffle_join_default_partitions is None: + min_partitions = get_context().daft_execution_config.shuffle_join_default_partitions + else: + min_partitions = shuffle_join_default_partitions + + with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): + right_partition_size = 50 + for left_partition_size in [1, min_partitions, min_partitions + 1]: + df_left = daft.from_pydict( + {"group": [i for i in range(min_partitions + 1)], "value": [i for i in range(min_partitions + 1)]} + ) + df_left = df_left.into_partitions(left_partition_size) + + df_right = daft.from_pydict( + {"group": [i for i in range(right_partition_size)], "value": [i for i in range(right_partition_size)]} + ) + + df_right = df_right.into_partitions(right_partition_size) + + actual = df_left.join(df_right, on="group", how="inner", strategy="hash").collect() + n_partitions = actual.num_partitions() + expected_n_partitions = max(min_partitions, left_partition_size, right_partition_size) + assert n_partitions == expected_n_partitions + + +def test_join_right_single_partition(): + skip_invalid_join_strategies("hash", "inner") + shuffle_join_default_partitions = 16 + df_left = daft.from_pydict({"group": [i for i in range(300)], "value": [i for i in range(300)]}).repartition( + 300, "group" + ) + + df_right = daft.from_pydict({"group": [i for i in range(100)], "value": [i for i in range(100)]}).repartition( + 1, "group" + ) + + with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): + actual = df_left.join(df_right, on="group", how="inner", strategy="hash").collect() + n_partitions = actual.num_partitions() + assert n_partitions == 300 + + +def test_join_right_smaller_than_cfg(): + skip_invalid_join_strategies("hash", "inner") + shuffle_join_default_partitions = 200 + df_left = daft.from_pydict({"group": [i for i in range(199)], "value": [i for i in range(199)]}).repartition( + 199, "group" + ) + + df_right = daft.from_pydict({"group": [i for i in range(100)], "value": [i for i in range(100)]}).repartition( + 100, "group" + ) + + with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): + actual = df_left.join(df_right, on="group", how="inner", strategy="hash").collect() + n_partitions = actual.num_partitions() + assert n_partitions == 200 + + +# for sort_merge, the result partitions should always be max(shuffle_join_default_partitions, max(left_partition_size, right_partition_size)) +@pytest.mark.parametrize("shuffle_join_default_partitions", [None, 20]) +def test_join_result_partitions_for_sortmerge(shuffle_join_default_partitions): + skip_invalid_join_strategies("sort_merge", "inner") + + if shuffle_join_default_partitions is None: + min_partitions = get_context().daft_execution_config.shuffle_join_default_partitions + else: + min_partitions = shuffle_join_default_partitions + + with daft.execution_config_ctx(shuffle_join_default_partitions=shuffle_join_default_partitions): + for partition_size in [1, min_partitions, min_partitions + 1]: + df_left = daft.from_pydict( + {"group": [i for i in range(min_partitions + 1)], "value": [i for i in range(min_partitions + 1)]} + ) + df_left = df_left.into_partitions(partition_size) + + df_right = daft.from_pydict({"group": [i for i in range(50)], "value": [i for i in range(50)]}) + + df_right = df_right.into_partitions(50) + + actual = df_left.join(df_right, on="group", how="inner", strategy="sort_merge").collect() + + assert actual.num_partitions() == max(min_partitions, partition_size, 50)