diff --git a/daft/__init__.py b/daft/__init__.py index ec230ad5fa..153e6b8eb2 100644 --- a/daft/__init__.py +++ b/daft/__init__.py @@ -60,7 +60,7 @@ def refresh_logger() -> None: # Daft top-level imports ### -from daft.context import set_execution_config, set_planning_config +from daft.context import set_execution_config, set_planning_config, execution_config_ctx, planning_config_ctx from daft.convert import ( from_arrow, from_dask_dataframe, @@ -128,6 +128,8 @@ def refresh_logger() -> None: "Schema", "set_planning_config", "set_execution_config", + "planning_config_ctx", + "execution_config_ctx", "sql", "sql_expr", "to_struct", diff --git a/daft/context.py b/daft/context.py index 38ef8545d5..e854e87ec4 100644 --- a/daft/context.py +++ b/daft/context.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import dataclasses import logging import os @@ -244,6 +245,17 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext: return ctx +@contextlib.contextmanager +def planning_config_ctx(**kwargs): + """Context manager that wraps set_planning_config to reset the config to its original setting afternwards""" + original_config = get_context().daft_planning_config + try: + set_planning_config(**kwargs) + yield + finally: + set_planning_config(config=original_config) + + def set_planning_config( config: PyDaftPlanningConfig | None = None, default_io_config: IOConfig | None = None, @@ -269,6 +281,17 @@ def set_planning_config( return ctx +@contextlib.contextmanager +def execution_config_ctx(**kwargs): + """Context manager that wraps set_execution_config to reset the config to its original setting afternwards""" + original_config = get_context().daft_execution_config + try: + set_execution_config(**kwargs) + yield + finally: + set_execution_config(config=original_config) + + def set_execution_config( config: PyDaftExecutionConfig | None = None, scan_tasks_min_size_bytes: int | None = None, diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 935359bed5..88044b95d3 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -10,7 +10,7 @@ import pyarrow as pa -from daft.context import get_context, set_execution_config +from daft.context import execution_config_ctx, get_context from daft.logical.builder import LogicalPlanBuilder from daft.plan_scheduler import PhysicalPlanScheduler from daft.runners.progress_bar import ProgressBar @@ -350,8 +350,10 @@ def single_partition_pipeline( partial_metadatas: list[PartitionMetadata], *inputs: MicroPartition, ) -> list[list[PartitionMetadata] | MicroPartition]: - set_execution_config(daft_execution_config) - return build_partitions(instruction_stack, partial_metadatas, *inputs) + with execution_config_ctx( + config=daft_execution_config, + ): + return build_partitions(instruction_stack, partial_metadatas, *inputs) @ray.remote @@ -361,8 +363,8 @@ def fanout_pipeline( partial_metadatas: list[PartitionMetadata], *inputs: MicroPartition, ) -> list[list[PartitionMetadata] | MicroPartition]: - set_execution_config(daft_execution_config) - return build_partitions(instruction_stack, partial_metadatas, *inputs) + with execution_config_ctx(config=daft_execution_config): + return build_partitions(instruction_stack, partial_metadatas, *inputs) @ray.remote(scheduling_strategy="SPREAD") @@ -374,8 +376,8 @@ def reduce_pipeline( ) -> list[list[PartitionMetadata] | MicroPartition]: import ray - set_execution_config(daft_execution_config) - return build_partitions(instruction_stack, partial_metadatas, *ray.get(inputs)) + with execution_config_ctx(config=daft_execution_config): + return build_partitions(instruction_stack, partial_metadatas, *ray.get(inputs)) @ray.remote(scheduling_strategy="SPREAD") @@ -387,8 +389,8 @@ def reduce_and_fanout( ) -> list[list[PartitionMetadata] | MicroPartition]: import ray - set_execution_config(daft_execution_config) - return build_partitions(instruction_stack, partial_metadatas, *ray.get(inputs)) + with execution_config_ctx(config=daft_execution_config): + return build_partitions(instruction_stack, partial_metadatas, *ray.get(inputs)) @ray.remote diff --git a/docs/source/api_docs/configs.rst b/docs/source/api_docs/configs.rst index cc1322b08d..d11ce27021 100644 --- a/docs/source/api_docs/configs.rst +++ b/docs/source/api_docs/configs.rst @@ -23,7 +23,9 @@ Configure Daft in various ways during execution. :toctree: doc_gen/configuration_functions daft.set_planning_config + daft.planning_config_ctx daft.set_execution_config + daft.execution_config_ctx I/O Configurations ****************** diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index 3febcad055..1ec92c6027 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -120,7 +120,7 @@ pub(crate) fn build_row_ranges( let mut rows_to_add: i64 = limit.unwrap_or(i64::MAX); for i in row_groups { let i = *i as usize; - if !(0..metadata.row_groups.len()).contains(&i) { + if !metadata.row_groups.keys().any(|x| *x == i) { return Err(super::Error::ParquetRowGroupOutOfIndex { path: uri.to_string(), row_group: i as i64, diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 4639a4d3f4..2a4e8524f5 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -201,12 +201,6 @@ async fn read_parquet_single( )) }?; - let rows_per_row_groups = metadata - .row_groups - .values() - .map(|m| m.num_rows()) - .collect::>(); - let metadata_num_rows = metadata.num_rows; let metadata_num_columns = metadata.schema().fields().len(); @@ -265,7 +259,7 @@ async fn read_parquet_single( } else if let Some(row_groups) = row_groups { let expected_rows = row_groups .iter() - .map(|i| rows_per_row_groups.get(*i as usize).unwrap()) + .map(|i| metadata.row_groups.get(&(*i as usize)).unwrap().num_rows()) .sum::() - num_deleted_rows; if expected_rows != table.len() { diff --git a/tests/benchmarks/test_local_tpch.py b/tests/benchmarks/test_local_tpch.py index 7f06dcb4d5..45ad164523 100644 --- a/tests/benchmarks/test_local_tpch.py +++ b/tests/benchmarks/test_local_tpch.py @@ -33,8 +33,8 @@ def gen_tpch(request): csv_files_location = data_generation.gen_csv_files(TPCH_DBGEN_DIR, num_parts, SCALE_FACTOR) # Disable native executor to generate parquet files, remove once native executor supports writing parquet files - daft.context.set_execution_config(enable_native_executor=False) - parquet_files_location = data_generation.gen_parquet(csv_files_location) + with daft.context.execution_config_ctx(enable_native_executor=False): + parquet_files_location = data_generation.gen_parquet(csv_files_location) in_memory_tables = {} for tbl_name in data_generation.SCHEMA.keys(): @@ -106,14 +106,16 @@ def test_tpch(tmp_path, check_answer, get_df, benchmark_with_memray, engine, q): def f(): if engine == "native": - daft.context.set_execution_config(enable_native_executor=True) + ctx = daft.context.execution_config_ctx(enable_native_executor=True) elif engine == "python": - daft.context.set_execution_config(enable_native_executor=False) + ctx = daft.context.execution_config_ctx(enable_native_executor=False) else: raise ValueError(f"{engine} unsupported") - question = getattr(answers, f"q{q}") - daft_df = question(get_df) - return daft_df.to_arrow() + + with ctx: + question = getattr(answers, f"q{q}") + daft_df = question(get_df) + return daft_df.to_arrow() benchmark_group = f"q{q}-parts-{num_parts}" daft_pd_df = benchmark_with_memray(f, benchmark_group).to_pandas() diff --git a/tests/conftest.py b/tests/conftest.py index 1efafe82cb..c8c8b3b52b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,16 +21,6 @@ def pytest_addoption(parser): ) -@pytest.fixture(scope="session", autouse=True) -def set_execution_configs(): - """Sets global Daft config for testing""" - daft.set_execution_config( - # Disables merging of ScanTasks - scan_tasks_min_size_bytes=0, - scan_tasks_max_size_bytes=0, - ) - - def pytest_configure(config): config.addinivalue_line( "markers", "integration: mark test as an integration test that runs with external dependencies" @@ -82,14 +72,8 @@ def join_strategy(request): if request.param != "sort_merge_aligned_boundaries": yield request.param else: - old_execution_config = daft.context.get_context().daft_execution_config - try: - daft.set_execution_config( - sort_merge_join_sort_with_aligned_boundaries=True, - ) + with daft.execution_config_ctx(sort_merge_join_sort_with_aligned_boundaries=True): yield "sort_merge" - finally: - daft.set_execution_config(old_execution_config) @pytest.fixture(scope="function") @@ -135,7 +119,12 @@ def _make_df( else: raise NotImplementedError(f"make_df not implemented for: {variant}") - yield _make_df + with daft.execution_config_ctx( + # Disables merging of ScanTasks of Parquet when reading small Parquet files + scan_tasks_min_size_bytes=0, + scan_tasks_max_size_bytes=0, + ): + yield _make_df def assert_df_equals( diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index 9927c17c6b..c197ed922d 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -160,12 +160,8 @@ def test_parquet_write_with_null_values(tmp_path): @pytest.fixture() def smaller_parquet_target_filesize(): - old_execution_config = daft.context.get_context().daft_execution_config - try: - daft.set_execution_config(parquet_target_filesize=1024) + with daft.execution_config_ctx(parquet_target_filesize=1024): yield - finally: - daft.set_execution_config(old_execution_config) @pytest.mark.skipif( diff --git a/tests/dataframe/test_aggregations.py b/tests/dataframe/test_aggregations.py index 912e1c2f2f..1f68b68a2a 100644 --- a/tests/dataframe/test_aggregations.py +++ b/tests/dataframe/test_aggregations.py @@ -363,26 +363,26 @@ def test_groupby_result_partitions_smaller_than_input(shuffle_aggregation_defaul if shuffle_aggregation_default_partitions is None: min_partitions = get_context().daft_execution_config.shuffle_aggregation_default_partitions else: - daft.set_execution_config(shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions) min_partitions = shuffle_aggregation_default_partitions - for partition_size in [1, min_partitions, min_partitions + 1]: - df = daft.from_pydict( - {"group": [i for i in range(min_partitions + 1)], "value": [i for i in range(min_partitions + 1)]} - ) - df = df.into_partitions(partition_size) + with daft.execution_config_ctx(shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions): + for partition_size in [1, min_partitions, min_partitions + 1]: + df = daft.from_pydict( + {"group": [i for i in range(min_partitions + 1)], "value": [i for i in range(min_partitions + 1)]} + ) + df = df.into_partitions(partition_size) - df = df.groupby(col("group")).agg( - [ - col("value").sum().alias("sum"), - col("value").mean().alias("mean"), - col("value").min().alias("min"), - ] - ) + df = df.groupby(col("group")).agg( + [ + col("value").sum().alias("sum"), + col("value").mean().alias("mean"), + col("value").min().alias("min"), + ] + ) - df = df.collect() + df = df.collect() - assert df.num_partitions() == min(min_partitions, partition_size) + assert df.num_partitions() == min(min_partitions, partition_size) @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) diff --git a/tests/integration/io/parquet/test_reads_public_data.py b/tests/integration/io/parquet/test_reads_public_data.py index 55c163f9b5..ec9c9a397f 100644 --- a/tests/integration/io/parquet/test_reads_public_data.py +++ b/tests/integration/io/parquet/test_reads_public_data.py @@ -217,16 +217,11 @@ def set_split_config(request): max_size = 0 if request.param[0] else 384 * 1024 * 1024 min_size = 0 if request.param[1] else 96 * 1024 * 1024 - old_execution_config = daft.context.get_context().daft_execution_config - - try: - daft.set_execution_config( - scan_tasks_max_size_bytes=max_size, - scan_tasks_min_size_bytes=min_size, - ) + with daft.execution_config_ctx( + scan_tasks_max_size_bytes=max_size, + scan_tasks_min_size_bytes=min_size, + ): yield - finally: - daft.set_execution_config(old_execution_config) def read_parquet_with_pyarrow(path) -> pa.Table: diff --git a/tests/integration/sql/test_sql.py b/tests/integration/sql/test_sql.py index ddafca8947..2c395f458b 100644 --- a/tests/integration/sql/test_sql.py +++ b/tests/integration/sql/test_sql.py @@ -8,7 +8,6 @@ import sqlalchemy import daft -from daft.context import set_execution_config from tests.conftest import assert_df_equals from tests.integration.sql.conftest import TEST_TABLE_NAME @@ -36,11 +35,14 @@ def test_sql_create_dataframe_ok(test_db, pdf) -> None: def test_sql_partitioned_read(test_db, num_partitions, pdf) -> None: row_size_bytes = daft.from_pandas(pdf).schema().estimate_row_size_bytes() num_rows_per_partition = len(pdf) / num_partitions - set_execution_config(read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition)) - - df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id") - assert df.num_partitions() == num_partitions - assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id") + with daft.execution_config_ctx( + read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition), + scan_tasks_min_size_bytes=0, + scan_tasks_max_size_bytes=0, + ): + df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id") + assert df.num_partitions() == num_partitions + assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id") @pytest.mark.integration() @@ -49,27 +51,35 @@ def test_sql_partitioned_read(test_db, num_partitions, pdf) -> None: def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col( test_db, num_partitions, partition_col, pdf ) -> None: - df = daft.read_sql( - f"SELECT * FROM {TEST_TABLE_NAME}", - test_db, - partition_col=partition_col, - num_partitions=num_partitions, - ) - assert df.num_partitions() == num_partitions - assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id") + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=0, + scan_tasks_max_size_bytes=0, + ): + df = daft.read_sql( + f"SELECT * FROM {TEST_TABLE_NAME}", + test_db, + partition_col=partition_col, + num_partitions=num_partitions, + ) + assert df.num_partitions() == num_partitions + assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id") @pytest.mark.integration() @pytest.mark.parametrize("num_partitions", [1, 2, 3, 4]) def test_sql_partitioned_read_with_non_uniformly_distributed_column(test_db, num_partitions, pdf) -> None: - df = daft.read_sql( - f"SELECT * FROM {TEST_TABLE_NAME}", - test_db, - partition_col="non_uniformly_distributed_col", - num_partitions=num_partitions, - ) - assert df.num_partitions() == num_partitions - assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id") + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=0, + scan_tasks_max_size_bytes=0, + ): + df = daft.read_sql( + f"SELECT * FROM {TEST_TABLE_NAME}", + test_db, + partition_col="non_uniformly_distributed_col", + num_partitions=num_partitions, + ) + assert df.num_partitions() == num_partitions + assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id") @pytest.mark.integration() diff --git a/tests/io/delta_lake/test_table_read.py b/tests/io/delta_lake/test_table_read.py index 0e96fc9968..9cb5881a72 100644 --- a/tests/io/delta_lake/test_table_read.py +++ b/tests/io/delta_lake/test_table_read.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextlib import sys import pyarrow as pa @@ -11,19 +10,6 @@ from daft.logical.schema import Schema from tests.utils import assert_pyarrow_tables_equal - -@contextlib.contextmanager -def split_small_pq_files(): - old_config = daft.context.get_context().daft_execution_config - daft.set_execution_config( - # Splits any parquet files >100 bytes in size - scan_tasks_min_size_bytes=1, - scan_tasks_max_size_bytes=100, - ) - yield - daft.set_execution_config(config=old_config) - - PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0) PYTHON_LT_3_8 = sys.version_info[:2] < (3, 8) pytestmark = pytest.mark.skipif( @@ -65,7 +51,10 @@ def test_deltalake_read_row_group_splits(tmp_path, base_table): deltalake.write_deltalake(path, base_table, min_rows_per_group=1, max_rows_per_group=2) # Force file splitting - with split_small_pq_files(): + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=1, + scan_tasks_max_size_bytes=100, + ): df = daft.read_deltalake(str(path)) df.collect() assert len(df) == 3, "Length of non-materialized data when read through deltalake should be correct" @@ -79,7 +68,10 @@ def test_deltalake_read_row_group_splits_with_filter(tmp_path, base_table): deltalake.write_deltalake(path, base_table, min_rows_per_group=1, max_rows_per_group=2) # Force file splitting - with split_small_pq_files(): + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=1, + scan_tasks_max_size_bytes=100, + ): df = daft.read_deltalake(str(path)) df = df.where(df["a"] > 1) df.collect() @@ -94,7 +86,10 @@ def test_deltalake_read_row_group_splits_with_limit(tmp_path, base_table): deltalake.write_deltalake(path, base_table, min_rows_per_group=1, max_rows_per_group=2) # Force file splitting - with split_small_pq_files(): + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=1, + scan_tasks_max_size_bytes=100, + ): df = daft.read_deltalake(str(path)) df = df.limit(2) df.collect() diff --git a/tests/io/delta_lake/test_table_write.py b/tests/io/delta_lake/test_table_write.py index d7eb25678f..63f9bd1881 100644 --- a/tests/io/delta_lake/test_table_write.py +++ b/tests/io/delta_lake/test_table_write.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextlib import sys import pyarrow as pa @@ -10,19 +9,6 @@ from daft.io.object_store_options import io_config_to_storage_options from daft.logical.schema import Schema - -@contextlib.contextmanager -def split_small_pq_files(): - old_config = daft.context.get_context().daft_execution_config - daft.set_execution_config( - # Splits any parquet files >100 bytes in size - scan_tasks_min_size_bytes=1, - scan_tasks_max_size_bytes=100, - ) - yield - daft.set_execution_config(config=old_config) - - PYARROW_LE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) < (8, 0, 0) PYTHON_LT_3_8 = sys.version_info[:2] < (3, 8) pytestmark = pytest.mark.skipif( diff --git a/tests/io/test_merge_scan_tasks.py b/tests/io/test_merge_scan_tasks.py index d52b2e0c02..fcbd7514ad 100644 --- a/tests/io/test_merge_scan_tasks.py +++ b/tests/io/test_merge_scan_tasks.py @@ -1,26 +1,10 @@ from __future__ import annotations -import contextlib - import pytest import daft -@contextlib.contextmanager -def override_merge_scan_tasks_configs(scan_tasks_min_size_bytes: int, scan_tasks_max_size_bytes: int): - old_execution_config = daft.context.get_context().daft_execution_config - - try: - daft.set_execution_config( - scan_tasks_min_size_bytes=scan_tasks_min_size_bytes, - scan_tasks_max_size_bytes=scan_tasks_max_size_bytes, - ) - yield - finally: - daft.set_execution_config(old_execution_config) - - @pytest.fixture(scope="function") def csv_files(tmpdir): """Writes 3 CSV files, each of 10 bytes in size, to tmpdir and yield tmpdir""" @@ -33,7 +17,10 @@ def csv_files(tmpdir): def test_merge_scan_task_exceed_max(csv_files): - with override_merge_scan_tasks_configs(0, 0): + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=0, + scan_tasks_max_size_bytes=0, + ): df = daft.read_csv(str(csv_files)) assert ( df.num_partitions() == 3 @@ -41,7 +28,10 @@ def test_merge_scan_task_exceed_max(csv_files): def test_merge_scan_task_below_max(csv_files): - with override_merge_scan_tasks_configs(11, 12): + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=11, + scan_tasks_max_size_bytes=12, + ): df = daft.read_csv(str(csv_files)) assert ( df.num_partitions() == 2 @@ -49,7 +39,10 @@ def test_merge_scan_task_below_max(csv_files): def test_merge_scan_task_above_min(csv_files): - with override_merge_scan_tasks_configs(9, 20): + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=9, + scan_tasks_max_size_bytes=20, + ): df = daft.read_csv(str(csv_files)) assert ( df.num_partitions() == 2 @@ -57,7 +50,10 @@ def test_merge_scan_task_above_min(csv_files): def test_merge_scan_task_below_min(csv_files): - with override_merge_scan_tasks_configs(17, 20): + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=17, + scan_tasks_max_size_bytes=20, + ): df = daft.read_csv(str(csv_files)) assert ( df.num_partitions() == 1 diff --git a/tests/io/test_split_scan_tasks.py b/tests/io/test_split_scan_tasks.py new file mode 100644 index 0000000000..7b92655668 --- /dev/null +++ b/tests/io/test_split_scan_tasks.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +import pyarrow as pa +import pyarrow.parquet as papq +import pytest + +import daft + + +@pytest.fixture(scope="function") +def parquet_files(tmpdir): + """Writes 1 Parquet file with 10 rowgroups, each of 100 bytes in size""" + tbl = pa.table({"data": ["aaa"] * 100}) + path = tmpdir / "file.pq" + papq.write_table(tbl, str(path), row_group_size=10, use_dictionary=False) + + return tmpdir + + +def test_split_parquet_read(parquet_files): + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=1, + scan_tasks_max_size_bytes=10, + ): + df = daft.read_parquet(str(parquet_files)) + assert df.num_partitions() == 10, "Should have 10 partitions since we will split the file" + assert df.to_pydict() == {"data": ["aaa"] * 100} diff --git a/tests/ray/runner.py b/tests/ray/runner.py index 9b1e306762..3a82ed9333 100644 --- a/tests/ray/runner.py +++ b/tests/ray/runner.py @@ -8,53 +8,69 @@ @pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner") def test_active_plan_clean_up_df_show(): - path = "tests/assets/parquet-data/mvp.parquet" - df = daft.read_parquet([path, path]) - df.show() - runner = get_context().runner() - assert len(runner.active_plans()) == 0 + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=0, + scan_tasks_max_size_bytes=0, + ): + path = "tests/assets/parquet-data/mvp.parquet" + df = daft.read_parquet([path, path]) + df.show() + runner = get_context().runner() + assert len(runner.active_plans()) == 0 @pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner") def test_active_plan_single_iter_partitions(): - path = "tests/assets/parquet-data/mvp.parquet" - df = daft.read_parquet([path, path]) - iter = df.iter_partitions() - next(iter) - runner = get_context().runner() - assert len(runner.active_plans()) == 1 - del iter - assert len(runner.active_plans()) == 0 + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=0, + scan_tasks_max_size_bytes=0, + ): + path = "tests/assets/parquet-data/mvp.parquet" + df = daft.read_parquet([path, path]) + iter = df.iter_partitions() + next(iter) + runner = get_context().runner() + assert len(runner.active_plans()) == 1 + del iter + assert len(runner.active_plans()) == 0 @pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner") def test_active_plan_multiple_iter_partitions(): - path = "tests/assets/parquet-data/mvp.parquet" - df = daft.read_parquet([path, path]) - iter = df.iter_partitions() - next(iter) - runner = get_context().runner() - assert len(runner.active_plans()) == 1 + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=0, + scan_tasks_max_size_bytes=0, + ): + path = "tests/assets/parquet-data/mvp.parquet" + df = daft.read_parquet([path, path]) + iter = df.iter_partitions() + next(iter) + runner = get_context().runner() + assert len(runner.active_plans()) == 1 - df2 = daft.read_parquet([path, path]) - iter2 = df2.iter_partitions() - next(iter2) - assert len(runner.active_plans()) == 2 + df2 = daft.read_parquet([path, path]) + iter2 = df2.iter_partitions() + next(iter2) + assert len(runner.active_plans()) == 2 - del iter - assert len(runner.active_plans()) == 1 + del iter + assert len(runner.active_plans()) == 1 - del iter2 - assert len(runner.active_plans()) == 0 + del iter2 + assert len(runner.active_plans()) == 0 @pytest.mark.skipif(get_context().runner_config.name != "ray", reason="Needs to run on Ray runner") def test_active_plan_with_show_and_write_parquet(tmpdir): - df = daft.read_parquet("tests/assets/parquet-data/mvp.parquet") - df = df.into_partitions(8) - df = df.join(df, on="a") - df.show() - runner = get_context().runner() - assert len(runner.active_plans()) == 0 - df.write_parquet(tmpdir.dirname) - assert len(runner.active_plans()) == 0 + with daft.execution_config_ctx( + scan_tasks_min_size_bytes=0, + scan_tasks_max_size_bytes=0, + ): + df = daft.read_parquet("tests/assets/parquet-data/mvp.parquet") + df = df.into_partitions(8) + df = df.join(df, on="a") + df.show() + runner = get_context().runner() + assert len(runner.active_plans()) == 0 + df.write_parquet(tmpdir.dirname) + assert len(runner.active_plans()) == 0