From 82271d5183b1b5c6869965ef5db8cf09cfe4c295 Mon Sep 17 00:00:00 2001 From: Jay Chia Date: Tue, 5 Dec 2023 10:42:32 -0800 Subject: [PATCH] Add tests for merge_scan_task --- tests/conftest.py | 10 +++++ tests/io/test_merge_scan_tasks.py | 69 +++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 tests/io/test_merge_scan_tasks.py diff --git a/tests/conftest.py b/tests/conftest.py index dfb396c053..805b33e288 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,16 @@ from daft.table import MicroPartition +@pytest.fixture(scope="session", autouse=True) +def set_configs(): + """Sets global Daft config for testing""" + daft.context.set_config( + # Disables merging of ScanTasks + merge_scan_tasks_min_size_bytes=0, + merge_scan_tasks_max_size_bytes=0, + ) + + def pytest_configure(config): config.addinivalue_line( "markers", "integration: mark test as an integration test that runs with external dependencies" diff --git a/tests/io/test_merge_scan_tasks.py b/tests/io/test_merge_scan_tasks.py new file mode 100644 index 0000000000..8ff6b47ad0 --- /dev/null +++ b/tests/io/test_merge_scan_tasks.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import contextlib + +import pytest + +import daft + + +@contextlib.contextmanager +def override_merge_scan_tasks_configs(merge_scan_tasks_min_size_bytes: int, merge_scan_tasks_max_size_bytes: int): + config = daft.context.get_context().daft_config + original_merge_scan_tasks_min_size_bytes = config.merge_scan_tasks_min_size_bytes + original_merge_scan_tasks_max_size_bytes = config.merge_scan_tasks_max_size_bytes + + try: + daft.context.set_config( + merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes, + merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes, + ) + yield + finally: + daft.context.set_config( + merge_scan_tasks_min_size_bytes=original_merge_scan_tasks_min_size_bytes, + merge_scan_tasks_max_size_bytes=original_merge_scan_tasks_max_size_bytes, + ) + + +@pytest.fixture(scope="function") +def csv_files(tmpdir): + """Writes 3 CSV files, each of 10 bytes in size, to tmpdir and yield tmpdir""" + + for i in range(3): + path = tmpdir / f"file.{i}.csv" + path.write_text("a,b,c\n1,2,", "utf8") # 10 bytes + + return tmpdir + + +def test_merge_scan_task_exceed_max(csv_files): + with override_merge_scan_tasks_configs(0, 0): + df = daft.read_csv(str(csv_files)) + assert ( + df.num_partitions() == 3 + ), "Should have 3 partitions since all merges are more than the maximum (>0 bytes)" + + +def test_merge_scan_task_below_max(csv_files): + with override_merge_scan_tasks_configs(1, 20): + df = daft.read_csv(str(csv_files)) + assert ( + df.num_partitions() == 2 + ), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>20 bytes)" + + +def test_merge_scan_task_above_min(csv_files): + with override_merge_scan_tasks_configs(0, 40): + df = daft.read_csv(str(csv_files)) + assert ( + df.num_partitions() == 2 + ), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>0 bytes)" + + +def test_merge_scan_task_below_min(csv_files): + with override_merge_scan_tasks_configs(35, 40): + df = daft.read_csv(str(csv_files)) + assert ( + df.num_partitions() == 1 + ), "Should have 1 partition [(CSV1, CSV2, CSV3)] since both merges are below the minimum and maximum"