Skip to content

Commit

Permalink
[BUG] Add a with_execution/planning_config context manager and fix te…
Browse files Browse the repository at this point in the history
…sts for splitting of parquet (#2713)

1. Adds a context manager to perform overrides of the configs that will
reset themselves afterwards
2. Adds some fixes for splitting of Parquet files

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
3 people authored Aug 23, 2024
1 parent a18b30a commit 959f8f6
Show file tree
Hide file tree
Showing 17 changed files with 215 additions and 180 deletions.
4 changes: 3 additions & 1 deletion daft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def refresh_logger() -> None:
# Daft top-level imports
###

from daft.context import set_execution_config, set_planning_config
from daft.context import set_execution_config, set_planning_config, execution_config_ctx, planning_config_ctx
from daft.convert import (
from_arrow,
from_dask_dataframe,
Expand Down Expand Up @@ -128,6 +128,8 @@ def refresh_logger() -> None:
"Schema",
"set_planning_config",
"set_execution_config",
"planning_config_ctx",
"execution_config_ctx",
"sql",
"sql_expr",
"to_struct",
Expand Down
23 changes: 23 additions & 0 deletions daft/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import contextlib
import dataclasses
import logging
import os
Expand Down Expand Up @@ -244,6 +245,17 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext:
return ctx


@contextlib.contextmanager
def planning_config_ctx(**kwargs):
"""Context manager that wraps set_planning_config to reset the config to its original setting afternwards"""
original_config = get_context().daft_planning_config
try:
set_planning_config(**kwargs)
yield
finally:
set_planning_config(config=original_config)


def set_planning_config(
config: PyDaftPlanningConfig | None = None,
default_io_config: IOConfig | None = None,
Expand All @@ -269,6 +281,17 @@ def set_planning_config(
return ctx


@contextlib.contextmanager
def execution_config_ctx(**kwargs):
"""Context manager that wraps set_execution_config to reset the config to its original setting afternwards"""
original_config = get_context().daft_execution_config
try:
set_execution_config(**kwargs)
yield
finally:
set_execution_config(config=original_config)


def set_execution_config(
config: PyDaftExecutionConfig | None = None,
scan_tasks_min_size_bytes: int | None = None,
Expand Down
20 changes: 11 additions & 9 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import pyarrow as pa

from daft.context import get_context, set_execution_config
from daft.context import execution_config_ctx, get_context
from daft.logical.builder import LogicalPlanBuilder
from daft.plan_scheduler import PhysicalPlanScheduler
from daft.runners.progress_bar import ProgressBar
Expand Down Expand Up @@ -350,8 +350,10 @@ def single_partition_pipeline(
partial_metadatas: list[PartitionMetadata],
*inputs: MicroPartition,
) -> list[list[PartitionMetadata] | MicroPartition]:
set_execution_config(daft_execution_config)
return build_partitions(instruction_stack, partial_metadatas, *inputs)
with execution_config_ctx(
config=daft_execution_config,
):
return build_partitions(instruction_stack, partial_metadatas, *inputs)


@ray.remote
Expand All @@ -361,8 +363,8 @@ def fanout_pipeline(
partial_metadatas: list[PartitionMetadata],
*inputs: MicroPartition,
) -> list[list[PartitionMetadata] | MicroPartition]:
set_execution_config(daft_execution_config)
return build_partitions(instruction_stack, partial_metadatas, *inputs)
with execution_config_ctx(config=daft_execution_config):
return build_partitions(instruction_stack, partial_metadatas, *inputs)


@ray.remote(scheduling_strategy="SPREAD")
Expand All @@ -374,8 +376,8 @@ def reduce_pipeline(
) -> list[list[PartitionMetadata] | MicroPartition]:
import ray

set_execution_config(daft_execution_config)
return build_partitions(instruction_stack, partial_metadatas, *ray.get(inputs))
with execution_config_ctx(config=daft_execution_config):
return build_partitions(instruction_stack, partial_metadatas, *ray.get(inputs))


@ray.remote(scheduling_strategy="SPREAD")
Expand All @@ -387,8 +389,8 @@ def reduce_and_fanout(
) -> list[list[PartitionMetadata] | MicroPartition]:
import ray

set_execution_config(daft_execution_config)
return build_partitions(instruction_stack, partial_metadatas, *ray.get(inputs))
with execution_config_ctx(config=daft_execution_config):
return build_partitions(instruction_stack, partial_metadatas, *ray.get(inputs))


@ray.remote
Expand Down
2 changes: 2 additions & 0 deletions docs/source/api_docs/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ Configure Daft in various ways during execution.
:toctree: doc_gen/configuration_functions

daft.set_planning_config
daft.planning_config_ctx
daft.set_execution_config
daft.execution_config_ctx

I/O Configurations
******************
Expand Down
2 changes: 1 addition & 1 deletion src/daft-parquet/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ pub(crate) fn build_row_ranges(
let mut rows_to_add: i64 = limit.unwrap_or(i64::MAX);
for i in row_groups {
let i = *i as usize;
if !(0..metadata.row_groups.len()).contains(&i) {
if !metadata.row_groups.keys().any(|x| *x == i) {
return Err(super::Error::ParquetRowGroupOutOfIndex {
path: uri.to_string(),
row_group: i as i64,
Expand Down
8 changes: 1 addition & 7 deletions src/daft-parquet/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,6 @@ async fn read_parquet_single(
))
}?;

let rows_per_row_groups = metadata
.row_groups
.values()
.map(|m| m.num_rows())
.collect::<Vec<_>>();

let metadata_num_rows = metadata.num_rows;
let metadata_num_columns = metadata.schema().fields().len();

Expand Down Expand Up @@ -265,7 +259,7 @@ async fn read_parquet_single(
} else if let Some(row_groups) = row_groups {
let expected_rows = row_groups
.iter()
.map(|i| rows_per_row_groups.get(*i as usize).unwrap())
.map(|i| metadata.row_groups.get(&(*i as usize)).unwrap().num_rows())
.sum::<usize>()
- num_deleted_rows;
if expected_rows != table.len() {
Expand Down
16 changes: 9 additions & 7 deletions tests/benchmarks/test_local_tpch.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def gen_tpch(request):
csv_files_location = data_generation.gen_csv_files(TPCH_DBGEN_DIR, num_parts, SCALE_FACTOR)

# Disable native executor to generate parquet files, remove once native executor supports writing parquet files
daft.context.set_execution_config(enable_native_executor=False)
parquet_files_location = data_generation.gen_parquet(csv_files_location)
with daft.context.execution_config_ctx(enable_native_executor=False):
parquet_files_location = data_generation.gen_parquet(csv_files_location)

in_memory_tables = {}
for tbl_name in data_generation.SCHEMA.keys():
Expand Down Expand Up @@ -106,14 +106,16 @@ def test_tpch(tmp_path, check_answer, get_df, benchmark_with_memray, engine, q):

def f():
if engine == "native":
daft.context.set_execution_config(enable_native_executor=True)
ctx = daft.context.execution_config_ctx(enable_native_executor=True)
elif engine == "python":
daft.context.set_execution_config(enable_native_executor=False)
ctx = daft.context.execution_config_ctx(enable_native_executor=False)
else:
raise ValueError(f"{engine} unsupported")
question = getattr(answers, f"q{q}")
daft_df = question(get_df)
return daft_df.to_arrow()

with ctx:
question = getattr(answers, f"q{q}")
daft_df = question(get_df)
return daft_df.to_arrow()

benchmark_group = f"q{q}-parts-{num_parts}"
daft_pd_df = benchmark_with_memray(f, benchmark_group).to_pandas()
Expand Down
25 changes: 7 additions & 18 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,6 @@ def pytest_addoption(parser):
)


@pytest.fixture(scope="session", autouse=True)
def set_execution_configs():
"""Sets global Daft config for testing"""
daft.set_execution_config(
# Disables merging of ScanTasks
scan_tasks_min_size_bytes=0,
scan_tasks_max_size_bytes=0,
)


def pytest_configure(config):
config.addinivalue_line(
"markers", "integration: mark test as an integration test that runs with external dependencies"
Expand Down Expand Up @@ -82,14 +72,8 @@ def join_strategy(request):
if request.param != "sort_merge_aligned_boundaries":
yield request.param
else:
old_execution_config = daft.context.get_context().daft_execution_config
try:
daft.set_execution_config(
sort_merge_join_sort_with_aligned_boundaries=True,
)
with daft.execution_config_ctx(sort_merge_join_sort_with_aligned_boundaries=True):
yield "sort_merge"
finally:
daft.set_execution_config(old_execution_config)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -135,7 +119,12 @@ def _make_df(
else:
raise NotImplementedError(f"make_df not implemented for: {variant}")

yield _make_df
with daft.execution_config_ctx(
# Disables merging of ScanTasks of Parquet when reading small Parquet files
scan_tasks_min_size_bytes=0,
scan_tasks_max_size_bytes=0,
):
yield _make_df


def assert_df_equals(
Expand Down
6 changes: 1 addition & 5 deletions tests/cookbook/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,8 @@ def test_parquet_write_with_null_values(tmp_path):

@pytest.fixture()
def smaller_parquet_target_filesize():
old_execution_config = daft.context.get_context().daft_execution_config
try:
daft.set_execution_config(parquet_target_filesize=1024)
with daft.execution_config_ctx(parquet_target_filesize=1024):
yield
finally:
daft.set_execution_config(old_execution_config)


@pytest.mark.skipif(
Expand Down
30 changes: 15 additions & 15 deletions tests/dataframe/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,26 +363,26 @@ def test_groupby_result_partitions_smaller_than_input(shuffle_aggregation_defaul
if shuffle_aggregation_default_partitions is None:
min_partitions = get_context().daft_execution_config.shuffle_aggregation_default_partitions
else:
daft.set_execution_config(shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions)
min_partitions = shuffle_aggregation_default_partitions

for partition_size in [1, min_partitions, min_partitions + 1]:
df = daft.from_pydict(
{"group": [i for i in range(min_partitions + 1)], "value": [i for i in range(min_partitions + 1)]}
)
df = df.into_partitions(partition_size)
with daft.execution_config_ctx(shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions):
for partition_size in [1, min_partitions, min_partitions + 1]:
df = daft.from_pydict(
{"group": [i for i in range(min_partitions + 1)], "value": [i for i in range(min_partitions + 1)]}
)
df = df.into_partitions(partition_size)

df = df.groupby(col("group")).agg(
[
col("value").sum().alias("sum"),
col("value").mean().alias("mean"),
col("value").min().alias("min"),
]
)
df = df.groupby(col("group")).agg(
[
col("value").sum().alias("sum"),
col("value").mean().alias("mean"),
col("value").min().alias("min"),
]
)

df = df.collect()
df = df.collect()

assert df.num_partitions() == min(min_partitions, partition_size)
assert df.num_partitions() == min(min_partitions, partition_size)


@pytest.mark.parametrize("repartition_nparts", [1, 2, 4])
Expand Down
13 changes: 4 additions & 9 deletions tests/integration/io/parquet/test_reads_public_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,11 @@ def set_split_config(request):
max_size = 0 if request.param[0] else 384 * 1024 * 1024
min_size = 0 if request.param[1] else 96 * 1024 * 1024

old_execution_config = daft.context.get_context().daft_execution_config

try:
daft.set_execution_config(
scan_tasks_max_size_bytes=max_size,
scan_tasks_min_size_bytes=min_size,
)
with daft.execution_config_ctx(
scan_tasks_max_size_bytes=max_size,
scan_tasks_min_size_bytes=min_size,
):
yield
finally:
daft.set_execution_config(old_execution_config)


def read_parquet_with_pyarrow(path) -> pa.Table:
Expand Down
54 changes: 32 additions & 22 deletions tests/integration/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import sqlalchemy

import daft
from daft.context import set_execution_config
from tests.conftest import assert_df_equals
from tests.integration.sql.conftest import TEST_TABLE_NAME

Expand Down Expand Up @@ -36,11 +35,14 @@ def test_sql_create_dataframe_ok(test_db, pdf) -> None:
def test_sql_partitioned_read(test_db, num_partitions, pdf) -> None:
row_size_bytes = daft.from_pandas(pdf).schema().estimate_row_size_bytes()
num_rows_per_partition = len(pdf) / num_partitions
set_execution_config(read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition))

df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id")
assert df.num_partitions() == num_partitions
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")
with daft.execution_config_ctx(
read_sql_partition_size_bytes=math.ceil(row_size_bytes * num_rows_per_partition),
scan_tasks_min_size_bytes=0,
scan_tasks_max_size_bytes=0,
):
df = daft.read_sql(f"SELECT * FROM {TEST_TABLE_NAME}", test_db, partition_col="id")
assert df.num_partitions() == num_partitions
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")


@pytest.mark.integration()
Expand All @@ -49,27 +51,35 @@ def test_sql_partitioned_read(test_db, num_partitions, pdf) -> None:
def test_sql_partitioned_read_with_custom_num_partitions_and_partition_col(
test_db, num_partitions, partition_col, pdf
) -> None:
df = daft.read_sql(
f"SELECT * FROM {TEST_TABLE_NAME}",
test_db,
partition_col=partition_col,
num_partitions=num_partitions,
)
assert df.num_partitions() == num_partitions
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")
with daft.execution_config_ctx(
scan_tasks_min_size_bytes=0,
scan_tasks_max_size_bytes=0,
):
df = daft.read_sql(
f"SELECT * FROM {TEST_TABLE_NAME}",
test_db,
partition_col=partition_col,
num_partitions=num_partitions,
)
assert df.num_partitions() == num_partitions
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")


@pytest.mark.integration()
@pytest.mark.parametrize("num_partitions", [1, 2, 3, 4])
def test_sql_partitioned_read_with_non_uniformly_distributed_column(test_db, num_partitions, pdf) -> None:
df = daft.read_sql(
f"SELECT * FROM {TEST_TABLE_NAME}",
test_db,
partition_col="non_uniformly_distributed_col",
num_partitions=num_partitions,
)
assert df.num_partitions() == num_partitions
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")
with daft.execution_config_ctx(
scan_tasks_min_size_bytes=0,
scan_tasks_max_size_bytes=0,
):
df = daft.read_sql(
f"SELECT * FROM {TEST_TABLE_NAME}",
test_db,
partition_col="non_uniformly_distributed_col",
num_partitions=num_partitions,
)
assert df.num_partitions() == num_partitions
assert_df_equals(df.to_pandas(coerce_temporal_nanoseconds=True), pdf, sort_key="id")


@pytest.mark.integration()
Expand Down
Loading

0 comments on commit 959f8f6

Please sign in to comment.