Skip to content

Commit

Permalink
[FEAT] Add ability to set global IOConfig (#1710)
Browse files Browse the repository at this point in the history
Planning configs:
* Config flags that are used in LogicalPlan building and construction
* Includes IOConfigs, and these are eagerly created and put onto the
logical plans
* This exists as a mutable global singleton, and are not propagated to
remote workers since all planning work should happen client-side

Execution configs:
* Config flags that are used in the execution chain
(`logical-to-physical-translation`, `physical-to-task-generation`,
`task-scheduling`, `task-execution`)
* These should be "frozen" per-unique-execution, and used unambiguously
throughout that execution

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Dec 12, 2023
1 parent 668399b commit 97753a2
Show file tree
Hide file tree
Showing 24 changed files with 281 additions and 195 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions daft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def get_build_type() -> str:
# Daft top-level imports
###

from daft.context import set_execution_config, set_planning_config
from daft.convert import (
from_arrow,
from_dask_dataframe,
Expand Down Expand Up @@ -94,4 +95,6 @@ def get_build_type() -> str:
"register_viz_hook",
"udf",
"ResourceRequest",
"set_planning_config",
"set_execution_config",
]
79 changes: 42 additions & 37 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from typing import TYPE_CHECKING, ClassVar

from daft.daft import PyDaftConfig
from daft.daft import IOConfig, PyDaftExecutionConfig, PyDaftPlanningConfig

if TYPE_CHECKING:
from daft.runners.runner import Runner
Expand Down Expand Up @@ -57,7 +57,13 @@ def _get_runner_config_from_env() -> _RunnerConfig:
class DaftContext:
"""Global context for the current Daft execution environment"""

daft_config: PyDaftConfig = PyDaftConfig()
# When a dataframe is executed, this config is copied into the Runner
# which then keeps track of a per-unique-execution-ID copy of the config, using it consistently throughout the execution
daft_execution_config: PyDaftExecutionConfig = PyDaftExecutionConfig()

# Non-execution calls (e.g. creation of a dataframe, logical plan building etc) directly reference values in this config
daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig()

runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env)
disallow_set_runner: bool = False
_runner: Runner | None = None
Expand All @@ -71,7 +77,6 @@ def runner(self) -> Runner:

assert isinstance(self.runner_config, _RayRunnerConfig)
self._runner = RayRunner(
daft_config=self.daft_config,
address=self.runner_config.address,
max_task_backlog=self.runner_config.max_task_backlog,
)
Expand All @@ -92,7 +97,7 @@ def runner(self) -> Runner:
pass

assert isinstance(self.runner_config, _PyRunnerConfig)
self._runner = PyRunner(daft_config=self.daft_config, use_thread_pool=self.runner_config.use_thread_pool)
self._runner = PyRunner(use_thread_pool=self.runner_config.use_thread_pool)

else:
raise NotImplementedError(f"Runner config implemented: {self.runner_config.name}")
Expand All @@ -115,25 +120,6 @@ def get_context() -> DaftContext:
return _DaftContext


def set_context(ctx: DaftContext) -> DaftContext:
global _DaftContext

pop_context()
_DaftContext = ctx

return _DaftContext


def pop_context() -> DaftContext:
"""Helper used in tests and test fixtures to clear the global runner and allow for re-setting of configs."""
global _DaftContext

old_daft_context = _DaftContext
_DaftContext = DaftContext()

return old_daft_context


def set_runner_ray(
address: str | None = None,
noop_if_initialized: bool = False,
Expand Down Expand Up @@ -191,16 +177,41 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext:
return ctx


def set_config(
config: PyDaftConfig | None = None,
def set_planning_config(
config: PyDaftPlanningConfig | None = None,
default_io_config: IOConfig | None = None,
) -> DaftContext:
"""Globally sets varioous configuration parameters which control Daft plan construction behavior. These configuration values
are used when a Dataframe is being constructed (e.g. calls to create a Dataframe, or to build on an existing Dataframe)
Args:
config: A PyDaftPlanningConfig object to set the config to, before applying other kwargs. Defaults to None which indicates
that the old (current) config should be used.
default_io_config: A default IOConfig to use in the absence of one being explicitly passed into any Expression (e.g. `.url.download()`)
or Dataframe operation (e.g. `daft.read_parquet()`).
"""
# Replace values in the DaftPlanningConfig with user-specified overrides
ctx = get_context()
old_daft_planning_config = ctx.daft_planning_config if config is None else config
new_daft_planning_config = old_daft_planning_config.with_config_values(
default_io_config=default_io_config,
)

ctx.daft_planning_config = new_daft_planning_config
return ctx


def set_execution_config(
config: PyDaftExecutionConfig | None = None,
merge_scan_tasks_min_size_bytes: int | None = None,
merge_scan_tasks_max_size_bytes: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
) -> DaftContext:
"""Globally sets various configuration parameters which control various aspects of Daft execution
"""Globally sets various configuration parameters which control various aspects of Daft execution. These configuration values
are used when a Dataframe is executed (e.g. calls to `.write_*`, `.collect()` or `.show()`)
Args:
config: A PyDaftConfig object to set the config to, before applying other kwargs. Defaults to None which indicates
config: A PyDaftExecutionConfig object to set the config to, before applying other kwargs. Defaults to None which indicates
that the old (current) config should be used.
merge_scan_tasks_min_size_bytes: Minimum size in bytes when merging ScanTasks when reading files from storage.
Increasing this value will make Daft perform more merging of files into a single partition before yielding,
Expand All @@ -211,20 +222,14 @@ def set_config(
broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used.
Default is 10 MiB.
"""
# Replace values in the DaftExecutionConfig with user-specified overrides
ctx = get_context()
if ctx._runner is not None:
raise RuntimeError(
"Cannot call `set_config` after the runner has already been created. "
"Please call `set_config` before any dataframe creation or execution."
)

# Replace values in the DaftConfig with user-specified overrides
old_daft_config = ctx.daft_config if config is None else config
new_daft_config = old_daft_config.with_config_values(
old_daft_execution_config = ctx.daft_execution_config if config is None else config
new_daft_execution_config = old_daft_execution_config.with_config_values(
merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
)

ctx.daft_config = new_daft_config
ctx.daft_execution_config = new_daft_execution_config
return ctx
22 changes: 15 additions & 7 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ class NativeStorageConfig:
multithreaded_io: bool
io_config: IOConfig

def __init__(self, multithreaded_io: bool, io_config: IOConfig | None = None): ...
def __init__(self, multithreaded_io: bool, io_config: IOConfig): ...

class PythonStorageConfig:
"""
Expand All @@ -503,7 +503,7 @@ class PythonStorageConfig:

io_config: IOConfig

def __init__(self, io_config: IOConfig | None = None): ...
def __init__(self, io_config: IOConfig): ...

class StorageConfig:
"""
Expand Down Expand Up @@ -836,7 +836,7 @@ class PyExpr:
def list_join(self, delimiter: PyExpr) -> PyExpr: ...
def list_lengths(self) -> PyExpr: ...
def url_download(
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig | None = None
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig
) -> PyExpr: ...

def eq(expr1: PyExpr, expr2: PyExpr) -> bool: ...
Expand Down Expand Up @@ -1094,23 +1094,31 @@ class LogicalPlanBuilder:
) -> LogicalPlanBuilder: ...
def schema(self) -> PySchema: ...
def optimize(self) -> LogicalPlanBuilder: ...
def to_physical_plan_scheduler(self, cfg: PyDaftConfig) -> PhysicalPlanScheduler: ...
def to_physical_plan_scheduler(self, cfg: PyDaftExecutionConfig) -> PhysicalPlanScheduler: ...
def repr_ascii(self, simple: bool) -> str: ...

class PyDaftConfig:
class PyDaftExecutionConfig:
def with_config_values(
self,
merge_scan_tasks_min_size_bytes: int | None = None,
merge_scan_tasks_max_size_bytes: int | None = None,
broadcast_join_size_bytes_threshold: int | None = None,
) -> PyDaftConfig: ...
) -> PyDaftExecutionConfig: ...
@property
def merge_scan_tasks_min_size_bytes(self): ...
def merge_scan_tasks_min_size_bytes(self) -> int: ...
@property
def merge_scan_tasks_max_size_bytes(self): ...
@property
def broadcast_join_size_bytes_threshold(self): ...

class PyDaftPlanningConfig:
def with_config_values(
self,
default_io_config: IOConfig | None = None,
) -> PyDaftPlanningConfig: ...
@property
def default_io_config(self) -> IOConfig: ...

def build_type() -> str: ...
def version() -> str: ...
def __getattr__(name) -> Any: ...
Expand Down
8 changes: 6 additions & 2 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def explain(self, show_optimized: bool = False, simple=False) -> None:
print(builder.pretty_print(simple))

def num_partitions(self) -> int:
daft_config = get_context().daft_config
return self.__builder.to_physical_plan_scheduler(daft_config).num_partitions()
daft_execution_config = get_context().daft_execution_config
return self.__builder.to_physical_plan_scheduler(daft_execution_config).num_partitions()

@DataframePublicAPI
def schema(self) -> Schema:
Expand Down Expand Up @@ -319,6 +319,8 @@ def write_parquet(
.. NOTE::
This call is **blocking** and will execute the DataFrame when called
"""
io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config

cols: Optional[List[Expression]] = None
if partition_cols is not None:
cols = self.__column_input_to_expression(tuple(partition_cols))
Expand Down Expand Up @@ -365,6 +367,8 @@ def write_csv(
Returns:
DataFrame: The filenames that were written out as strings.
"""
io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config

cols: Optional[List[Expression]] = None
if partition_cols is not None:
cols = self.__column_input_to_expression(tuple(partition_cols))
Expand Down
9 changes: 2 additions & 7 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,6 @@ def download(
Returns:
Expression: a Binary expression which is the bytes contents of the URL, or None if an error occured during download
"""
from daft.io import IOConfig, S3Config

if use_native_downloader:

raise_on_error = False
Expand All @@ -464,11 +462,8 @@ def download(
# This is because the max parallelism is actually `min(S3Config's max_connections, url_download's max_connections)` under the hood.
# However, default max_connections on S3Config is only 8, and even if we specify 32 here we are bottlenecked there.
# Therefore for S3 downloads, we override `max_connections` kwarg to have the intended effect.
io_config = (
IOConfig(s3=S3Config(max_connections=max_connections))
if io_config is None
else io_config.replace(s3=io_config.s3.replace(max_connections=max_connections))
)
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
io_config = io_config.replace(s3=io_config.s3.replace(max_connections=max_connections))

using_ray_runner = context.get_context().is_ray_runner
return Expression._from_pyexpr(
Expand Down
3 changes: 3 additions & 0 deletions daft/io/_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Dict, List, Optional, Union

from daft import context
from daft.api_annotations import PublicAPI
from daft.daft import (
CsvSourceConfig,
Expand Down Expand Up @@ -66,6 +67,8 @@ def read_csv(
if isinstance(path, list) and len(path) == 0:
raise ValueError(f"Cannot read DataFrame from from empty list of CSV filepaths")

io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config

csv_config = CsvSourceConfig(
delimiter=delimiter,
has_headers=has_headers,
Expand Down
8 changes: 6 additions & 2 deletions daft/io/_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,12 @@ def read_iceberg(
) -> DataFrame:
from daft.iceberg.iceberg_scan import IcebergScanOperator

if io_config is None:
io_config = _convert_iceberg_file_io_properties_to_io_config(pyiceberg_table.io.properties)
io_config = (
_convert_iceberg_file_io_properties_to_io_config(pyiceberg_table.io.properties)
if io_config is None
else io_config
)
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config

multithreaded_io = not context.get_context().is_ray_runner
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config))
Expand Down
3 changes: 3 additions & 0 deletions daft/io/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Dict, List, Optional, Union

from daft import context
from daft.api_annotations import PublicAPI
from daft.daft import (
FileFormatConfig,
Expand Down Expand Up @@ -47,6 +48,8 @@ def read_json(
if isinstance(path, list) and len(path) == 0:
raise ValueError(f"Cannot read DataFrame from from empty list of JSON filepaths")

io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config

json_config = JsonSourceConfig(_buffer_size, _chunk_size)
file_format_config = FileFormatConfig.from_json_config(json_config)
if use_native_downloader:
Expand Down
1 change: 1 addition & 0 deletions daft/io/_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def read_parquet(
returns:
DataFrame: parsed DataFrame
"""
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config

if isinstance(path, list) and len(path) == 0:
raise ValueError(f"Cannot read DataFrame from from empty list of Parquet filepaths")
Expand Down
8 changes: 4 additions & 4 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder
from daft.daft import (
PartitionScheme,
PyDaftConfig,
PyDaftExecutionConfig,
ResourceRequest,
ScanOperatorHandle,
StorageConfig,
Expand All @@ -35,7 +35,7 @@ class LogicalPlanBuilder:
def __init__(self, builder: _LogicalPlanBuilder) -> None:
self._builder = builder

def to_physical_plan_scheduler(self, daft_config: PyDaftConfig) -> PhysicalPlanScheduler:
def to_physical_plan_scheduler(self, daft_execution_config: PyDaftExecutionConfig) -> PhysicalPlanScheduler:
"""
Convert the underlying logical plan to a physical plan scheduler, which is
used to generate executable tasks for the physical plan.
Expand All @@ -44,7 +44,7 @@ def to_physical_plan_scheduler(self, daft_config: PyDaftConfig) -> PhysicalPlanS
"""
from daft.plan_scheduler.physical_plan_scheduler import PhysicalPlanScheduler

return PhysicalPlanScheduler(self._builder.to_physical_plan_scheduler(daft_config))
return PhysicalPlanScheduler(self._builder.to_physical_plan_scheduler(daft_execution_config))

def schema(self) -> Schema:
"""
Expand Down Expand Up @@ -207,9 +207,9 @@ def write_tabular(
self,
root_dir: str | pathlib.Path,
file_format: FileFormat,
io_config: IOConfig,
partition_cols: list[Expression] | None = None,
compression: str | None = None,
io_config: IOConfig | None = None,
) -> LogicalPlanBuilder:
if file_format != FileFormat.Csv and file_format != FileFormat.Parquet:
raise ValueError(f"Writing is only supported for Parquet and CSV file formats, but got: {file_format}")
Expand Down
Loading

0 comments on commit 97753a2

Please sign in to comment.