From 9b6cb946471b45a2b9f98e0ed233f624db348bdf Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Fri, 9 Feb 2024 12:58:03 -0800 Subject: [PATCH] [BUG] Protect Global Context With Mutex (#1857) * Fixes race condition with Daft is being ran with multiple threads and creates multiple Runners. This leads to the issue of cached partition sets being dropped. Like we see in https://github.com/Eventual-Inc/Daft/issues/1843 closes: https://github.com/Eventual-Inc/Daft/issues/1843 --- daft/context.py | 153 ++++++++++++++++++++++------------- daft/runners/partitioning.py | 28 ++++--- daft/runners/pyrunner.py | 6 +- daft/runners/ray_runner.py | 7 +- 4 files changed, 118 insertions(+), 76 deletions(-) diff --git a/daft/context.py b/daft/context.py index 84e219e27f..d12e34fd71 100644 --- a/daft/context.py +++ b/daft/context.py @@ -13,6 +13,8 @@ logger = logging.getLogger(__name__) +import threading + class _RunnerConfig: name = ClassVar[str] @@ -59,28 +61,60 @@ class DaftContext: # When a dataframe is executed, this config is copied into the Runner # which then keeps track of a per-unique-execution-ID copy of the config, using it consistently throughout the execution - daft_execution_config: PyDaftExecutionConfig = PyDaftExecutionConfig() + _daft_execution_config: PyDaftExecutionConfig = PyDaftExecutionConfig() # Non-execution calls (e.g. creation of a dataframe, logical plan building etc) directly reference values in this config - daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig() + _daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig() - runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env) - disallow_set_runner: bool = False + _runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env) + _disallow_set_runner: bool = False _runner: Runner | None = None + _instance: ClassVar[DaftContext | None] = None + _lock: ClassVar[threading.Lock] = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + # Another thread could have created the instance + # before we acquired the lock. So check that the + # instance is still nonexistent. + if not cls._instance: + cls._instance = super().__new__(cls) + return cls._instance + def runner(self) -> Runner: + with self._lock: + return self._get_runner() + + @property + def daft_execution_config(self) -> PyDaftExecutionConfig: + with self._lock: + return self._daft_execution_config + + @property + def daft_planning_config(self) -> PyDaftPlanningConfig: + with self._lock: + return self._daft_planning_config + + @property + def runner_config(self) -> _RunnerConfig: + with self._lock: + return self._runner_config + + def _get_runner(self) -> Runner: if self._runner is not None: return self._runner - if self.runner_config.name == "ray": + if self._runner_config.name == "ray": from daft.runners.ray_runner import RayRunner - assert isinstance(self.runner_config, _RayRunnerConfig) + assert isinstance(self._runner_config, _RayRunnerConfig) self._runner = RayRunner( - address=self.runner_config.address, - max_task_backlog=self.runner_config.max_task_backlog, + address=self._runner_config.address, + max_task_backlog=self._runner_config.max_task_backlog, ) - elif self.runner_config.name == "py": + elif self._runner_config.name == "py": from daft.runners.pyrunner import PyRunner try: @@ -96,21 +130,22 @@ def runner(self) -> Runner: except ImportError: pass - assert isinstance(self.runner_config, _PyRunnerConfig) - self._runner = PyRunner(use_thread_pool=self.runner_config.use_thread_pool) + assert isinstance(self._runner_config, _PyRunnerConfig) + self._runner = PyRunner(use_thread_pool=self._runner_config.use_thread_pool) else: - raise NotImplementedError(f"Runner config implemented: {self.runner_config.name}") + raise NotImplementedError(f"Runner config implemented: {self._runner_config.name}") # Mark DaftContext as having the runner set, which prevents any subsequent setting of the config # after the runner has been initialized once - self.disallow_set_runner = True + self._disallow_set_runner = True return self._runner @property def is_ray_runner(self) -> bool: - return isinstance(self.runner_config, _RayRunnerConfig) + with self._lock: + return isinstance(self._runner_config, _RayRunnerConfig) _DaftContext = DaftContext() @@ -144,20 +179,21 @@ def set_runner_ray( DaftContext: Daft context after setting the Ray runner """ ctx = get_context() - if ctx.disallow_set_runner: - if noop_if_initialized: - warnings.warn( - "Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call." - ) - return ctx - raise RuntimeError("Cannot set runner more than once") - - ctx.runner_config = _RayRunnerConfig( - address=address, - max_task_backlog=max_task_backlog, - ) - ctx.disallow_set_runner = True - return ctx + with ctx._lock: + if ctx._disallow_set_runner: + if noop_if_initialized: + warnings.warn( + "Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call." + ) + return ctx + raise RuntimeError("Cannot set runner more than once") + + ctx._runner_config = _RayRunnerConfig( + address=address, + max_task_backlog=max_task_backlog, + ) + ctx._disallow_set_runner = True + return ctx def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext: @@ -169,12 +205,13 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext: DaftContext: Daft context after setting the Py runner """ ctx = get_context() - if ctx.disallow_set_runner: - raise RuntimeError("Cannot set runner more than once") + with ctx._lock: + if ctx._disallow_set_runner: + raise RuntimeError("Cannot set runner more than once") - ctx.runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool) - ctx.disallow_set_runner = True - return ctx + ctx._runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool) + ctx._disallow_set_runner = True + return ctx def set_planning_config( @@ -192,13 +229,14 @@ def set_planning_config( """ # Replace values in the DaftPlanningConfig with user-specified overrides ctx = get_context() - old_daft_planning_config = ctx.daft_planning_config if config is None else config - new_daft_planning_config = old_daft_planning_config.with_config_values( - default_io_config=default_io_config, - ) + with ctx._lock: + old_daft_planning_config = ctx._daft_planning_config if config is None else config + new_daft_planning_config = old_daft_planning_config.with_config_values( + default_io_config=default_io_config, + ) - ctx.daft_planning_config = new_daft_planning_config - return ctx + ctx._daft_planning_config = new_daft_planning_config + return ctx def set_execution_config( @@ -246,21 +284,22 @@ def set_execution_config( """ # Replace values in the DaftExecutionConfig with user-specified overrides ctx = get_context() - old_daft_execution_config = ctx.daft_execution_config if config is None else config - new_daft_execution_config = old_daft_execution_config.with_config_values( - scan_tasks_min_size_bytes=scan_tasks_min_size_bytes, - scan_tasks_max_size_bytes=scan_tasks_max_size_bytes, - broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold, - parquet_split_row_groups_max_files=parquet_split_row_groups_max_files, - sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries, - sample_size_for_sort=sample_size_for_sort, - num_preview_rows=num_preview_rows, - parquet_target_filesize=parquet_target_filesize, - parquet_target_row_group_size=parquet_target_row_group_size, - parquet_inflation_factor=parquet_inflation_factor, - csv_target_filesize=csv_target_filesize, - csv_inflation_factor=csv_inflation_factor, - ) - - ctx.daft_execution_config = new_daft_execution_config - return ctx + with ctx._lock: + old_daft_execution_config = ctx._daft_execution_config if config is None else config + new_daft_execution_config = old_daft_execution_config.with_config_values( + scan_tasks_min_size_bytes=scan_tasks_min_size_bytes, + scan_tasks_max_size_bytes=scan_tasks_max_size_bytes, + broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold, + parquet_split_row_groups_max_files=parquet_split_row_groups_max_files, + sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries, + sample_size_for_sort=sample_size_for_sort, + num_preview_rows=num_preview_rows, + parquet_target_filesize=parquet_target_filesize, + parquet_target_row_group_size=parquet_target_row_group_size, + parquet_inflation_factor=parquet_inflation_factor, + csv_target_filesize=csv_target_filesize, + csv_inflation_factor=csv_inflation_factor, + ) + + ctx._daft_execution_config = new_daft_execution_config + return ctx diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index e43d6eaee1..8836a3bc5c 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +import threading import weakref from abc import abstractmethod from dataclasses import dataclass @@ -296,24 +297,33 @@ def size_bytes(self) -> int | None: class PartitionSetCache: def __init__(self) -> None: - self._uuid_to_partition_set: weakref.WeakValueDictionary[ + self.__uuid_to_partition_set: weakref.WeakValueDictionary[ str, PartitionCacheEntry ] = weakref.WeakValueDictionary() + self._lock = threading.Lock() def get_partition_set(self, pset_id: str) -> PartitionCacheEntry: - assert pset_id in self._uuid_to_partition_set - return self._uuid_to_partition_set[pset_id] + with self._lock: + assert pset_id in self.__uuid_to_partition_set + return self.__uuid_to_partition_set[pset_id] + + def get_all_partition_sets(self) -> dict[str, PartitionSet]: + with self._lock: + return {key: entry.value for key, entry in self.__uuid_to_partition_set.items() if entry.value is not None} def put_partition_set(self, pset: PartitionSet) -> PartitionCacheEntry: pset_id = uuid4().hex part_entry = PartitionCacheEntry(pset_id, pset) - self._uuid_to_partition_set[pset_id] = part_entry - return part_entry + with self._lock: + self.__uuid_to_partition_set[pset_id] = part_entry + return part_entry def rm(self, pset_id: str) -> None: - if pset_id in self._uuid_to_partition_set: - del self._uuid_to_partition_set[pset_id] + with self._lock: + if pset_id in self.__uuid_to_partition_set: + del self.__uuid_to_partition_set[pset_id] def clear(self) -> None: - del self._uuid_to_partition_set - self._uuid_to_partition_set = weakref.WeakValueDictionary() + with self._lock: + del self.__uuid_to_partition_set + self.__uuid_to_partition_set = weakref.WeakValueDictionary() diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 6b6a7e98ba..2be28e4c54 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -148,11 +148,7 @@ def run_iter( # Finalize the logical plan and get a physical plan scheduler for translating the # physical plan to executable tasks. plan_scheduler = builder.to_physical_plan_scheduler(daft_execution_config) - psets = { - key: entry.value.values() - for key, entry in self._part_set_cache._uuid_to_partition_set.items() - if entry.value is not None - } + psets = {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()} # Get executable tasks from planner. tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=False) with profiler("profile_PyRunner.run_{datetime.now().isoformat()}.json"): diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index d765a313d5..f8803995c0 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -727,11 +727,8 @@ def run_iter( # physical plan to executable tasks. plan_scheduler = builder.to_physical_plan_scheduler(daft_execution_config) - psets = { - key: entry.value.values() - for key, entry in self._part_set_cache._uuid_to_partition_set.items() - if entry.value is not None - } + psets = {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()} + result_uuid = str(uuid.uuid4()) if isinstance(self.ray_context, ray.client_builder.ClientContext): ray.get(