Skip to content

Commit

Permalink
[FEAT] Automatically use Ray Runner if Ray is initialized (#2282)
Browse files Browse the repository at this point in the history
1. Automatically switches Daft to use the Ray Runner if a user calls
`ray.init(...)` before running any Daft querying
2. Also switches behavior to try and deprecate the `DAFT_RAY_ADDRESS`
environment variable so that we can centralize on the normal
`RAY_ADDRESS` behavior

This PR ensures the following behavior:

* If a user explicitly calls `daft.context.set_runner_ray/py`, this
overrides all behavior
* If a user calls daft.context.set_runner_ray with a specified address,
but Ray is already initialized, we warn them that their address is being
ignored
* Otherwise, on first execution Daft will attempt to retrieve the runner
config from the current environment:
    * Check for the `DAFT_RUNNER` environment variable for `RAY`/`PY`
* Check to see if Ray is initialized, and if we aren't running in a Ray
worker: `RAY`
    * Fallback onto: `PY`

Ray connection detection, on driver vs on worker:

<img width="1470" alt="image"
src="https://github.com/Eventual-Inc/Daft/assets/17691182/885a9dbc-0687-42ca-b0fa-6e99675bb00b">

Warning if set_runner_ray is called with an address after Ray is already
initialized:
<img width="1470" alt="image"
src="https://github.com/Eventual-Inc/Daft/assets/17691182/8468220d-ed5a-49e6-b6e4-61c4c24f7e4e">

---------

Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Jun 19, 2024
1 parent 682a18c commit 4b67b85
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 35 deletions.
83 changes: 56 additions & 27 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,54 @@ def _get_runner_config_from_env() -> _RunnerConfig:
To use:
1. PyRunner: set DAFT_RUNNER=py
2. RayRunner: set DAFT_RUNNER=ray and optionally DAFT_RAY_ADDRESS=ray://...
2. RayRunner: set DAFT_RUNNER=ray and optionally RAY_ADDRESS=ray://...
"""
runner = os.getenv("DAFT_RUNNER") or "PY"
if runner.upper() == "RAY":
task_backlog_env = os.getenv("DAFT_DEVELOPER_RAY_MAX_TASK_BACKLOG")
runner_from_envvar = os.getenv("DAFT_RUNNER")
task_backlog_env = os.getenv("DAFT_DEVELOPER_RAY_MAX_TASK_BACKLOG")
use_thread_pool_env = os.getenv("DAFT_DEVELOPER_USE_THREAD_POOL")
use_thread_pool = bool(int(use_thread_pool_env)) if use_thread_pool_env is not None else None

ray_is_initialized = False
in_ray_worker = False
try:
import ray

if ray.is_initialized():
ray_is_initialized = True
# Check if running inside a Ray worker
if ray._private.worker.global_worker.mode == ray.WORKER_MODE:
in_ray_worker = True
except ImportError:
pass

# Retrieve the runner from environment variables
if runner_from_envvar and runner_from_envvar.upper() == "RAY":
ray_address = os.getenv("DAFT_RAY_ADDRESS")
if ray_address is not None:
warnings.warn(
"Detected usage of the $DAFT_RAY_ADDRESS environment variable. This will be deprecated, please use $RAY_ADDRESS instead."
)
else:
ray_address = os.getenv("RAY_ADDRESS")
return _RayRunnerConfig(
address=ray_address,
max_task_backlog=int(task_backlog_env) if task_backlog_env else None,
)
elif runner_from_envvar and runner_from_envvar.upper() == "PY":
return _PyRunnerConfig(use_thread_pool=use_thread_pool)
elif runner_from_envvar is not None:
raise ValueError(f"Unsupported DAFT_RUNNER variable: {runner_from_envvar}")

# Retrieve the runner from current initialized Ray environment, only if not running in a Ray worker
elif ray_is_initialized and not in_ray_worker:
return _RayRunnerConfig(
address=os.getenv("DAFT_RAY_ADDRESS"),
address=None, # No address supplied, use the existing connection
max_task_backlog=int(task_backlog_env) if task_backlog_env else None,
)
elif runner.upper() == "PY":
use_thread_pool_env = os.getenv("DAFT_DEVELOPER_USE_THREAD_POOL")
use_thread_pool = bool(int(use_thread_pool_env)) if use_thread_pool_env is not None else None

# Fall back on PyRunner
else:
return _PyRunnerConfig(use_thread_pool=use_thread_pool)
raise ValueError(f"Unsupported DAFT_RUNNER variable: {runner}")


@dataclasses.dataclass
Expand All @@ -66,7 +100,7 @@ class DaftContext:
# Non-execution calls (e.g. creation of a dataframe, logical plan building etc) directly reference values in this config
_daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig()

_runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env)
_runner_config: _RunnerConfig | None = None
_disallow_set_runner: bool = False
_runner: Runner | None = None

Expand Down Expand Up @@ -100,41 +134,35 @@ def daft_planning_config(self) -> PyDaftPlanningConfig:
@property
def runner_config(self) -> _RunnerConfig:
with self._lock:
return self._get_runner_config()

def _get_runner_config(self) -> _RunnerConfig:
if self._runner_config is not None:
return self._runner_config
self._runner_config = _get_runner_config_from_env()
return self._runner_config

def _get_runner(self) -> Runner:
if self._runner is not None:
return self._runner

if self._runner_config.name == "ray":
runner_config = self._get_runner_config()
if runner_config.name == "ray":
from daft.runners.ray_runner import RayRunner

assert isinstance(self._runner_config, _RayRunnerConfig)
self._runner = RayRunner(
address=self._runner_config.address,
max_task_backlog=self._runner_config.max_task_backlog,
)
elif self._runner_config.name == "py":
elif runner_config.name == "py":
from daft.runners.pyrunner import PyRunner

try:
import ray

if ray.is_initialized():
logger.warning(
"WARNING: Daft is NOT using Ray for execution!\n"
"Daft is using the PyRunner but we detected an active Ray connection. "
"If you intended to use the Daft RayRunner, please first run `daft.context.set_runner_ray()` "
"before executing Daft queries."
)
except ImportError:
pass

assert isinstance(self._runner_config, _PyRunnerConfig)
self._runner = PyRunner(use_thread_pool=self._runner_config.use_thread_pool)

else:
raise NotImplementedError(f"Runner config implemented: {self._runner_config.name}")
raise NotImplementedError(f"Runner config not implemented: {runner_config.name}")

# Mark DaftContext as having the runner set, which prevents any subsequent setting of the config
# after the runner has been initialized once
Expand Down Expand Up @@ -165,7 +193,7 @@ def set_runner_ray(
Alternatively, users can set this behavior via environment variables:
1. DAFT_RUNNER=ray
2. Optionally, DAFT_RAY_ADDRESS=ray://...
2. Optionally, RAY_ADDRESS=ray://...
**This function will throw an error if called multiple times in the same process.**
Expand All @@ -178,6 +206,7 @@ def set_runner_ray(
Returns:
DaftContext: Daft context after setting the Ray runner
"""

ctx = get_context()
with ctx._lock:
if ctx._disallow_set_runner:
Expand Down
23 changes: 16 additions & 7 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,19 @@ def __init__(
) -> None:
super().__init__()
if ray.is_initialized():
logger.warning("Ray has already been initialized, Daft will reuse the existing Ray context.")
self.ray_context = ray.init(address=address, ignore_reinit_error=True)
if address is not None:
logger.warning(
"Ray has already been initialized, Daft will reuse the existing Ray context and ignore the "
"supplied address: %s",
address,
)
else:
ray.init(address=address)

# Check if Ray is running in "client mode" (connected to a Ray cluster via a Ray client)
self.ray_client_mode = ray.util.client.ray.get_context().is_connected()

if isinstance(self.ray_context, ray.client_builder.ClientContext):
if self.ray_client_mode:
# Run scheduler remotely if the cluster is connected remotely.
self.scheduler_actor = SchedulerActor.options( # type: ignore
name=SCHEDULER_ACTOR_NAME,
Expand All @@ -759,7 +768,7 @@ def __init__(
)

def active_plans(self) -> list[str]:
if isinstance(self.ray_context, ray.client_builder.ClientContext):
if self.ray_client_mode:
return ray.get(self.scheduler_actor.active_plans.remote())
else:
return self.scheduler.active_plans()
Expand All @@ -772,7 +781,7 @@ def _start_plan(
) -> str:
psets = {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()}
result_uuid = str(uuid.uuid4())
if isinstance(self.ray_context, ray.client_builder.ClientContext):
if self.ray_client_mode:
ray.get(
self.scheduler_actor.start_plan.remote(
daft_execution_config=daft_execution_config,
Expand All @@ -795,7 +804,7 @@ def _start_plan(
def _stream_plan(self, result_uuid: str) -> Iterator[RayMaterializedResult]:
try:
while True:
if isinstance(self.ray_context, ray.client_builder.ClientContext):
if self.ray_client_mode:
result = ray.get(self.scheduler_actor.next.remote(result_uuid))
else:
result = self.scheduler.next(result_uuid)
Expand All @@ -808,7 +817,7 @@ def _stream_plan(self, result_uuid: str) -> Iterator[RayMaterializedResult]:
yield result
finally:
# Generator is out of scope, ensure that state has been cleaned up
if isinstance(self.ray_context, ray.client_builder.ClientContext):
if self.ray_client_mode:
ray.get(self.scheduler_actor.stop_plan.remote(result_uuid))
else:
self.scheduler.stop_plan(result_uuid)
Expand Down
2 changes: 1 addition & 1 deletion tutorials/text_to_image/using_cloud_with_ray.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
"\n",
"To activate the RayRunner, you can either:\n",
"\n",
"1. Use the `DAFT_RUNNER=ray` and optionally the `DAFT_RAY_ADDRESS` environment variables\n",
"1. Use the `DAFT_RUNNER=ray` and optionally the `RAY_ADDRESS` environment variables\n",
"2. Call `daft.context.set_runner_ray(...)` at the start of your program.\n",
"\n",
"We'll demonstrate option 2 here!"
Expand Down

0 comments on commit 4b67b85

Please sign in to comment.