Skip to content

Commit

Permalink
Add tests for merge_scan_task
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Dec 5, 2023
1 parent 2b04c99 commit 4ad1edf
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
from daft.table import MicroPartition


@pytest.fixture(scope="session", autouse=True)
def set_configs():
"""Sets global Daft config for testing"""
daft.context.set_config(
# Disables merging of ScanTasks
merge_scan_tasks_min_size_bytes=0,
merge_scan_tasks_max_size_bytes=0,
)


def pytest_configure(config):
config.addinivalue_line(
"markers", "integration: mark test as an integration test that runs with external dependencies"
Expand Down
69 changes: 69 additions & 0 deletions tests/io/test_merge_scan_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

import contextlib

import pytest

import daft


@contextlib.contextmanager
def override_merge_scan_tasks_configs(merge_scan_tasks_min_size_bytes: int, merge_scan_tasks_max_size_bytes: int):
config = daft.context.get_context().daft_config
original_merge_scan_tasks_min_size_bytes = config.merge_scan_tasks_min_size_bytes
original_merge_scan_tasks_max_size_bytes = config.merge_scan_tasks_max_size_bytes

try:
daft.context.set_config(
merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes,
)
yield
finally:
daft.context.set_config(
merge_scan_tasks_min_size_bytes=original_merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=original_merge_scan_tasks_max_size_bytes,
)


@pytest.fixture(scope="function")
def csv_files(tmpdir):
"""Writes 3 CSV files, each of 10 bytes in size, to tmpdir and yield tmpdir"""

for i in range(3):
path = tmpdir / f"file.{i}.csv"
path.write_text("a,b,c\n1,2,", "utf8") # 10 bytes

return tmpdir


def test_merge_scan_task_exceed_max(csv_files):
with override_merge_scan_tasks_configs(1, 10):
df = daft.read_csv(str(csv_files))
assert (
df.num_partitions() == 3
), "Should have 3 partitions since all merges are more than the maximum (>10 bytes)"


def test_merge_scan_task_below_max(csv_files):
with override_merge_scan_tasks_configs(1, 20):
df = daft.read_csv(str(csv_files))
assert (
df.num_partitions() == 2
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the second merge is too large (>20 bytes)"


def test_merge_scan_task_above_min(csv_files):
with override_merge_scan_tasks_configs(12, 40):
df = daft.read_csv(str(csv_files))
assert (
df.num_partitions() == 2
), "Should have 2 partitions [(CSV1, CSV2), (CSV3)] since the first merge is above the minimum (>12 bytes)"


def test_merge_scan_task_below_min(csv_files):
with override_merge_scan_tasks_configs(35, 40):
df = daft.read_csv(str(csv_files))
assert (
df.num_partitions() == 1
), "Should have 1 partition [(CSV1, CSV2, CSV3)] since both merges are below the minimum and maximum"

0 comments on commit 4ad1edf

Please sign in to comment.