From d4554c6c80f2b28d73091a7a4717e4c8d270dd66 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Thu, 8 Feb 2024 20:12:34 -0800 Subject: [PATCH] protect global context with mutex --- daft/context.py | 115 ++++++++++++++++++++++++++++-------------------- 1 file changed, 68 insertions(+), 47 deletions(-) diff --git a/daft/context.py b/daft/context.py index 84e219e27f..d274fe49bf 100644 --- a/daft/context.py +++ b/daft/context.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) - +import threading class _RunnerConfig: name = ClassVar[str] @@ -68,7 +68,25 @@ class DaftContext: disallow_set_runner: bool = False _runner: Runner | None = None + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + # Another thread could have created the instance + # before we acquired the lock. So check that the + # instance is still nonexistent. + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + + def runner(self) -> Runner: + with self._lock: + return self._get_runner() + + def _get_runner(self) -> Runner: if self._runner is not None: return self._runner @@ -110,16 +128,15 @@ def runner(self) -> Runner: @property def is_ray_runner(self) -> bool: - return isinstance(self.runner_config, _RayRunnerConfig) + with self._lock: + return isinstance(self.runner_config, _RayRunnerConfig) _DaftContext = DaftContext() - def get_context() -> DaftContext: return _DaftContext - def set_runner_ray( address: str | None = None, noop_if_initialized: bool = False, @@ -144,20 +161,21 @@ def set_runner_ray( DaftContext: Daft context after setting the Ray runner """ ctx = get_context() - if ctx.disallow_set_runner: - if noop_if_initialized: - warnings.warn( - "Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call." - ) - return ctx - raise RuntimeError("Cannot set runner more than once") - - ctx.runner_config = _RayRunnerConfig( - address=address, - max_task_backlog=max_task_backlog, - ) - ctx.disallow_set_runner = True - return ctx + with ctx._lock: + if ctx.disallow_set_runner: + if noop_if_initialized: + warnings.warn( + "Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call." + ) + return ctx + raise RuntimeError("Cannot set runner more than once") + + ctx.runner_config = _RayRunnerConfig( + address=address, + max_task_backlog=max_task_backlog, + ) + ctx.disallow_set_runner = True + return ctx def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext: @@ -169,12 +187,13 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext: DaftContext: Daft context after setting the Py runner """ ctx = get_context() - if ctx.disallow_set_runner: - raise RuntimeError("Cannot set runner more than once") + with ctx._lock: + if ctx.disallow_set_runner: + raise RuntimeError("Cannot set runner more than once") - ctx.runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool) - ctx.disallow_set_runner = True - return ctx + ctx.runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool) + ctx.disallow_set_runner = True + return ctx def set_planning_config( @@ -192,13 +211,14 @@ def set_planning_config( """ # Replace values in the DaftPlanningConfig with user-specified overrides ctx = get_context() - old_daft_planning_config = ctx.daft_planning_config if config is None else config - new_daft_planning_config = old_daft_planning_config.with_config_values( - default_io_config=default_io_config, - ) + with ctx._lock: + old_daft_planning_config = ctx.daft_planning_config if config is None else config + new_daft_planning_config = old_daft_planning_config.with_config_values( + default_io_config=default_io_config, + ) - ctx.daft_planning_config = new_daft_planning_config - return ctx + ctx.daft_planning_config = new_daft_planning_config + return ctx def set_execution_config( @@ -246,21 +266,22 @@ def set_execution_config( """ # Replace values in the DaftExecutionConfig with user-specified overrides ctx = get_context() - old_daft_execution_config = ctx.daft_execution_config if config is None else config - new_daft_execution_config = old_daft_execution_config.with_config_values( - scan_tasks_min_size_bytes=scan_tasks_min_size_bytes, - scan_tasks_max_size_bytes=scan_tasks_max_size_bytes, - broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold, - parquet_split_row_groups_max_files=parquet_split_row_groups_max_files, - sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries, - sample_size_for_sort=sample_size_for_sort, - num_preview_rows=num_preview_rows, - parquet_target_filesize=parquet_target_filesize, - parquet_target_row_group_size=parquet_target_row_group_size, - parquet_inflation_factor=parquet_inflation_factor, - csv_target_filesize=csv_target_filesize, - csv_inflation_factor=csv_inflation_factor, - ) - - ctx.daft_execution_config = new_daft_execution_config - return ctx + with ctx._lock: + old_daft_execution_config = ctx.daft_execution_config if config is None else config + new_daft_execution_config = old_daft_execution_config.with_config_values( + scan_tasks_min_size_bytes=scan_tasks_min_size_bytes, + scan_tasks_max_size_bytes=scan_tasks_max_size_bytes, + broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold, + parquet_split_row_groups_max_files=parquet_split_row_groups_max_files, + sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries, + sample_size_for_sort=sample_size_for_sort, + num_preview_rows=num_preview_rows, + parquet_target_filesize=parquet_target_filesize, + parquet_target_row_group_size=parquet_target_row_group_size, + parquet_inflation_factor=parquet_inflation_factor, + csv_target_filesize=csv_target_filesize, + csv_inflation_factor=csv_inflation_factor, + ) + + ctx.daft_execution_config = new_daft_execution_config + return ctx