Skip to content

Commit

Permalink
Put Runner into DaftContext since it now holds a copy of the DaftConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Dec 8, 2023
1 parent 144d89f commit 018748e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 61 deletions.
85 changes: 30 additions & 55 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from daft.daft import PyDaftConfig

if TYPE_CHECKING:
pass

from daft.runners.runner import Runner

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,27 +53,24 @@ def _get_runner_config_from_env() -> _RunnerConfig:
raise ValueError(f"Unsupported DAFT_RUNNER variable: {runner}")


# Global Runner singleton, initialized when accessed through the DaftContext
_RUNNER: Runner | None = None


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass()
class DaftContext:
"""Global context for the current Daft execution environment"""

daft_config: PyDaftConfig = PyDaftConfig()
runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env)
disallow_set_runner: bool = False
_runner: Runner | None = None

def runner(self) -> Runner:
global _RUNNER
if _RUNNER is not None:
return _RUNNER
if self._runner is not None:
return self._runner

if self.runner_config.name == "ray":
from daft.runners.ray_runner import RayRunner

assert isinstance(self.runner_config, _RayRunnerConfig)
_RUNNER = RayRunner(
self._runner = RayRunner(
daft_config=self.daft_config,
address=self.runner_config.address,
max_task_backlog=self.runner_config.max_task_backlog,
Expand All @@ -97,20 +92,16 @@ def runner(self) -> Runner:
pass

assert isinstance(self.runner_config, _PyRunnerConfig)
_RUNNER = PyRunner(use_thread_pool=self.runner_config.use_thread_pool)
self._runner = PyRunner(daft_config=self.daft_config, use_thread_pool=self.runner_config.use_thread_pool)

else:
raise NotImplementedError(f"Runner config implemented: {self.runner_config.name}")

# Mark DaftContext as having the runner set, which prevents any subsequent setting of the config
# after the runner has been initialized once
global _DaftContext
_DaftContext = dataclasses.replace(
_DaftContext,
disallow_set_runner=True,
)
self.disallow_set_runner = True

return _RUNNER
return self._runner

@property
def is_ray_runner(self) -> bool:
Expand All @@ -124,13 +115,7 @@ def get_context() -> DaftContext:
return _DaftContext


def _set_context(ctx: DaftContext):
global _DaftContext

_DaftContext = ctx


def _pop_context() -> DaftContext:
def pop_context() -> DaftContext:
"""Helper used in tests and test fixtures to clear the global runner and allow for re-setting of configs."""
global _DaftContext

Expand Down Expand Up @@ -163,24 +148,21 @@ def set_runner_ray(
Returns:
DaftContext: Daft context after setting the Ray runner
"""
old_ctx = get_context()
if old_ctx.disallow_set_runner:
ctx = get_context()
if ctx.disallow_set_runner:

Check warning on line 152 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L151-L152

Added lines #L151 - L152 were not covered by tests
if noop_if_initialized:
warnings.warn(
"Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call."
)
return old_ctx
return ctx

Check warning on line 157 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L157

Added line #L157 was not covered by tests
raise RuntimeError("Cannot set runner more than once")
new_ctx = dataclasses.replace(
old_ctx,
runner_config=_RayRunnerConfig(
address=address,
max_task_backlog=max_task_backlog,
),
disallow_set_runner=True,

ctx.runner_config = _RayRunnerConfig(

Check warning on line 160 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L160

Added line #L160 was not covered by tests
address=address,
max_task_backlog=max_task_backlog,
)
_set_context(new_ctx)
return new_ctx
ctx.disallow_set_runner = True
return ctx

Check warning on line 165 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L164-L165

Added lines #L164 - L165 were not covered by tests


def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext:
Expand All @@ -191,16 +173,13 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext:
Returns:
DaftContext: Daft context after setting the Py runner
"""
old_ctx = get_context()
if old_ctx.disallow_set_runner:
ctx = get_context()
if ctx.disallow_set_runner:

Check warning on line 177 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L176-L177

Added lines #L176 - L177 were not covered by tests
raise RuntimeError("Cannot set runner more than once")
new_ctx = dataclasses.replace(
old_ctx,
runner_config=_PyRunnerConfig(use_thread_pool=use_thread_pool),
disallow_set_runner=True,
)
_set_context(new_ctx)
return new_ctx

ctx.runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool)
ctx.disallow_set_runner = True
return ctx

Check warning on line 182 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L180-L182

Added lines #L180 - L182 were not covered by tests


def set_config(
Expand All @@ -220,23 +199,19 @@ def set_config(
Increasing this value will increase the upper bound of the size of merged ScanTasks, which leads to bigger but
fewer partitions. (Defaults to 512MB)
"""
old_ctx = get_context()
if old_ctx.disallow_set_runner:
ctx = get_context()
if ctx.disallow_set_runner:
raise RuntimeError(

Check warning on line 204 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L204

Added line #L204 was not covered by tests
"Cannot call `set_config` after the runner has already been created. "
"Please call `set_config` before any calls to set the runner and before any dataframe creation or execution."
)

# Replace values in the DaftConfig with user-specified overrides
old_daft_config = old_ctx.daft_config if config is None else config
old_daft_config = ctx.daft_config if config is None else config
new_daft_config = old_daft_config.with_config_values(
merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes,
)

new_ctx = dataclasses.replace(
old_ctx,
daft_config=new_daft_config,
)
_set_context(new_ctx)
return new_ctx
ctx.daft_config = new_daft_config
return ctx
9 changes: 4 additions & 5 deletions daft/runners/pyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

import psutil

from daft.context import get_context
from daft.daft import (
FileFormatConfig,
FileInfos,
IOConfig,
PyDaftConfig,
ResourceRequest,
StorageConfig,
)
Expand Down Expand Up @@ -105,8 +105,9 @@ def get_schema_from_first_filepath(


class PyRunner(Runner[MicroPartition]):
def __init__(self, use_thread_pool: bool | None) -> None:
def __init__(self, daft_config: PyDaftConfig, use_thread_pool: bool | None) -> None:
super().__init__()
self.daft_config = daft_config
self._use_thread_pool: bool = use_thread_pool if use_thread_pool is not None else True

self.num_cpus = multiprocessing.cpu_count()
Expand All @@ -132,13 +133,11 @@ def run_iter(
# NOTE: PyRunner does not run any async execution, so it ignores `results_buffer_size` which is essentially 0
results_buffer_size: int | None = None,
) -> Iterator[PyMaterializedResult]:
daft_config = get_context().daft_config

# Optimize the logical plan.
builder = builder.optimize()
# Finalize the logical plan and get a physical plan scheduler for translating the
# physical plan to executable tasks.
plan_scheduler = builder.to_physical_plan_scheduler(daft_config)
plan_scheduler = builder.to_physical_plan_scheduler(self.daft_config)
psets = {
key: entry.value.values()
for key, entry in self._part_set_cache._uuid_to_partition_set.items()
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def set_configs():
"""Sets global Daft config for testing"""

# Pop the old context, which gets rid of the old Runner as well
daft.context._pop_context()
daft.context.pop_context()

daft.context.set_config(
# Disables merging of ScanTasks
Expand Down
2 changes: 2 additions & 0 deletions tests/io/test_merge_scan_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ def override_merge_scan_tasks_configs(merge_scan_tasks_min_size_bytes: int, merg
original_merge_scan_tasks_max_size_bytes = config.merge_scan_tasks_max_size_bytes

try:
daft.context.pop_context()
daft.context.set_config(
merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes,
)
yield
finally:
daft.context.pop_context()
daft.context.set_config(
merge_scan_tasks_min_size_bytes=original_merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=original_merge_scan_tasks_max_size_bytes,
Expand Down

0 comments on commit 018748e

Please sign in to comment.