diff --git a/Cargo.lock b/Cargo.lock index 7b1912399b..ef09b60856 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -907,6 +907,7 @@ name = "common-daft-config" version = "0.1.10" dependencies = [ "bincode", + "common-io-config", "lazy_static", "pyo3", "serde", diff --git a/daft/__init__.py b/daft/__init__.py index baffd00a84..bdbb9f0fb5 100644 --- a/daft/__init__.py +++ b/daft/__init__.py @@ -54,6 +54,7 @@ def get_build_type() -> str: # Daft top-level imports ### +from daft.context import set_execution_config, set_planning_config from daft.convert import ( from_arrow, from_dask_dataframe, @@ -94,4 +95,6 @@ def get_build_type() -> str: "register_viz_hook", "udf", "ResourceRequest", + "set_planning_config", + "set_execution_config", ] diff --git a/daft/context.py b/daft/context.py index 8b1b16fe7a..d691511a47 100644 --- a/daft/context.py +++ b/daft/context.py @@ -6,7 +6,7 @@ import warnings from typing import TYPE_CHECKING, ClassVar -from daft.daft import PyDaftConfig +from daft.daft import IOConfig, PyDaftExecutionConfig, PyDaftPlanningConfig if TYPE_CHECKING: from daft.runners.runner import Runner @@ -57,7 +57,13 @@ def _get_runner_config_from_env() -> _RunnerConfig: class DaftContext: """Global context for the current Daft execution environment""" - daft_config: PyDaftConfig = PyDaftConfig() + # When a dataframe is executed, this config is copied into the Runner + # which then keeps track of a per-unique-execution-ID copy of the config, using it consistently throughout the execution + daft_execution_config: PyDaftExecutionConfig = PyDaftExecutionConfig() + + # Non-execution calls (e.g. creation of a dataframe, logical plan building etc) directly reference values in this config + daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig() + runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env) disallow_set_runner: bool = False _runner: Runner | None = None @@ -71,7 +77,6 @@ def runner(self) -> Runner: assert isinstance(self.runner_config, _RayRunnerConfig) self._runner = RayRunner( - daft_config=self.daft_config, address=self.runner_config.address, max_task_backlog=self.runner_config.max_task_backlog, ) @@ -92,7 +97,7 @@ def runner(self) -> Runner: pass assert isinstance(self.runner_config, _PyRunnerConfig) - self._runner = PyRunner(daft_config=self.daft_config, use_thread_pool=self.runner_config.use_thread_pool) + self._runner = PyRunner(use_thread_pool=self.runner_config.use_thread_pool) else: raise NotImplementedError(f"Runner config implemented: {self.runner_config.name}") @@ -115,25 +120,6 @@ def get_context() -> DaftContext: return _DaftContext -def set_context(ctx: DaftContext) -> DaftContext: - global _DaftContext - - pop_context() - _DaftContext = ctx - - return _DaftContext - - -def pop_context() -> DaftContext: - """Helper used in tests and test fixtures to clear the global runner and allow for re-setting of configs.""" - global _DaftContext - - old_daft_context = _DaftContext - _DaftContext = DaftContext() - - return old_daft_context - - def set_runner_ray( address: str | None = None, noop_if_initialized: bool = False, @@ -191,16 +177,41 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext: return ctx -def set_config( - config: PyDaftConfig | None = None, +def set_planning_config( + config: PyDaftPlanningConfig | None = None, + default_io_config: IOConfig | None = None, +) -> DaftContext: + """Globally sets varioous configuration parameters which control Daft plan construction behavior. These configuration values + are used when a Dataframe is being constructed (e.g. calls to create a Dataframe, or to build on an existing Dataframe) + + Args: + config: A PyDaftPlanningConfig object to set the config to, before applying other kwargs. Defaults to None which indicates + that the old (current) config should be used. + default_io_config: A default IOConfig to use in the absence of one being explicitly passed into any Expression (e.g. `.url.download()`) + or Dataframe operation (e.g. `daft.read_parquet()`). + """ + # Replace values in the DaftPlanningConfig with user-specified overrides + ctx = get_context() + old_daft_planning_config = ctx.daft_planning_config if config is None else config + new_daft_planning_config = old_daft_planning_config.with_config_values( + default_io_config=default_io_config, + ) + + ctx.daft_planning_config = new_daft_planning_config + return ctx + + +def set_execution_config( + config: PyDaftExecutionConfig | None = None, merge_scan_tasks_min_size_bytes: int | None = None, merge_scan_tasks_max_size_bytes: int | None = None, broadcast_join_size_bytes_threshold: int | None = None, ) -> DaftContext: - """Globally sets various configuration parameters which control various aspects of Daft execution + """Globally sets various configuration parameters which control various aspects of Daft execution. These configuration values + are used when a Dataframe is executed (e.g. calls to `.write_*`, `.collect()` or `.show()`) Args: - config: A PyDaftConfig object to set the config to, before applying other kwargs. Defaults to None which indicates + config: A PyDaftExecutionConfig object to set the config to, before applying other kwargs. Defaults to None which indicates that the old (current) config should be used. merge_scan_tasks_min_size_bytes: Minimum size in bytes when merging ScanTasks when reading files from storage. Increasing this value will make Daft perform more merging of files into a single partition before yielding, @@ -211,20 +222,14 @@ def set_config( broadcast_join_size_bytes_threshold: If one side of a join is smaller than this threshold, a broadcast join will be used. Default is 10 MiB. """ + # Replace values in the DaftExecutionConfig with user-specified overrides ctx = get_context() - if ctx._runner is not None: - raise RuntimeError( - "Cannot call `set_config` after the runner has already been created. " - "Please call `set_config` before any dataframe creation or execution." - ) - - # Replace values in the DaftConfig with user-specified overrides - old_daft_config = ctx.daft_config if config is None else config - new_daft_config = old_daft_config.with_config_values( + old_daft_execution_config = ctx.daft_execution_config if config is None else config + new_daft_execution_config = old_daft_execution_config.with_config_values( merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes, merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes, broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold, ) - ctx.daft_config = new_daft_config + ctx.daft_execution_config = new_daft_execution_config return ctx diff --git a/daft/daft.pyi b/daft/daft.pyi index 7fe6f6a4b4..eb796aebd9 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -494,7 +494,7 @@ class NativeStorageConfig: multithreaded_io: bool io_config: IOConfig - def __init__(self, multithreaded_io: bool, io_config: IOConfig | None = None): ... + def __init__(self, multithreaded_io: bool, io_config: IOConfig): ... class PythonStorageConfig: """ @@ -503,7 +503,7 @@ class PythonStorageConfig: io_config: IOConfig - def __init__(self, io_config: IOConfig | None = None): ... + def __init__(self, io_config: IOConfig): ... class StorageConfig: """ @@ -836,7 +836,7 @@ class PyExpr: def list_join(self, delimiter: PyExpr) -> PyExpr: ... def list_lengths(self) -> PyExpr: ... def url_download( - self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig | None = None + self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig ) -> PyExpr: ... def eq(expr1: PyExpr, expr2: PyExpr) -> bool: ... @@ -1094,23 +1094,31 @@ class LogicalPlanBuilder: ) -> LogicalPlanBuilder: ... def schema(self) -> PySchema: ... def optimize(self) -> LogicalPlanBuilder: ... - def to_physical_plan_scheduler(self, cfg: PyDaftConfig) -> PhysicalPlanScheduler: ... + def to_physical_plan_scheduler(self, cfg: PyDaftExecutionConfig) -> PhysicalPlanScheduler: ... def repr_ascii(self, simple: bool) -> str: ... -class PyDaftConfig: +class PyDaftExecutionConfig: def with_config_values( self, merge_scan_tasks_min_size_bytes: int | None = None, merge_scan_tasks_max_size_bytes: int | None = None, broadcast_join_size_bytes_threshold: int | None = None, - ) -> PyDaftConfig: ... + ) -> PyDaftExecutionConfig: ... @property - def merge_scan_tasks_min_size_bytes(self): ... + def merge_scan_tasks_min_size_bytes(self) -> int: ... @property def merge_scan_tasks_max_size_bytes(self): ... @property def broadcast_join_size_bytes_threshold(self): ... +class PyDaftPlanningConfig: + def with_config_values( + self, + default_io_config: IOConfig | None = None, + ) -> PyDaftPlanningConfig: ... + @property + def default_io_config(self) -> IOConfig: ... + def build_type() -> str: ... def version() -> str: ... def __getattr__(name) -> Any: ... diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index c4b729e039..1e191b6494 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -131,8 +131,8 @@ def explain(self, show_optimized: bool = False, simple=False) -> None: print(builder.pretty_print(simple)) def num_partitions(self) -> int: - daft_config = get_context().daft_config - return self.__builder.to_physical_plan_scheduler(daft_config).num_partitions() + daft_execution_config = get_context().daft_execution_config + return self.__builder.to_physical_plan_scheduler(daft_execution_config).num_partitions() @DataframePublicAPI def schema(self) -> Schema: @@ -319,6 +319,8 @@ def write_parquet( .. NOTE:: This call is **blocking** and will execute the DataFrame when called """ + io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config + cols: Optional[List[Expression]] = None if partition_cols is not None: cols = self.__column_input_to_expression(tuple(partition_cols)) @@ -365,6 +367,8 @@ def write_csv( Returns: DataFrame: The filenames that were written out as strings. """ + io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config + cols: Optional[List[Expression]] = None if partition_cols is not None: cols = self.__column_input_to_expression(tuple(partition_cols)) diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 33d3c7e499..a5e9baebf0 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -445,8 +445,6 @@ def download( Returns: Expression: a Binary expression which is the bytes contents of the URL, or None if an error occured during download """ - from daft.io import IOConfig, S3Config - if use_native_downloader: raise_on_error = False @@ -464,11 +462,8 @@ def download( # This is because the max parallelism is actually `min(S3Config's max_connections, url_download's max_connections)` under the hood. # However, default max_connections on S3Config is only 8, and even if we specify 32 here we are bottlenecked there. # Therefore for S3 downloads, we override `max_connections` kwarg to have the intended effect. - io_config = ( - IOConfig(s3=S3Config(max_connections=max_connections)) - if io_config is None - else io_config.replace(s3=io_config.s3.replace(max_connections=max_connections)) - ) + io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config + io_config = io_config.replace(s3=io_config.s3.replace(max_connections=max_connections)) using_ray_runner = context.get_context().is_ray_runner return Expression._from_pyexpr( diff --git a/daft/io/_csv.py b/daft/io/_csv.py index 9dcc561cba..9101f83872 100644 --- a/daft/io/_csv.py +++ b/daft/io/_csv.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Union +from daft import context from daft.api_annotations import PublicAPI from daft.daft import ( CsvSourceConfig, @@ -66,6 +67,8 @@ def read_csv( if isinstance(path, list) and len(path) == 0: raise ValueError(f"Cannot read DataFrame from from empty list of CSV filepaths") + io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config + csv_config = CsvSourceConfig( delimiter=delimiter, has_headers=has_headers, diff --git a/daft/io/_iceberg.py b/daft/io/_iceberg.py index 896a35b165..0d3b1af102 100644 --- a/daft/io/_iceberg.py +++ b/daft/io/_iceberg.py @@ -72,8 +72,12 @@ def read_iceberg( ) -> DataFrame: from daft.iceberg.iceberg_scan import IcebergScanOperator - if io_config is None: - io_config = _convert_iceberg_file_io_properties_to_io_config(pyiceberg_table.io.properties) + io_config = ( + _convert_iceberg_file_io_properties_to_io_config(pyiceberg_table.io.properties) + if io_config is None + else io_config + ) + io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config multithreaded_io = not context.get_context().is_ray_runner storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) diff --git a/daft/io/_json.py b/daft/io/_json.py index cb40bbea4c..309686c61b 100644 --- a/daft/io/_json.py +++ b/daft/io/_json.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Union +from daft import context from daft.api_annotations import PublicAPI from daft.daft import ( FileFormatConfig, @@ -47,6 +48,8 @@ def read_json( if isinstance(path, list) and len(path) == 0: raise ValueError(f"Cannot read DataFrame from from empty list of JSON filepaths") + io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config + json_config = JsonSourceConfig(_buffer_size, _chunk_size) file_format_config = FileFormatConfig.from_json_config(json_config) if use_native_downloader: diff --git a/daft/io/_parquet.py b/daft/io/_parquet.py index b7f6e64fe1..6b95723a45 100644 --- a/daft/io/_parquet.py +++ b/daft/io/_parquet.py @@ -46,6 +46,7 @@ def read_parquet( returns: DataFrame: parsed DataFrame """ + io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config if isinstance(path, list) and len(path) == 0: raise ValueError(f"Cannot read DataFrame from from empty list of Parquet filepaths") diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 5f9f59dc08..9d535151cb 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -14,7 +14,7 @@ from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder from daft.daft import ( PartitionScheme, - PyDaftConfig, + PyDaftExecutionConfig, ResourceRequest, ScanOperatorHandle, StorageConfig, @@ -35,7 +35,7 @@ class LogicalPlanBuilder: def __init__(self, builder: _LogicalPlanBuilder) -> None: self._builder = builder - def to_physical_plan_scheduler(self, daft_config: PyDaftConfig) -> PhysicalPlanScheduler: + def to_physical_plan_scheduler(self, daft_execution_config: PyDaftExecutionConfig) -> PhysicalPlanScheduler: """ Convert the underlying logical plan to a physical plan scheduler, which is used to generate executable tasks for the physical plan. @@ -44,7 +44,7 @@ def to_physical_plan_scheduler(self, daft_config: PyDaftConfig) -> PhysicalPlanS """ from daft.plan_scheduler.physical_plan_scheduler import PhysicalPlanScheduler - return PhysicalPlanScheduler(self._builder.to_physical_plan_scheduler(daft_config)) + return PhysicalPlanScheduler(self._builder.to_physical_plan_scheduler(daft_execution_config)) def schema(self) -> Schema: """ @@ -207,9 +207,9 @@ def write_tabular( self, root_dir: str | pathlib.Path, file_format: FileFormat, + io_config: IOConfig, partition_cols: list[Expression] | None = None, compression: str | None = None, - io_config: IOConfig | None = None, ) -> LogicalPlanBuilder: if file_format != FileFormat.Csv and file_format != FileFormat.Parquet: raise ValueError(f"Writing is only supported for Parquet and CSV file formats, but got: {file_format}") diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 34c9d5e599..6b6a7e98ba 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -8,11 +8,11 @@ import psutil +from daft.context import get_context from daft.daft import ( FileFormatConfig, FileInfos, IOConfig, - PyDaftConfig, ResourceRequest, StorageConfig, ) @@ -113,9 +113,8 @@ def get_schema_from_first_filepath( class PyRunner(Runner[MicroPartition]): - def __init__(self, daft_config: PyDaftConfig, use_thread_pool: bool | None) -> None: + def __init__(self, use_thread_pool: bool | None) -> None: super().__init__() - self.daft_config = daft_config self._use_thread_pool: bool = use_thread_pool if use_thread_pool is not None else True self.num_cpus = multiprocessing.cpu_count() @@ -141,11 +140,14 @@ def run_iter( # NOTE: PyRunner does not run any async execution, so it ignores `results_buffer_size` which is essentially 0 results_buffer_size: int | None = None, ) -> Iterator[PyMaterializedResult]: + # NOTE: Freeze and use this same execution config for the entire execution + daft_execution_config = get_context().daft_execution_config + # Optimize the logical plan. builder = builder.optimize() # Finalize the logical plan and get a physical plan scheduler for translating the # physical plan to executable tasks. - plan_scheduler = builder.to_physical_plan_scheduler(self.daft_config) + plan_scheduler = builder.to_physical_plan_scheduler(daft_execution_config) psets = { key: entry.value.values() for key, entry in self._part_set_cache._uuid_to_partition_set.items() diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 06fde27c76..7eacf32778 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -11,7 +11,7 @@ import pyarrow as pa -from daft.context import set_config +from daft.context import get_context, set_execution_config from daft.logical.builder import LogicalPlanBuilder from daft.plan_scheduler import PhysicalPlanScheduler from daft.runners.progress_bar import ProgressBar @@ -30,7 +30,7 @@ FileFormatConfig, FileInfos, IOConfig, - PyDaftConfig, + PyDaftExecutionConfig, ResourceRequest, StorageConfig, ) @@ -76,13 +76,10 @@ @ray.remote def _glob_path_into_file_infos( - daft_config: PyDaftConfig, paths: list[str], file_format_config: FileFormatConfig | None, io_config: IOConfig | None, ) -> MicroPartition: - set_config(daft_config) - file_infos = FileInfos() file_format = file_format_config.file_format() if file_format_config is not None else None for path in paths: @@ -95,9 +92,7 @@ def _glob_path_into_file_infos( @ray.remote -def _make_ray_block_from_vpartition(daft_config: PyDaftConfig, partition: MicroPartition) -> RayDatasetBlock: - set_config(daft_config) - +def _make_ray_block_from_vpartition(partition: MicroPartition) -> RayDatasetBlock: try: return partition.to_arrow(cast_tensors_to_ray_tensor_dtype=True) except pa.ArrowInvalid: @@ -106,20 +101,15 @@ def _make_ray_block_from_vpartition(daft_config: PyDaftConfig, partition: MicroP @ray.remote def _make_daft_partition_from_ray_dataset_blocks( - daft_config: PyDaftConfig, ray_dataset_block: pa.MicroPartition, daft_schema: Schema + ray_dataset_block: pa.MicroPartition, daft_schema: Schema ) -> MicroPartition: - set_config(daft_config) - return MicroPartition.from_arrow(ray_dataset_block) @ray.remote(num_returns=2) def _make_daft_partition_from_dask_dataframe_partitions( - daft_config: PyDaftConfig, dask_df_partition: pd.DataFrame, ) -> tuple[MicroPartition, pa.Schema]: - set_config(daft_config) - vpart = MicroPartition.from_pandas(dask_df_partition) return vpart, vpart.schema() @@ -138,21 +128,17 @@ def _to_pandas_ref(df: pd.DataFrame | ray.ObjectRef[pd.DataFrame]) -> ray.Object @ray.remote def sample_schema_from_filepath( - daft_config: PyDaftConfig, first_file_path: str, file_format_config: FileFormatConfig, storage_config: StorageConfig, ) -> Schema: """Ray remote function to run schema sampling on top of a MicroPartition containing a single filepath""" - set_config(daft_config) - # Currently just samples the Schema from the first file return runner_io.sample_schema(first_file_path, file_format_config, storage_config) @dataclass class RayPartitionSet(PartitionSet[ray.ObjectRef]): - _daft_config_objref: ray.ObjectRef _results: dict[PartID, RayMaterializedResult] def items(self) -> list[tuple[PartID, ray.ObjectRef]]: @@ -171,10 +157,7 @@ def to_ray_dataset(self) -> RayDataset: "Unable to import `ray.data.from_arrow_refs`. Please ensure that you have a compatible version of Ray >= 1.10 installed." ) - blocks = [ - _make_ray_block_from_vpartition.remote(self._daft_config_objref, self._results[k].partition()) - for k in self._results.keys() - ] + blocks = [_make_ray_block_from_vpartition.remote(self._results[k].partition()) for k in self._results.keys()] # NOTE: although the Ray method is called `from_arrow_refs`, this method works also when the blocks are List[T] types # instead of Arrow tables as the codepath for Dataset creation is the same. return from_arrow_refs(blocks) @@ -231,8 +214,7 @@ def wait(self) -> None: class RayRunnerIO(runner_io.RunnerIO): - def __init__(self, daft_config_objref: ray.ObjectRef, *args, **kwargs): - self.daft_config_objref = daft_config_objref + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def glob_paths_details( @@ -243,11 +225,7 @@ def glob_paths_details( ) -> FileInfos: # Synchronously fetch the file infos, for now. return FileInfos.from_table( - ray.get( - _glob_path_into_file_infos.remote( - self.daft_config_objref, source_paths, file_format_config, io_config=io_config - ) - ) + ray.get(_glob_path_into_file_infos.remote(source_paths, file_format_config, io_config=io_config)) .to_table() ._table ) @@ -264,7 +242,6 @@ def get_schema_from_first_filepath( first_path = file_infos[0].file_path return ray.get( sample_schema_from_filepath.remote( - self.daft_config_objref, first_path, file_format_config, storage_config, @@ -302,18 +279,11 @@ def partition_set_from_ray_dataset( # NOTE: This materializes the entire Ray Dataset - we could make this more intelligent by creating a new RayDatasetScan node # which can iterate on Ray Dataset blocks and materialize as-needed daft_vpartitions = [ - _make_daft_partition_from_ray_dataset_blocks.remote(self.daft_config_objref, block, daft_schema) - for block in block_refs + _make_daft_partition_from_ray_dataset_blocks.remote(block, daft_schema) for block in block_refs ] return ( - RayPartitionSet( - _daft_config_objref=self.daft_config_objref, - _results={ - i: RayMaterializedResult(obj, daft_config_objref=self.daft_config_objref) - for i, obj in enumerate(daft_vpartitions) - }, - ), + RayPartitionSet(_results={i: RayMaterializedResult(obj) for i, obj in enumerate(daft_vpartitions)}), daft_schema, ) @@ -329,9 +299,7 @@ def partition_set_from_dask_dataframe( raise ValueError("Can't convert an empty Dask DataFrame (with no partitions) to a Daft DataFrame.") persisted_partitions = dask.persist(*partitions, scheduler=ray_dask_get) parts = [_to_pandas_ref(next(iter(part.dask.values()))) for part in persisted_partitions] - daft_vpartitions, schemas = zip( - *(_make_daft_partition_from_dask_dataframe_partitions.remote(self.daft_config_objref, p) for p in parts) - ) + daft_vpartitions, schemas = zip(*(_make_daft_partition_from_dask_dataframe_partitions.remote(p) for p in parts)) schemas = ray.get(list(schemas)) # Dask shouldn't allow inconsistent schemas across partitions, but we double-check here. if not all(schemas[0] == schema for schema in schemas[1:]): @@ -340,13 +308,7 @@ def partition_set_from_dask_dataframe( schemas, ) return ( - RayPartitionSet( - _daft_config_objref=self.daft_config_objref, - _results={ - i: RayMaterializedResult(obj, daft_config_objref=self.daft_config_objref) - for i, obj in enumerate(daft_vpartitions) - }, - ), + RayPartitionSet(_results={i: RayMaterializedResult(obj) for i, obj in enumerate(daft_vpartitions)}), schemas[0], ) @@ -385,43 +347,42 @@ def build_partitions( @ray.remote def single_partition_pipeline( - daft_config: PyDaftConfig, instruction_stack: list[Instruction], *inputs: MicroPartition + daft_execution_config: PyDaftExecutionConfig, instruction_stack: list[Instruction], *inputs: MicroPartition ) -> list[list[PartitionMetadata] | MicroPartition]: - set_config(daft_config) + set_execution_config(daft_execution_config) return build_partitions(instruction_stack, *inputs) @ray.remote def fanout_pipeline( - daft_config: PyDaftConfig, instruction_stack: list[Instruction], *inputs: MicroPartition + daft_execution_config: PyDaftExecutionConfig, instruction_stack: list[Instruction], *inputs: MicroPartition ) -> list[list[PartitionMetadata] | MicroPartition]: - set_config(daft_config) + set_execution_config(daft_execution_config) return build_partitions(instruction_stack, *inputs) @ray.remote(scheduling_strategy="SPREAD") def reduce_pipeline( - daft_config: PyDaftConfig, instruction_stack: list[Instruction], inputs: list + daft_execution_config: PyDaftExecutionConfig, instruction_stack: list[Instruction], inputs: list ) -> list[list[PartitionMetadata] | MicroPartition]: import ray - set_config(daft_config) + set_execution_config(daft_execution_config) return build_partitions(instruction_stack, *ray.get(inputs)) @ray.remote(scheduling_strategy="SPREAD") def reduce_and_fanout( - daft_config: PyDaftConfig, instruction_stack: list[Instruction], inputs: list + daft_execution_config: PyDaftExecutionConfig, instruction_stack: list[Instruction], inputs: list ) -> list[list[PartitionMetadata] | MicroPartition]: import ray - set_config(daft_config) + set_execution_config(daft_execution_config) return build_partitions(instruction_stack, *ray.get(inputs)) @ray.remote -def get_metas(daft_config: PyDaftConfig, *partitions: MicroPartition) -> list[PartitionMetadata]: - set_config(daft_config) +def get_metas(*partitions: MicroPartition) -> list[PartitionMetadata]: return [PartitionMetadata.from_table(partition) for partition in partitions] @@ -448,7 +409,7 @@ def _ray_num_cpus_provider(ttl_seconds: int = 1) -> Generator[int, None, None]: class Scheduler: - def __init__(self, daft_config_objref: ray.ObjectRef, max_task_backlog: int | None, use_ray_tqdm: bool) -> None: + def __init__(self, max_task_backlog: int | None, use_ray_tqdm: bool) -> None: """ max_task_backlog: Max number of inflight tasks waiting for cores. """ @@ -465,7 +426,7 @@ def __init__(self, daft_config_objref: ray.ObjectRef, max_task_backlog: int | No self.reserved_cores = 0 - self.daft_config_objref = daft_config_objref + self.execution_configs_objref_by_df: dict[str, ray.ObjectRef] = dict() self.threads_by_df: dict[str, threading.Thread] = dict() self.results_by_df: dict[str, Queue] = {} self.active_by_df: dict[str, bool] = dict() @@ -492,8 +453,10 @@ def start_plan( plan_scheduler: PhysicalPlanScheduler, psets: dict[str, ray.ObjectRef], result_uuid: str, + daft_execution_config: PyDaftExecutionConfig, results_buffer_size: int | None = None, ) -> None: + self.execution_configs_objref_by_df[result_uuid] = ray.put(daft_execution_config) self.results_by_df[result_uuid] = Queue(maxsize=results_buffer_size or -1) self.active_by_df[result_uuid] = True @@ -532,6 +495,7 @@ def _run_plan( # Get executable tasks from plan scheduler. tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=True) + daft_execution_config = self.execution_configs_objref_by_df[result_uuid] inflight_tasks: dict[str, PartitionTask[ray.ObjectRef]] = dict() inflight_ref_to_task: dict[ray.ObjectRef, str] = dict() pbar = ProgressBar(use_ray_tqdm=self.use_ray_tqdm) @@ -585,10 +549,7 @@ def place_in_queue(item): logger.debug("Running task synchronously in main thread: %s", next_step) assert isinstance(next_step, SingleOutputPartitionTask) next_step.set_result( - [ - RayMaterializedResult(partition, daft_config_objref=self.daft_config_objref) - for partition in next_step.inputs - ] + [RayMaterializedResult(partition) for partition in next_step.inputs] ) next_step = next(tasks) @@ -608,7 +569,7 @@ def place_in_queue(item): break for task in tasks_to_dispatch: - results = _build_partitions(self.daft_config_objref, task) + results = _build_partitions(daft_execution_config, task) logger.debug("%s -> %s", task, results) inflight_tasks[task.id()] = task for result in results: @@ -686,7 +647,9 @@ def __init__(self, *n, **kw) -> None: self.reserved_cores = 1 -def _build_partitions(daft_config_objref: ray.ObjectRef, task: PartitionTask[ray.ObjectRef]) -> list[ray.ObjectRef]: +def _build_partitions( + daft_execution_config_objref: ray.ObjectRef, task: PartitionTask[ray.ObjectRef] +) -> list[ray.ObjectRef]: """Run a PartitionTask and return the resulting list of partitions.""" ray_options: dict[str, Any] = { "num_returns": task.num_results + 1, @@ -698,14 +661,16 @@ def _build_partitions(daft_config_objref: ray.ObjectRef, task: PartitionTask[ray if isinstance(task.instructions[0], ReduceInstruction): build_remote = reduce_and_fanout if isinstance(task.instructions[-1], FanoutInstruction) else reduce_pipeline build_remote = build_remote.options(**ray_options) - [metadatas_ref, *partitions] = build_remote.remote(daft_config_objref, task.instructions, task.inputs) + [metadatas_ref, *partitions] = build_remote.remote(daft_execution_config_objref, task.instructions, task.inputs) else: build_remote = ( fanout_pipeline if isinstance(task.instructions[-1], FanoutInstruction) else single_partition_pipeline ) build_remote = build_remote.options(**ray_options) - [metadatas_ref, *partitions] = build_remote.remote(daft_config_objref, task.instructions, *task.inputs) + [metadatas_ref, *partitions] = build_remote.remote( + daft_execution_config_objref, task.instructions, *task.inputs + ) metadatas_accessor = PartitionMetadataAccessor(metadatas_ref) task.set_result( @@ -725,7 +690,6 @@ def _build_partitions(daft_config_objref: ray.ObjectRef, task: PartitionTask[ray class RayRunner(Runner[ray.ObjectRef]): def __init__( self, - daft_config: PyDaftConfig, address: str | None, max_task_backlog: int | None, ) -> None: @@ -734,18 +698,16 @@ def __init__( logger.warning(f"Ray has already been initialized, Daft will reuse the existing Ray context.") self.ray_context = ray.init(address=address, ignore_reinit_error=True) - # We put a frozen copy of the Daft config into the cluster to be used across all subsequent Daft function calls - self.daft_config_objref = ray.put(daft_config) - self.daft_config = daft_config - if isinstance(self.ray_context, ray.client_builder.ClientContext): # Run scheduler remotely if the cluster is connected remotely. self.scheduler_actor = SchedulerActor.remote( # type: ignore - daft_config_objref=self.daft_config_objref, max_task_backlog=max_task_backlog, use_ray_tqdm=True + max_task_backlog=max_task_backlog, + use_ray_tqdm=True, ) else: self.scheduler = Scheduler( - daft_config_objref=self.daft_config_objref, max_task_backlog=max_task_backlog, use_ray_tqdm=False + max_task_backlog=max_task_backlog, + use_ray_tqdm=False, ) def active_plans(self) -> list[str]: @@ -757,12 +719,15 @@ def active_plans(self) -> list[str]: def run_iter( self, builder: LogicalPlanBuilder, results_buffer_size: int | None = None ) -> Iterator[RayMaterializedResult]: + # Grab and freeze the current DaftExecutionConfig + daft_execution_config = get_context().daft_execution_config + # Optimize the logical plan. builder = builder.optimize() # Finalize the logical plan and get a physical plan scheduler for translating the # physical plan to executable tasks. - plan_scheduler = builder.to_physical_plan_scheduler(self.daft_config) + plan_scheduler = builder.to_physical_plan_scheduler(daft_execution_config) psets = { key: entry.value.values() @@ -773,6 +738,7 @@ def run_iter( if isinstance(self.ray_context, ray.client_builder.ClientContext): ray.get( self.scheduler_actor.start_plan.remote( + daft_execution_config=daft_execution_config, plan_scheduler=plan_scheduler, psets=psets, result_uuid=result_uuid, @@ -781,6 +747,7 @@ def run_iter( ) else: self.scheduler.start_plan( + daft_execution_config=daft_execution_config, plan_scheduler=plan_scheduler, psets=psets, result_uuid=result_uuid, @@ -814,7 +781,7 @@ def run_iter_tables( yield ray.get(result.partition()) def run(self, builder: LogicalPlanBuilder) -> PartitionCacheEntry: - result_pset = RayPartitionSet(_daft_config_objref=self.daft_config_objref, _results={}) + result_pset = RayPartitionSet(_results={}) results_iter = self.run_iter(builder) @@ -828,9 +795,10 @@ def run(self, builder: LogicalPlanBuilder) -> PartitionCacheEntry: def put_partition_set_into_cache(self, pset: PartitionSet) -> PartitionCacheEntry: if isinstance(pset, LocalPartitionSet): pset = RayPartitionSet( - _daft_config_objref=self.daft_config_objref, _results={ - pid: RayMaterializedResult(ray.put(val), daft_config_objref=self.daft_config_objref) + pid: RayMaterializedResult( + ray.put(val), + ) for pid, val in pset._partitions.items() }, ) @@ -838,22 +806,20 @@ def put_partition_set_into_cache(self, pset: PartitionSet) -> PartitionCacheEntr return self._part_set_cache.put_partition_set(pset=pset) def runner_io(self) -> RayRunnerIO: - return RayRunnerIO(daft_config_objref=self.daft_config_objref) + return RayRunnerIO() class RayMaterializedResult(MaterializedResult[ray.ObjectRef]): def __init__( self, partition: ray.ObjectRef, - daft_config_objref: ray.ObjectRef | None = None, metadatas: PartitionMetadataAccessor | None = None, metadata_idx: int | None = None, ): self._partition = partition if metadatas is None: assert metadata_idx is None - assert daft_config_objref is not None - metadatas = PartitionMetadataAccessor(get_metas.remote(daft_config_objref, self._partition)) + metadatas = PartitionMetadataAccessor(get_metas.remote(self._partition)) metadata_idx = 0 self._metadatas = metadatas self._metadata_idx = metadata_idx diff --git a/docs/source/api_docs/context.rst b/docs/source/api_docs/context.rst index e7d94488c8..e97a3f9d88 100644 --- a/docs/source/api_docs/context.rst +++ b/docs/source/api_docs/context.rst @@ -12,3 +12,15 @@ Control the execution backend that Daft will run on by calling these functions o daft.context.set_runner_py daft.context.set_runner_ray + +Setting configurations +********************** + +Configure Daft in various ways during execution. + +.. autosummary:: + :nosignatures: + :toctree: doc_gen/configuration_functions + + daft.set_planning_config + daft.set_execution_config diff --git a/src/common/daft-config/Cargo.toml b/src/common/daft-config/Cargo.toml index 6d759e8d8e..959853b60f 100644 --- a/src/common/daft-config/Cargo.toml +++ b/src/common/daft-config/Cargo.toml @@ -1,12 +1,13 @@ [dependencies] bincode = {workspace = true} +common-io-config = {path = "../io-config", default-features = false} lazy_static = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true} [features] default = ["python"] -python = ["dep:pyo3"] +python = ["dep:pyo3", "common-io-config/python"] [package] edition = {workspace = true} diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 7134f65893..aa6ca5b575 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -1,17 +1,37 @@ +use common_io_config::IOConfig; use serde::{Deserialize, Serialize}; +/// Configurations for Daft to use during the building of a Dataframe's plan. +/// +/// 1. Creation of a Dataframe including any file listing and schema inference that needs to happen. Note +/// that this does not include the actual scan, which is taken care of by the DaftExecutionConfig. +/// 2. Building of logical plan nodes +#[derive(Clone, Serialize, Deserialize, Default)] +pub struct DaftPlanningConfig { + pub default_io_config: IOConfig, +} + +/// Configurations for Daft to use during the execution of a Dataframe +/// Note that this should be immutable for a given end-to-end execution of a logical plan. +/// +/// Execution entails everything that happens when a Dataframe `.collect()`, `.show()` or similar is called: +/// 1. Logical plan optimization +/// 2. Logical-to-physical-plan translation +/// 3. Task generation from physical plan +/// 4. Task scheduling +/// 5. Task local execution #[derive(Clone, Serialize, Deserialize)] -pub struct DaftConfig { +pub struct DaftExecutionConfig { pub merge_scan_tasks_min_size_bytes: usize, pub merge_scan_tasks_max_size_bytes: usize, pub broadcast_join_size_bytes_threshold: usize, } -impl Default for DaftConfig { +impl Default for DaftExecutionConfig { fn default() -> Self { - DaftConfig { - merge_scan_tasks_min_size_bytes: 64 * 1024 * 1024, // 64 MiB - merge_scan_tasks_max_size_bytes: 512 * 1024 * 1024, // 512 MiB + DaftExecutionConfig { + merge_scan_tasks_min_size_bytes: 64 * 1024 * 1024, // 64MB + merge_scan_tasks_max_size_bytes: 512 * 1024 * 1024, // 512MB broadcast_join_size_bytes_threshold: 10 * 1024 * 1024, // 10 MiB } } @@ -21,14 +41,15 @@ impl Default for DaftConfig { mod python; #[cfg(feature = "python")] -pub use python::PyDaftConfig; +pub use python::PyDaftExecutionConfig; #[cfg(feature = "python")] use pyo3::prelude::*; #[cfg(feature = "python")] pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { - parent.add_class::()?; + parent.add_class::()?; + parent.add_class::()?; Ok(()) } diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index 353858152c..7fc7359ffb 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -3,19 +3,76 @@ use std::sync::Arc; use pyo3::{prelude::*, PyTypeInfo}; use serde::{Deserialize, Serialize}; -use crate::DaftConfig; +use crate::{DaftExecutionConfig, DaftPlanningConfig}; +use common_io_config::python::IOConfig as PyIOConfig; #[derive(Clone, Default, Serialize, Deserialize)] #[pyclass(module = "daft.daft")] -pub struct PyDaftConfig { - pub config: Arc, +pub struct PyDaftPlanningConfig { + pub config: Arc, } #[pymethods] -impl PyDaftConfig { +impl PyDaftPlanningConfig { #[new] pub fn new() -> Self { - PyDaftConfig::default() + PyDaftPlanningConfig::default() + } + + fn with_config_values( + &mut self, + default_io_config: Option, + ) -> PyResult { + let mut config = self.config.as_ref().clone(); + + if let Some(default_io_config) = default_io_config { + config.default_io_config = default_io_config.config; + } + + Ok(PyDaftPlanningConfig { + config: Arc::new(config), + }) + } + + #[getter(default_io_config)] + fn default_io_config(&self) -> PyResult { + Ok(PyIOConfig { + config: self.config.default_io_config.clone(), + }) + } + + fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (Vec,))> { + let bin_data = bincode::serialize(self.config.as_ref()) + .expect("DaftPlanningConfig should be serializable to bytes"); + Ok(( + Self::type_object(py) + .getattr("_from_serialized")? + .to_object(py), + (bin_data,), + )) + } + + #[staticmethod] + fn _from_serialized(bin_data: Vec) -> PyResult { + let daft_planning_config: DaftPlanningConfig = bincode::deserialize(bin_data.as_slice()) + .expect("DaftExecutionConfig should be deserializable from bytes"); + Ok(PyDaftPlanningConfig { + config: daft_planning_config.into(), + }) + } +} + +#[derive(Clone, Default, Serialize, Deserialize)] +#[pyclass(module = "daft.daft")] +pub struct PyDaftExecutionConfig { + pub config: Arc, +} + +#[pymethods] +impl PyDaftExecutionConfig { + #[new] + pub fn new() -> Self { + PyDaftExecutionConfig::default() } fn with_config_values( @@ -23,7 +80,7 @@ impl PyDaftConfig { merge_scan_tasks_min_size_bytes: Option, merge_scan_tasks_max_size_bytes: Option, broadcast_join_size_bytes_threshold: Option, - ) -> PyResult { + ) -> PyResult { let mut config = self.config.as_ref().clone(); if let Some(merge_scan_tasks_max_size_bytes) = merge_scan_tasks_max_size_bytes { @@ -36,7 +93,7 @@ impl PyDaftConfig { config.broadcast_join_size_bytes_threshold = broadcast_join_size_bytes_threshold; } - Ok(PyDaftConfig { + Ok(PyDaftExecutionConfig { config: Arc::new(config), }) } @@ -58,7 +115,7 @@ impl PyDaftConfig { fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (Vec,))> { let bin_data = bincode::serialize(self.config.as_ref()) - .expect("DaftConfig should be serializable to bytes"); + .expect("DaftExecutionConfig should be serializable to bytes"); Ok(( Self::type_object(py) .getattr("_from_serialized")? @@ -68,11 +125,11 @@ impl PyDaftConfig { } #[staticmethod] - fn _from_serialized(bin_data: Vec) -> PyResult { - let daft_config: DaftConfig = bincode::deserialize(bin_data.as_slice()) - .expect("DaftConfig should be deserializable from bytes"); - Ok(PyDaftConfig { - config: daft_config.into(), + fn _from_serialized(bin_data: Vec) -> PyResult { + let daft_execution_config: DaftExecutionConfig = bincode::deserialize(bin_data.as_slice()) + .expect("DaftExecutionConfig should be deserializable from bytes"); + Ok(PyDaftExecutionConfig { + config: daft_execution_config.into(), }) } } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 202e4e624d..ba1a82c162 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -378,7 +378,7 @@ impl PyExpr { max_connections: i64, raise_error_on_failure: bool, multi_thread: bool, - config: Option, + config: PyIOConfig, ) -> PyResult { if max_connections <= 0 { return Err(PyValueError::new_err(format!( @@ -391,7 +391,7 @@ impl PyExpr { max_connections as usize, raise_error_on_failure, multi_thread, - config.map(|c| c.config), + Some(config.config), ) .into()) } diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 874871c4fc..1ebb2b6dd3 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -26,7 +26,7 @@ use daft_scan::{ #[cfg(feature = "python")] use { crate::{physical_plan::PhysicalPlan, source_info::InMemoryInfo}, - common_daft_config::PyDaftConfig, + common_daft_config::PyDaftExecutionConfig, daft_core::python::schema::PySchema, daft_dsl::python::PyExpr, daft_scan::{file_format::PyFileFormatConfig, python::pylib::ScanOperatorHandle}, @@ -474,7 +474,7 @@ impl PyLogicalPlanBuilder { pub fn to_physical_plan_scheduler( &self, py: Python, - cfg: PyDaftConfig, + cfg: PyDaftExecutionConfig, ) -> PyResult { py.allow_threads(|| { let logical_plan = self.builder.build(); diff --git a/src/daft-plan/src/physical_ops/project.rs b/src/daft-plan/src/physical_ops/project.rs index 093ba9ea8c..0b34874fba 100644 --- a/src/daft-plan/src/physical_ops/project.rs +++ b/src/daft-plan/src/physical_ops/project.rs @@ -186,7 +186,7 @@ impl Project { #[cfg(test)] mod tests { - use common_daft_config::DaftConfig; + use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use daft_core::{datatypes::Field, DataType}; use daft_dsl::{col, lit, Expr}; @@ -198,7 +198,7 @@ mod tests { /// do not destroy the partition spec. #[test] fn test_partition_spec_preserving() -> DaftResult<()> { - let cfg = DaftConfig::default().into(); + let cfg = DaftExecutionConfig::default().into(); let expressions = vec![ (col("a") % lit(2)), // this is now "a" col("b"), @@ -242,7 +242,7 @@ mod tests { )] projection: Vec, ) -> DaftResult<()> { - let cfg = DaftConfig::default().into(); + let cfg = DaftExecutionConfig::default().into(); let logical_plan = dummy_scan_node(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), @@ -271,7 +271,7 @@ mod tests { /// i.e. ("a", "a" as "b") remains partitioned by "a", not "b" #[test] fn test_partition_spec_prefer_existing_names() -> DaftResult<()> { - let cfg = DaftConfig::default().into(); + let cfg = DaftExecutionConfig::default().into(); let expressions = vec![col("a").alias("y"), col("a"), col("a").alias("z"), col("b")]; let logical_plan = dummy_scan_node(vec![ diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index c59bcdd944..ba43518b73 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -2,7 +2,7 @@ use std::cmp::Ordering; use std::sync::Arc; use std::{cmp::max, collections::HashMap}; -use common_daft_config::DaftConfig; +use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use daft_core::count_mode::CountMode; use daft_dsl::Expr; @@ -26,7 +26,7 @@ use crate::{FileFormat, PartitionScheme}; use crate::physical_ops::InMemoryScan; /// Translate a logical plan to a physical plan. -pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult { +pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult { match logical_plan { LogicalPlan::Source(Source { output_schema, @@ -602,7 +602,7 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftResult upstream_op #[test] fn repartition_dropped_redundant_into_partitions() -> DaftResult<()> { - let cfg: Arc = DaftConfig::default().into(); + let cfg: Arc = DaftExecutionConfig::default().into(); // dummy_scan_node() will create the default PartitionSpec, which only has a single partition. let builder = dummy_scan_node(vec![ Field::new("a", DataType::Int64), @@ -646,7 +646,7 @@ mod tests { /// Repartition-upstream_op -> upstream_op #[test] fn repartition_dropped_single_partition() -> DaftResult<()> { - let cfg: Arc = DaftConfig::default().into(); + let cfg: Arc = DaftExecutionConfig::default().into(); // dummy_scan_node() will create the default PartitionSpec, which only has a single partition. let builder = dummy_scan_node(vec![ Field::new("a", DataType::Int64), @@ -671,7 +671,7 @@ mod tests { /// Repartition-upstream_op -> upstream_op #[test] fn repartition_dropped_same_partition_spec() -> DaftResult<()> { - let cfg = DaftConfig::default().into(); + let cfg = DaftExecutionConfig::default().into(); let logical_plan = dummy_scan_node(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Utf8), @@ -691,7 +691,7 @@ mod tests { /// Repartition-Aggregation -> Aggregation #[test] fn repartition_dropped_same_partition_spec_agg() -> DaftResult<()> { - let cfg = DaftConfig::default().into(); + let cfg = DaftExecutionConfig::default().into(); let logical_plan = dummy_scan_node(vec![ Field::new("a", DataType::Int64), Field::new("b", DataType::Int64), diff --git a/tests/conftest.py b/tests/conftest.py index 6516864390..4696998424 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,13 +11,9 @@ @pytest.fixture(scope="session", autouse=True) -def set_configs(): +def set_execution_configs(): """Sets global Daft config for testing""" - - # Pop the old context, which gets rid of the old Runner as well - daft.context.pop_context() - - daft.context.set_config( + daft.set_execution_config( # Disables merging of ScanTasks merge_scan_tasks_min_size_bytes=0, merge_scan_tasks_max_size_bytes=0, diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index a538bb5c53..5c1f61539c 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -13,11 +13,15 @@ def broadcast_join_enabled(request): # Toggles between default broadcast join threshold (10 MiB), and a threshold of 0, which disables broadcast joins. broadcast_threshold = 10 * 1024 * 1024 if request.param else 0 - old_context = daft.context.pop_context() + + old_execution_config = daft.context.get_context().daft_execution_config try: - yield daft.context.set_config(broadcast_join_size_bytes_threshold=broadcast_threshold) + daft.set_execution_config( + broadcast_join_size_bytes_threshold=broadcast_threshold, + ) + yield finally: - daft.context.set_context(old_context) + daft.set_execution_config(old_execution_config) @pytest.mark.parametrize("n_partitions", [1, 2, 4]) diff --git a/tests/io/test_merge_scan_tasks.py b/tests/io/test_merge_scan_tasks.py index 3ed35eb467..f4f2e39134 100644 --- a/tests/io/test_merge_scan_tasks.py +++ b/tests/io/test_merge_scan_tasks.py @@ -10,16 +10,16 @@ @contextlib.contextmanager def override_merge_scan_tasks_configs(merge_scan_tasks_min_size_bytes: int, merge_scan_tasks_max_size_bytes: int): - old_context = daft.context.pop_context() + old_execution_config = daft.context.get_context().daft_execution_config try: - daft.context.set_config( + daft.set_execution_config( merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes, merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes, ) yield finally: - daft.context.set_context(old_context) + daft.set_execution_config(old_execution_config) @pytest.fixture(scope="function")