From e811034ea6f5e775abd58f3f46ae970897ec0a9e Mon Sep 17 00:00:00 2001 From: Jay Chia Date: Thu, 19 Dec 2024 16:35:32 +0800 Subject: [PATCH] Add config flag and unit test --- daft/context.py | 3 ++ daft/daft/__init__.pyi | 1 + src/common/daft-config/src/lib.rs | 6 +++ src/common/daft-config/src/python.rs | 15 +++++++ src/daft-scan/src/scan_task_iters/mod.rs | 48 +++++++++++--------- tests/io/test_split_scan_tasks.py | 57 ++++++++++++++++++++++++ 6 files changed, 110 insertions(+), 20 deletions(-) diff --git a/daft/context.py b/daft/context.py index 7f975ae83e..d96400980d 100644 --- a/daft/context.py +++ b/daft/context.py @@ -352,6 +352,7 @@ def set_execution_config( shuffle_algorithm: str | None = None, pre_shuffle_merge_threshold: int | None = None, enable_ray_tracing: bool | None = None, + scantask_splitting_level: int | None = None, ) -> DaftContext: """Globally sets various configuration parameters which control various aspects of Daft execution. @@ -395,6 +396,7 @@ def set_execution_config( shuffle_algorithm: The shuffle algorithm to use. Defaults to "map_reduce". Other options are "pre_shuffle_merge". pre_shuffle_merge_threshold: Memory threshold in bytes for pre-shuffle merge. Defaults to 1GB enable_ray_tracing: Enable tracing for Ray. Accessible in `/tmp/ray/session_latest/logs/daft` after the run completes. Defaults to False. + scantask_splitting_level: How aggressively to split scan tasks. Setting this to `2` will use a more aggressive ScanTask splitting algorithm which might be more expensive to run but results in more even splits of partitions. Defaults to 1. """ # Replace values in the DaftExecutionConfig with user-specified overrides ctx = get_context() @@ -425,6 +427,7 @@ def set_execution_config( shuffle_algorithm=shuffle_algorithm, pre_shuffle_merge_threshold=pre_shuffle_merge_threshold, enable_ray_tracing=enable_ray_tracing, + scantask_splitting_level=scantask_splitting_level, ) ctx._daft_execution_config = new_daft_execution_config diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 6860f72491..b53956f516 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1716,6 +1716,7 @@ class PyDaftExecutionConfig: enable_ray_tracing: bool | None = None, shuffle_algorithm: str | None = None, pre_shuffle_merge_threshold: int | None = None, + scantask_splitting_level: int | None = None, ) -> PyDaftExecutionConfig: ... @property def scan_tasks_min_size_bytes(self) -> int: ... diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 590fd5cf6c..b598b2072d 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -53,6 +53,7 @@ pub struct DaftExecutionConfig { pub shuffle_algorithm: String, pub pre_shuffle_merge_threshold: usize, pub enable_ray_tracing: bool, + pub scantask_splitting_level: i32, } impl Default for DaftExecutionConfig { @@ -81,6 +82,7 @@ impl Default for DaftExecutionConfig { shuffle_algorithm: "map_reduce".to_string(), pre_shuffle_merge_threshold: 1024 * 1024 * 1024, // 1GB enable_ray_tracing: false, + scantask_splitting_level: 1, } } } @@ -118,6 +120,10 @@ impl DaftExecutionConfig { if let Ok(val) = std::env::var(shuffle_algorithm_env_var_name) { cfg.shuffle_algorithm = val; } + let enable_aggressive_scantask_splitting_env_var_name = "DAFT_SCANTASK_SPLITTING_LEVEL"; + if let Ok(val) = std::env::var(enable_aggressive_scantask_splitting_env_var_name) { + cfg.scantask_splitting_level = val.parse::().unwrap_or(0); + } cfg } } diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 3228263b07..f60a14b537 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -99,6 +99,7 @@ impl PyDaftExecutionConfig { shuffle_algorithm: Option<&str>, pre_shuffle_merge_threshold: Option, enable_ray_tracing: Option, + scantask_splitting_level: Option, ) -> PyResult { let mut config = self.config.as_ref().clone(); @@ -184,6 +185,15 @@ impl PyDaftExecutionConfig { config.enable_ray_tracing = enable_ray_tracing; } + if let Some(scantask_splitting_level) = scantask_splitting_level { + if !matches!(scantask_splitting_level, 1 | 2) { + return Err(PyErr::new::( + "scantask_splitting_level must be 1 or 2", + )); + } + config.scantask_splitting_level = scantask_splitting_level; + } + Ok(Self { config: Arc::new(config), }) @@ -293,6 +303,11 @@ impl PyDaftExecutionConfig { fn enable_ray_tracing(&self) -> PyResult { Ok(self.config.enable_ray_tracing) } + + #[getter] + fn scantask_splitting_level(&self) -> PyResult { + Ok(self.config.scantask_splitting_level) + } } impl_bincode_py_state_serialization!(PyDaftExecutionConfig); diff --git a/src/daft-scan/src/scan_task_iters/mod.rs b/src/daft-scan/src/scan_task_iters/mod.rs index 2d11c3f251..b4e4d6e3c3 100644 --- a/src/daft-scan/src/scan_task_iters/mod.rs +++ b/src/daft-scan/src/scan_task_iters/mod.rs @@ -316,26 +316,34 @@ fn split_and_merge_pass( .iter() .all(|st| st.as_any().downcast_ref::().is_some()) { - // TODO(desmond): Here we downcast Arc to Arc. ScanTask and DummyScanTask (test only) are - // the only non-test implementer of ScanTaskLike. It might be possible to avoid the downcast by implementing merging - // at the trait level, but today that requires shifting around a non-trivial amount of code to avoid circular dependencies. - let iter: BoxScanTaskIter = Box::new(scan_tasks.as_ref().iter().map(|st| { - st.clone() - .as_any_arc() - .downcast::() - .map_err(|e| DaftError::TypeError(format!("Expected Arc, found {:?}", e))) - })); - let split_tasks = split_by_row_groups( - iter, - cfg.parquet_split_row_groups_max_files, - cfg.scan_tasks_min_size_bytes, - cfg.scan_tasks_max_size_bytes, - ); - let merged_tasks = merge_by_sizes(split_tasks, pushdowns, cfg); - let scan_tasks: Vec> = merged_tasks - .map(|st| st.map(|task| task as Arc)) - .collect::>>()?; - Ok(Arc::new(scan_tasks)) + if cfg.scantask_splitting_level == 1 { + // TODO(desmond): Here we downcast Arc to Arc. ScanTask and DummyScanTask (test only) are + // the only non-test implementer of ScanTaskLike. It might be possible to avoid the downcast by implementing merging + // at the trait level, but today that requires shifting around a non-trivial amount of code to avoid circular dependencies. + let iter: BoxScanTaskIter = Box::new(scan_tasks.as_ref().iter().map(|st| { + st.clone().as_any_arc().downcast::().map_err(|e| { + DaftError::TypeError(format!("Expected Arc, found {:?}", e)) + }) + })); + let split_tasks = split_by_row_groups( + iter, + cfg.parquet_split_row_groups_max_files, + cfg.scan_tasks_min_size_bytes, + cfg.scan_tasks_max_size_bytes, + ); + let merged_tasks = merge_by_sizes(split_tasks, pushdowns, cfg); + let scan_tasks: Vec> = merged_tasks + .map(|st| st.map(|task| task as Arc)) + .collect::>>()?; + Ok(Arc::new(scan_tasks)) + } else if cfg.scantask_splitting_level == 2 { + todo!("Implement aggressive scantask splitting"); + } else { + panic!( + "DAFT_SCANTASK_SPLITTING_LEVEL must be either 1 or 2, received: {}", + cfg.scantask_splitting_level + ); + } } else { Ok(scan_tasks) } diff --git a/tests/io/test_split_scan_tasks.py b/tests/io/test_split_scan_tasks.py index 96ca4ba83f..86990678b3 100644 --- a/tests/io/test_split_scan_tasks.py +++ b/tests/io/test_split_scan_tasks.py @@ -25,3 +25,60 @@ def test_split_parquet_read(parquet_files): df = daft.read_parquet(str(parquet_files)) assert df.num_partitions() == 10, "Should have 10 partitions since we will split the file" assert df.to_pydict() == {"data": ["aaa"] * 100} + + +def test_split_parquet_read_some_splits(tmpdir): + with daft.execution_config_ctx(scantask_splitting_level=2): + # Write a mix of 20 large and 20 small files + # Small ones should not be split, large ones should be split into 10 rowgroups each + # This gives us a total of 200 + 20 scantasks + + # Write 20 large files into tmpdir + large_file_paths = [] + for i in range(20): + tbl = pa.table({"data": [str(f"large{i}") for i in range(100)]}) + path = tmpdir / f"file.{i}.large.pq" + papq.write_table(tbl, str(path), row_group_size=10, use_dictionary=False) + large_file_paths.append(str(path)) + + # Write 20 small files into tmpdir + small_file_paths = [] + for i in range(20): + tbl = pa.table({"data": ["small"]}) + path = tmpdir / f"file.{i}.small.pq" + papq.write_table(tbl, str(path), row_group_size=1, use_dictionary=False) + small_file_paths.append(str(path)) + + # Test [large_paths, ..., small_paths, ...] + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=20, + scan_tasks_max_size_bytes=100, + ): + df = daft.read_parquet(large_file_paths + small_file_paths) + assert ( + df.num_partitions() == 220 + ), "Should have 220 partitions since we will split all large files (20 * 10 rowgroups) but keep small files unsplit" + assert df.to_pydict() == {"data": [str(f"large{i}") for i in range(100)] * 20 + ["small"] * 20} + + # Test interleaved [large_path, small_path, large_path, small_path, ...] + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=20, + scan_tasks_max_size_bytes=100, + ): + interleaved_paths = [path for pair in zip(large_file_paths, small_file_paths) for path in pair] + df = daft.read_parquet(interleaved_paths) + assert ( + df.num_partitions() == 220 + ), "Should have 220 partitions since we will split all large files (20 * 10 rowgroups) but keep small files unsplit" + assert df.to_pydict() == {"data": ([str(f"large{i}") for i in range(100)] + ["small"]) * 20} + + # Test [small_paths, ..., large_paths] + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=20, + scan_tasks_max_size_bytes=100, + ): + df = daft.read_parquet(small_file_paths + large_file_paths) + assert ( + df.num_partitions() == 220 + ), "Should have 220 partitions since we will split all large files (20 * 10 rowgroups) but keep small files unsplit" + assert df.to_pydict() == {"data": ["small"] * 20 + [str(f"large{i}") for i in range(100)] * 20}