Skip to content

Commit

Permalink
[BUG] Protect Global Context With Mutex (#1857)
Browse files Browse the repository at this point in the history
* Fixes race condition with Daft is being ran with multiple threads and
creates multiple Runners. This leads to the issue of cached partition
sets being dropped. Like we see in
#1843
closes: #1843
  • Loading branch information
samster25 authored Feb 9, 2024
1 parent 573ea12 commit 9b6cb94
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 76 deletions.
153 changes: 96 additions & 57 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

logger = logging.getLogger(__name__)

import threading


class _RunnerConfig:
name = ClassVar[str]
Expand Down Expand Up @@ -59,28 +61,60 @@ class DaftContext:

# When a dataframe is executed, this config is copied into the Runner
# which then keeps track of a per-unique-execution-ID copy of the config, using it consistently throughout the execution
daft_execution_config: PyDaftExecutionConfig = PyDaftExecutionConfig()
_daft_execution_config: PyDaftExecutionConfig = PyDaftExecutionConfig()

# Non-execution calls (e.g. creation of a dataframe, logical plan building etc) directly reference values in this config
daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig()
_daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig()

runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env)
disallow_set_runner: bool = False
_runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env)
_disallow_set_runner: bool = False
_runner: Runner | None = None

_instance: ClassVar[DaftContext | None] = None
_lock: ClassVar[threading.Lock] = threading.Lock()

def __new__(cls):
if cls._instance is None:
with cls._lock:
# Another thread could have created the instance
# before we acquired the lock. So check that the
# instance is still nonexistent.
if not cls._instance:
cls._instance = super().__new__(cls)
return cls._instance

def runner(self) -> Runner:
with self._lock:
return self._get_runner()

@property
def daft_execution_config(self) -> PyDaftExecutionConfig:
with self._lock:
return self._daft_execution_config

@property
def daft_planning_config(self) -> PyDaftPlanningConfig:
with self._lock:
return self._daft_planning_config

@property
def runner_config(self) -> _RunnerConfig:
with self._lock:
return self._runner_config

def _get_runner(self) -> Runner:
if self._runner is not None:
return self._runner

if self.runner_config.name == "ray":
if self._runner_config.name == "ray":
from daft.runners.ray_runner import RayRunner

assert isinstance(self.runner_config, _RayRunnerConfig)
assert isinstance(self._runner_config, _RayRunnerConfig)
self._runner = RayRunner(
address=self.runner_config.address,
max_task_backlog=self.runner_config.max_task_backlog,
address=self._runner_config.address,
max_task_backlog=self._runner_config.max_task_backlog,
)
elif self.runner_config.name == "py":
elif self._runner_config.name == "py":
from daft.runners.pyrunner import PyRunner

try:
Expand All @@ -96,21 +130,22 @@ def runner(self) -> Runner:
except ImportError:
pass

assert isinstance(self.runner_config, _PyRunnerConfig)
self._runner = PyRunner(use_thread_pool=self.runner_config.use_thread_pool)
assert isinstance(self._runner_config, _PyRunnerConfig)
self._runner = PyRunner(use_thread_pool=self._runner_config.use_thread_pool)

else:
raise NotImplementedError(f"Runner config implemented: {self.runner_config.name}")
raise NotImplementedError(f"Runner config implemented: {self._runner_config.name}")

# Mark DaftContext as having the runner set, which prevents any subsequent setting of the config
# after the runner has been initialized once
self.disallow_set_runner = True
self._disallow_set_runner = True

return self._runner

@property
def is_ray_runner(self) -> bool:
return isinstance(self.runner_config, _RayRunnerConfig)
with self._lock:
return isinstance(self._runner_config, _RayRunnerConfig)


_DaftContext = DaftContext()
Expand Down Expand Up @@ -144,20 +179,21 @@ def set_runner_ray(
DaftContext: Daft context after setting the Ray runner
"""
ctx = get_context()
if ctx.disallow_set_runner:
if noop_if_initialized:
warnings.warn(
"Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call."
)
return ctx
raise RuntimeError("Cannot set runner more than once")

ctx.runner_config = _RayRunnerConfig(
address=address,
max_task_backlog=max_task_backlog,
)
ctx.disallow_set_runner = True
return ctx
with ctx._lock:
if ctx._disallow_set_runner:
if noop_if_initialized:
warnings.warn(
"Calling daft.context.set_runner_ray(noop_if_initialized=True) multiple times has no effect beyond the first call."
)
return ctx
raise RuntimeError("Cannot set runner more than once")

ctx._runner_config = _RayRunnerConfig(
address=address,
max_task_backlog=max_task_backlog,
)
ctx._disallow_set_runner = True
return ctx


def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext:
Expand All @@ -169,12 +205,13 @@ def set_runner_py(use_thread_pool: bool | None = None) -> DaftContext:
DaftContext: Daft context after setting the Py runner
"""
ctx = get_context()
if ctx.disallow_set_runner:
raise RuntimeError("Cannot set runner more than once")
with ctx._lock:
if ctx._disallow_set_runner:
raise RuntimeError("Cannot set runner more than once")

ctx.runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool)
ctx.disallow_set_runner = True
return ctx
ctx._runner_config = _PyRunnerConfig(use_thread_pool=use_thread_pool)
ctx._disallow_set_runner = True
return ctx


def set_planning_config(
Expand All @@ -192,13 +229,14 @@ def set_planning_config(
"""
# Replace values in the DaftPlanningConfig with user-specified overrides
ctx = get_context()
old_daft_planning_config = ctx.daft_planning_config if config is None else config
new_daft_planning_config = old_daft_planning_config.with_config_values(
default_io_config=default_io_config,
)
with ctx._lock:
old_daft_planning_config = ctx._daft_planning_config if config is None else config
new_daft_planning_config = old_daft_planning_config.with_config_values(
default_io_config=default_io_config,
)

ctx.daft_planning_config = new_daft_planning_config
return ctx
ctx._daft_planning_config = new_daft_planning_config
return ctx


def set_execution_config(
Expand Down Expand Up @@ -246,21 +284,22 @@ def set_execution_config(
"""
# Replace values in the DaftExecutionConfig with user-specified overrides
ctx = get_context()
old_daft_execution_config = ctx.daft_execution_config if config is None else config
new_daft_execution_config = old_daft_execution_config.with_config_values(
scan_tasks_min_size_bytes=scan_tasks_min_size_bytes,
scan_tasks_max_size_bytes=scan_tasks_max_size_bytes,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
parquet_split_row_groups_max_files=parquet_split_row_groups_max_files,
sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries,
sample_size_for_sort=sample_size_for_sort,
num_preview_rows=num_preview_rows,
parquet_target_filesize=parquet_target_filesize,
parquet_target_row_group_size=parquet_target_row_group_size,
parquet_inflation_factor=parquet_inflation_factor,
csv_target_filesize=csv_target_filesize,
csv_inflation_factor=csv_inflation_factor,
)

ctx.daft_execution_config = new_daft_execution_config
return ctx
with ctx._lock:
old_daft_execution_config = ctx._daft_execution_config if config is None else config
new_daft_execution_config = old_daft_execution_config.with_config_values(
scan_tasks_min_size_bytes=scan_tasks_min_size_bytes,
scan_tasks_max_size_bytes=scan_tasks_max_size_bytes,
broadcast_join_size_bytes_threshold=broadcast_join_size_bytes_threshold,
parquet_split_row_groups_max_files=parquet_split_row_groups_max_files,
sort_merge_join_sort_with_aligned_boundaries=sort_merge_join_sort_with_aligned_boundaries,
sample_size_for_sort=sample_size_for_sort,
num_preview_rows=num_preview_rows,
parquet_target_filesize=parquet_target_filesize,
parquet_target_row_group_size=parquet_target_row_group_size,
parquet_inflation_factor=parquet_inflation_factor,
csv_target_filesize=csv_target_filesize,
csv_inflation_factor=csv_inflation_factor,
)

ctx._daft_execution_config = new_daft_execution_config
return ctx
28 changes: 19 additions & 9 deletions daft/runners/partitioning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sys
import threading
import weakref
from abc import abstractmethod
from dataclasses import dataclass
Expand Down Expand Up @@ -296,24 +297,33 @@ def size_bytes(self) -> int | None:

class PartitionSetCache:
def __init__(self) -> None:
self._uuid_to_partition_set: weakref.WeakValueDictionary[
self.__uuid_to_partition_set: weakref.WeakValueDictionary[
str, PartitionCacheEntry
] = weakref.WeakValueDictionary()
self._lock = threading.Lock()

def get_partition_set(self, pset_id: str) -> PartitionCacheEntry:
assert pset_id in self._uuid_to_partition_set
return self._uuid_to_partition_set[pset_id]
with self._lock:
assert pset_id in self.__uuid_to_partition_set
return self.__uuid_to_partition_set[pset_id]

def get_all_partition_sets(self) -> dict[str, PartitionSet]:
with self._lock:
return {key: entry.value for key, entry in self.__uuid_to_partition_set.items() if entry.value is not None}

def put_partition_set(self, pset: PartitionSet) -> PartitionCacheEntry:
pset_id = uuid4().hex
part_entry = PartitionCacheEntry(pset_id, pset)
self._uuid_to_partition_set[pset_id] = part_entry
return part_entry
with self._lock:
self.__uuid_to_partition_set[pset_id] = part_entry
return part_entry

def rm(self, pset_id: str) -> None:
if pset_id in self._uuid_to_partition_set:
del self._uuid_to_partition_set[pset_id]
with self._lock:
if pset_id in self.__uuid_to_partition_set:
del self.__uuid_to_partition_set[pset_id]

def clear(self) -> None:
del self._uuid_to_partition_set
self._uuid_to_partition_set = weakref.WeakValueDictionary()
with self._lock:
del self.__uuid_to_partition_set
self.__uuid_to_partition_set = weakref.WeakValueDictionary()
6 changes: 1 addition & 5 deletions daft/runners/pyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,7 @@ def run_iter(
# Finalize the logical plan and get a physical plan scheduler for translating the
# physical plan to executable tasks.
plan_scheduler = builder.to_physical_plan_scheduler(daft_execution_config)
psets = {
key: entry.value.values()
for key, entry in self._part_set_cache._uuid_to_partition_set.items()
if entry.value is not None
}
psets = {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()}
# Get executable tasks from planner.
tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=False)
with profiler("profile_PyRunner.run_{datetime.now().isoformat()}.json"):
Expand Down
7 changes: 2 additions & 5 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,11 +727,8 @@ def run_iter(
# physical plan to executable tasks.
plan_scheduler = builder.to_physical_plan_scheduler(daft_execution_config)

psets = {
key: entry.value.values()
for key, entry in self._part_set_cache._uuid_to_partition_set.items()
if entry.value is not None
}
psets = {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()}

result_uuid = str(uuid.uuid4())
if isinstance(self.ray_context, ray.client_builder.ClientContext):
ray.get(
Expand Down

0 comments on commit 9b6cb94

Please sign in to comment.