From 4b67b852b33057f0f2e6db44e585d5b0b2ce3273 Mon Sep 17 00:00:00 2001
From: Jay Chia <17691182+jaychia@users.noreply.github.com>
Date: Wed, 19 Jun 2024 11:26:11 -0700
Subject: [PATCH] [FEAT] Automatically use Ray Runner if Ray is initialized
(#2282)
1. Automatically switches Daft to use the Ray Runner if a user calls
`ray.init(...)` before running any Daft querying
2. Also switches behavior to try and deprecate the `DAFT_RAY_ADDRESS`
environment variable so that we can centralize on the normal
`RAY_ADDRESS` behavior
This PR ensures the following behavior:
* If a user explicitly calls `daft.context.set_runner_ray/py`, this
overrides all behavior
* If a user calls daft.context.set_runner_ray with a specified address,
but Ray is already initialized, we warn them that their address is being
ignored
* Otherwise, on first execution Daft will attempt to retrieve the runner
config from the current environment:
* Check for the `DAFT_RUNNER` environment variable for `RAY`/`PY`
* Check to see if Ray is initialized, and if we aren't running in a Ray
worker: `RAY`
* Fallback onto: `PY`
Ray connection detection, on driver vs on worker:
Warning if set_runner_ray is called with an address after Ray is already
initialized:
---------
Co-authored-by: Jay Chia
---
daft/context.py | 83 +++++++++++++------
daft/runners/ray_runner.py | 23 +++--
.../text_to_image/using_cloud_with_ray.ipynb | 2 +-
3 files changed, 73 insertions(+), 35 deletions(-)
diff --git a/daft/context.py b/daft/context.py
index 38dcd17501..aa98ff23fc 100644
--- a/daft/context.py
+++ b/daft/context.py
@@ -39,20 +39,54 @@ def _get_runner_config_from_env() -> _RunnerConfig:
To use:
1. PyRunner: set DAFT_RUNNER=py
- 2. RayRunner: set DAFT_RUNNER=ray and optionally DAFT_RAY_ADDRESS=ray://...
+ 2. RayRunner: set DAFT_RUNNER=ray and optionally RAY_ADDRESS=ray://...
"""
- runner = os.getenv("DAFT_RUNNER") or "PY"
- if runner.upper() == "RAY":
- task_backlog_env = os.getenv("DAFT_DEVELOPER_RAY_MAX_TASK_BACKLOG")
+ runner_from_envvar = os.getenv("DAFT_RUNNER")
+ task_backlog_env = os.getenv("DAFT_DEVELOPER_RAY_MAX_TASK_BACKLOG")
+ use_thread_pool_env = os.getenv("DAFT_DEVELOPER_USE_THREAD_POOL")
+ use_thread_pool = bool(int(use_thread_pool_env)) if use_thread_pool_env is not None else None
+
+ ray_is_initialized = False
+ in_ray_worker = False
+ try:
+ import ray
+
+ if ray.is_initialized():
+ ray_is_initialized = True
+ # Check if running inside a Ray worker
+ if ray._private.worker.global_worker.mode == ray.WORKER_MODE:
+ in_ray_worker = True
+ except ImportError:
+ pass
+
+ # Retrieve the runner from environment variables
+ if runner_from_envvar and runner_from_envvar.upper() == "RAY":
+ ray_address = os.getenv("DAFT_RAY_ADDRESS")
+ if ray_address is not None:
+ warnings.warn(
+ "Detected usage of the $DAFT_RAY_ADDRESS environment variable. This will be deprecated, please use $RAY_ADDRESS instead."
+ )
+ else:
+ ray_address = os.getenv("RAY_ADDRESS")
+ return _RayRunnerConfig(
+ address=ray_address,
+ max_task_backlog=int(task_backlog_env) if task_backlog_env else None,
+ )
+ elif runner_from_envvar and runner_from_envvar.upper() == "PY":
+ return _PyRunnerConfig(use_thread_pool=use_thread_pool)
+ elif runner_from_envvar is not None:
+ raise ValueError(f"Unsupported DAFT_RUNNER variable: {runner_from_envvar}")
+
+ # Retrieve the runner from current initialized Ray environment, only if not running in a Ray worker
+ elif ray_is_initialized and not in_ray_worker:
return _RayRunnerConfig(
- address=os.getenv("DAFT_RAY_ADDRESS"),
+ address=None, # No address supplied, use the existing connection
max_task_backlog=int(task_backlog_env) if task_backlog_env else None,
)
- elif runner.upper() == "PY":
- use_thread_pool_env = os.getenv("DAFT_DEVELOPER_USE_THREAD_POOL")
- use_thread_pool = bool(int(use_thread_pool_env)) if use_thread_pool_env is not None else None
+
+ # Fall back on PyRunner
+ else:
return _PyRunnerConfig(use_thread_pool=use_thread_pool)
- raise ValueError(f"Unsupported DAFT_RUNNER variable: {runner}")
@dataclasses.dataclass
@@ -66,7 +100,7 @@ class DaftContext:
# Non-execution calls (e.g. creation of a dataframe, logical plan building etc) directly reference values in this config
_daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig()
- _runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env)
+ _runner_config: _RunnerConfig | None = None
_disallow_set_runner: bool = False
_runner: Runner | None = None
@@ -100,13 +134,20 @@ def daft_planning_config(self) -> PyDaftPlanningConfig:
@property
def runner_config(self) -> _RunnerConfig:
with self._lock:
+ return self._get_runner_config()
+
+ def _get_runner_config(self) -> _RunnerConfig:
+ if self._runner_config is not None:
return self._runner_config
+ self._runner_config = _get_runner_config_from_env()
+ return self._runner_config
def _get_runner(self) -> Runner:
if self._runner is not None:
return self._runner
- if self._runner_config.name == "ray":
+ runner_config = self._get_runner_config()
+ if runner_config.name == "ray":
from daft.runners.ray_runner import RayRunner
assert isinstance(self._runner_config, _RayRunnerConfig)
@@ -114,27 +155,14 @@ def _get_runner(self) -> Runner:
address=self._runner_config.address,
max_task_backlog=self._runner_config.max_task_backlog,
)
- elif self._runner_config.name == "py":
+ elif runner_config.name == "py":
from daft.runners.pyrunner import PyRunner
- try:
- import ray
-
- if ray.is_initialized():
- logger.warning(
- "WARNING: Daft is NOT using Ray for execution!\n"
- "Daft is using the PyRunner but we detected an active Ray connection. "
- "If you intended to use the Daft RayRunner, please first run `daft.context.set_runner_ray()` "
- "before executing Daft queries."
- )
- except ImportError:
- pass
-
assert isinstance(self._runner_config, _PyRunnerConfig)
self._runner = PyRunner(use_thread_pool=self._runner_config.use_thread_pool)
else:
- raise NotImplementedError(f"Runner config implemented: {self._runner_config.name}")
+ raise NotImplementedError(f"Runner config not implemented: {runner_config.name}")
# Mark DaftContext as having the runner set, which prevents any subsequent setting of the config
# after the runner has been initialized once
@@ -165,7 +193,7 @@ def set_runner_ray(
Alternatively, users can set this behavior via environment variables:
1. DAFT_RUNNER=ray
- 2. Optionally, DAFT_RAY_ADDRESS=ray://...
+ 2. Optionally, RAY_ADDRESS=ray://...
**This function will throw an error if called multiple times in the same process.**
@@ -178,6 +206,7 @@ def set_runner_ray(
Returns:
DaftContext: Daft context after setting the Ray runner
"""
+
ctx = get_context()
with ctx._lock:
if ctx._disallow_set_runner:
diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py
index d8208ce752..0c48f9da74 100644
--- a/daft/runners/ray_runner.py
+++ b/daft/runners/ray_runner.py
@@ -739,10 +739,19 @@ def __init__(
) -> None:
super().__init__()
if ray.is_initialized():
- logger.warning("Ray has already been initialized, Daft will reuse the existing Ray context.")
- self.ray_context = ray.init(address=address, ignore_reinit_error=True)
+ if address is not None:
+ logger.warning(
+ "Ray has already been initialized, Daft will reuse the existing Ray context and ignore the "
+ "supplied address: %s",
+ address,
+ )
+ else:
+ ray.init(address=address)
+
+ # Check if Ray is running in "client mode" (connected to a Ray cluster via a Ray client)
+ self.ray_client_mode = ray.util.client.ray.get_context().is_connected()
- if isinstance(self.ray_context, ray.client_builder.ClientContext):
+ if self.ray_client_mode:
# Run scheduler remotely if the cluster is connected remotely.
self.scheduler_actor = SchedulerActor.options( # type: ignore
name=SCHEDULER_ACTOR_NAME,
@@ -759,7 +768,7 @@ def __init__(
)
def active_plans(self) -> list[str]:
- if isinstance(self.ray_context, ray.client_builder.ClientContext):
+ if self.ray_client_mode:
return ray.get(self.scheduler_actor.active_plans.remote())
else:
return self.scheduler.active_plans()
@@ -772,7 +781,7 @@ def _start_plan(
) -> str:
psets = {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()}
result_uuid = str(uuid.uuid4())
- if isinstance(self.ray_context, ray.client_builder.ClientContext):
+ if self.ray_client_mode:
ray.get(
self.scheduler_actor.start_plan.remote(
daft_execution_config=daft_execution_config,
@@ -795,7 +804,7 @@ def _start_plan(
def _stream_plan(self, result_uuid: str) -> Iterator[RayMaterializedResult]:
try:
while True:
- if isinstance(self.ray_context, ray.client_builder.ClientContext):
+ if self.ray_client_mode:
result = ray.get(self.scheduler_actor.next.remote(result_uuid))
else:
result = self.scheduler.next(result_uuid)
@@ -808,7 +817,7 @@ def _stream_plan(self, result_uuid: str) -> Iterator[RayMaterializedResult]:
yield result
finally:
# Generator is out of scope, ensure that state has been cleaned up
- if isinstance(self.ray_context, ray.client_builder.ClientContext):
+ if self.ray_client_mode:
ray.get(self.scheduler_actor.stop_plan.remote(result_uuid))
else:
self.scheduler.stop_plan(result_uuid)
diff --git a/tutorials/text_to_image/using_cloud_with_ray.ipynb b/tutorials/text_to_image/using_cloud_with_ray.ipynb
index d0e44d039c..c14aa26be9 100644
--- a/tutorials/text_to_image/using_cloud_with_ray.ipynb
+++ b/tutorials/text_to_image/using_cloud_with_ray.ipynb
@@ -71,7 +71,7 @@
"\n",
"To activate the RayRunner, you can either:\n",
"\n",
- "1. Use the `DAFT_RUNNER=ray` and optionally the `DAFT_RAY_ADDRESS` environment variables\n",
+ "1. Use the `DAFT_RUNNER=ray` and optionally the `RAY_ADDRESS` environment variables\n",
"2. Call `daft.context.set_runner_ray(...)` at the start of your program.\n",
"\n",
"We'll demonstrate option 2 here!"