Skip to content

Commit

Permalink
protect global context with mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Feb 9, 2024
1 parent 1066ace commit d4554c6
Showing 1 changed file with 68 additions and 47 deletions.
115 changes: 68 additions & 47 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

logger = logging.getLogger(__name__)


import threading
class _RunnerConfig:
name = ClassVar[str]

Expand Down Expand Up @@ -68,7 +68,25 @@ class DaftContext:
disallow_set_runner: bool = False
_runner: Runner | None = None

_instance = None
_lock = threading.Lock()

def __new__(cls):
if cls._instance is None:
with cls._lock:
# Another thread could have created the instance
# before we acquired the lock. So check that the
# instance is still nonexistent.
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance


def runner(self) -> Runner:
with self._lock:
return self._get_runner()

def _get_runner(self) -> Runner:
if self._runner is not None:
return self._runner

Expand Down Expand Up @@ -110,16 +128,15 @@ def runner(self) -> Runner:

@property
def is_ray_runner(self) -> bool:
return isinstance(self.runner_config, _RayRunnerConfig)
with self._lock:
return isinstance(self.runner_config, _RayRunnerConfig)


_DaftContext = DaftContext()


def get_context() -> DaftContext:
return _DaftContext


def set_runner_ray(
address: str | None = None,
noop_if_initialized: bool = False,
Expand All @@ -144,20 +161,21 @@ def set_runner_ray(
DaftContext: Daft context after setting the Ray runner
"""
ctx = get_context()
if ctx.disallow_set_runner:
if noop_if_initialized:
warnings.warn(
"Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call."
)
return ctx
raise RuntimeError("Cannot set runner more than once")

ctx.runner_config = _RayRunnerConfig(
address=address,
max_task_backlog=max_task_backlog,
)
ctx.disallow_set_runner = True
return ctx
with ctx._lock:
if ctx.disallow_set_runner:
if noop_if_initialized:
warnings.warn(

Check warning on line 167 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L164-L167

Added lines #L164 - L167 were not covered by tests
"Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call."
)
return ctx
raise RuntimeError("Cannot set runner more than once")

Check warning on line 171 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L170-L171

Added lines #L170 - L171 were not covered by tests

ctx.runner_config = _RayRunnerConfig(

Check warning on line 173 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L173

Added line #L173 was not covered by tests
address=address,
max_task_backlog=max_task_backlog,
)
ctx.disallow_set_runner = True
return ctx

Check warning on line 178 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L177-L178

Added lines #L177 - L178 were not covered by tests


def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext:
Expand All @@ -169,12 +187,13 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext:
DaftContext: Daft context after setting the Py runner
"""
ctx = get_context()
if ctx.disallow_set_runner:
raise RuntimeError("Cannot set runner more than once")
with ctx._lock:
if ctx.disallow_set_runner:
raise RuntimeError("Cannot set runner more than once")

Check warning on line 192 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L190-L192

Added lines #L190 - L192 were not covered by tests

ctx.runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool)
ctx.disallow_set_runner = True
return ctx
ctx.runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool)
ctx.disallow_set_runner = True
return ctx

Check warning on line 196 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L194-L196

Added lines #L194 - L196 were not covered by tests


def set_planning_config(
Expand All @@ -192,13 +211,14 @@ def set_planning_config(
"""
# Replace values in the DaftPlanningConfig with user-specified overrides
ctx = get_context()
old_daft_planning_config = ctx.daft_planning_config if config is None else config
new_daft_planning_config = old_daft_planning_config.with_config_values(
default_io_config=default_io_config,
)
with ctx._lock:
old_daft_planning_config = ctx.daft_planning_config if config is None else config
new_daft_planning_config = old_daft_planning_config.with_config_values(

Check warning on line 216 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L214-L216

Added lines #L214 - L216 were not covered by tests
default_io_config=default_io_config,
)

ctx.daft_planning_config = new_daft_planning_config
return ctx
ctx.daft_planning_config = new_daft_planning_config
return ctx

Check warning on line 221 in daft/context.py

View check run for this annotation

Codecov / codecov/patch

daft/context.py#L220-L221

Added lines #L220 - L221 were not covered by tests


def set_execution_config(
Expand Down Expand Up @@ -246,21 +266,22 @@ def set_execution_config(
"""
# Replace values in the DaftExecutionConfig with user-specified overrides
ctx = get_context()
old_daft_execution_config = ctx.daft_execution_config if config is None else config
new_daft_execution_config = old_daft_execution_config.with_config_values(
scan_tasks_min_size_bytes=scan_tasks_min_size_bytes,
scan_tasks_max_size_bytes=scan_tasks_max_size_bytes,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
parquet_split_row_groups_max_files=parquet_split_row_groups_max_files,
sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries,
sample_size_for_sort=sample_size_for_sort,
num_preview_rows=num_preview_rows,
parquet_target_filesize=parquet_target_filesize,
parquet_target_row_group_size=parquet_target_row_group_size,
parquet_inflation_factor=parquet_inflation_factor,
csv_target_filesize=csv_target_filesize,
csv_inflation_factor=csv_inflation_factor,
)

ctx.daft_execution_config = new_daft_execution_config
return ctx
with ctx._lock:
old_daft_execution_config = ctx.daft_execution_config if config is None else config
new_daft_execution_config = old_daft_execution_config.with_config_values(
scan_tasks_min_size_bytes=scan_tasks_min_size_bytes,
scan_tasks_max_size_bytes=scan_tasks_max_size_bytes,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
parquet_split_row_groups_max_files=parquet_split_row_groups_max_files,
sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries,
sample_size_for_sort=sample_size_for_sort,
num_preview_rows=num_preview_rows,
parquet_target_filesize=parquet_target_filesize,
parquet_target_row_group_size=parquet_target_row_group_size,
parquet_inflation_factor=parquet_inflation_factor,
csv_target_filesize=csv_target_filesize,
csv_inflation_factor=csv_inflation_factor,
)

ctx.daft_execution_config = new_daft_execution_config
return ctx

0 comments on commit d4554c6

Please sign in to comment.