diff --git a/daft/context.py b/daft/context.py index 38dcd17501..aa98ff23fc 100644 --- a/daft/context.py +++ b/daft/context.py @@ -39,20 +39,54 @@ def _get_runner_config_from_env() -> _RunnerConfig: To use: 1. PyRunner: set DAFT_RUNNER=py - 2. RayRunner: set DAFT_RUNNER=ray and optionally DAFT_RAY_ADDRESS=ray://... + 2. RayRunner: set DAFT_RUNNER=ray and optionally RAY_ADDRESS=ray://... """ - runner = os.getenv("DAFT_RUNNER") or "PY" - if runner.upper() == "RAY": - task_backlog_env = os.getenv("DAFT_DEVELOPER_RAY_MAX_TASK_BACKLOG") + runner_from_envvar = os.getenv("DAFT_RUNNER") + task_backlog_env = os.getenv("DAFT_DEVELOPER_RAY_MAX_TASK_BACKLOG") + use_thread_pool_env = os.getenv("DAFT_DEVELOPER_USE_THREAD_POOL") + use_thread_pool = bool(int(use_thread_pool_env)) if use_thread_pool_env is not None else None + + ray_is_initialized = False + in_ray_worker = False + try: + import ray + + if ray.is_initialized(): + ray_is_initialized = True + # Check if running inside a Ray worker + if ray._private.worker.global_worker.mode == ray.WORKER_MODE: + in_ray_worker = True + except ImportError: + pass + + # Retrieve the runner from environment variables + if runner_from_envvar and runner_from_envvar.upper() == "RAY": + ray_address = os.getenv("DAFT_RAY_ADDRESS") + if ray_address is not None: + warnings.warn( + "Detected usage of the $DAFT_RAY_ADDRESS environment variable. This will be deprecated, please use $RAY_ADDRESS instead." + ) + else: + ray_address = os.getenv("RAY_ADDRESS") + return _RayRunnerConfig( + address=ray_address, + max_task_backlog=int(task_backlog_env) if task_backlog_env else None, + ) + elif runner_from_envvar and runner_from_envvar.upper() == "PY": + return _PyRunnerConfig(use_thread_pool=use_thread_pool) + elif runner_from_envvar is not None: + raise ValueError(f"Unsupported DAFT_RUNNER variable: {runner_from_envvar}") + + # Retrieve the runner from current initialized Ray environment, only if not running in a Ray worker + elif ray_is_initialized and not in_ray_worker: return _RayRunnerConfig( - address=os.getenv("DAFT_RAY_ADDRESS"), + address=None, # No address supplied, use the existing connection max_task_backlog=int(task_backlog_env) if task_backlog_env else None, ) - elif runner.upper() == "PY": - use_thread_pool_env = os.getenv("DAFT_DEVELOPER_USE_THREAD_POOL") - use_thread_pool = bool(int(use_thread_pool_env)) if use_thread_pool_env is not None else None + + # Fall back on PyRunner + else: return _PyRunnerConfig(use_thread_pool=use_thread_pool) - raise ValueError(f"Unsupported DAFT_RUNNER variable: {runner}") @dataclasses.dataclass @@ -66,7 +100,7 @@ class DaftContext: # Non-execution calls (e.g. creation of a dataframe, logical plan building etc) directly reference values in this config _daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig() - _runner_config: _RunnerConfig = dataclasses.field(default_factory=_get_runner_config_from_env) + _runner_config: _RunnerConfig | None = None _disallow_set_runner: bool = False _runner: Runner | None = None @@ -100,13 +134,20 @@ def daft_planning_config(self) -> PyDaftPlanningConfig: @property def runner_config(self) -> _RunnerConfig: with self._lock: + return self._get_runner_config() + + def _get_runner_config(self) -> _RunnerConfig: + if self._runner_config is not None: return self._runner_config + self._runner_config = _get_runner_config_from_env() + return self._runner_config def _get_runner(self) -> Runner: if self._runner is not None: return self._runner - if self._runner_config.name == "ray": + runner_config = self._get_runner_config() + if runner_config.name == "ray": from daft.runners.ray_runner import RayRunner assert isinstance(self._runner_config, _RayRunnerConfig) @@ -114,27 +155,14 @@ def _get_runner(self) -> Runner: address=self._runner_config.address, max_task_backlog=self._runner_config.max_task_backlog, ) - elif self._runner_config.name == "py": + elif runner_config.name == "py": from daft.runners.pyrunner import PyRunner - try: - import ray - - if ray.is_initialized(): - logger.warning( - "WARNING: Daft is NOT using Ray for execution!\n" - "Daft is using the PyRunner but we detected an active Ray connection. " - "If you intended to use the Daft RayRunner, please first run `daft.context.set_runner_ray()` " - "before executing Daft queries." - ) - except ImportError: - pass - assert isinstance(self._runner_config, _PyRunnerConfig) self._runner = PyRunner(use_thread_pool=self._runner_config.use_thread_pool) else: - raise NotImplementedError(f"Runner config implemented: {self._runner_config.name}") + raise NotImplementedError(f"Runner config not implemented: {runner_config.name}") # Mark DaftContext as having the runner set, which prevents any subsequent setting of the config # after the runner has been initialized once @@ -165,7 +193,7 @@ def set_runner_ray( Alternatively, users can set this behavior via environment variables: 1. DAFT_RUNNER=ray - 2. Optionally, DAFT_RAY_ADDRESS=ray://... + 2. Optionally, RAY_ADDRESS=ray://... **This function will throw an error if called multiple times in the same process.** @@ -178,6 +206,7 @@ def set_runner_ray( Returns: DaftContext: Daft context after setting the Ray runner """ + ctx = get_context() with ctx._lock: if ctx._disallow_set_runner: diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index d8208ce752..0c48f9da74 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -739,10 +739,19 @@ def __init__( ) -> None: super().__init__() if ray.is_initialized(): - logger.warning("Ray has already been initialized, Daft will reuse the existing Ray context.") - self.ray_context = ray.init(address=address, ignore_reinit_error=True) + if address is not None: + logger.warning( + "Ray has already been initialized, Daft will reuse the existing Ray context and ignore the " + "supplied address: %s", + address, + ) + else: + ray.init(address=address) + + # Check if Ray is running in "client mode" (connected to a Ray cluster via a Ray client) + self.ray_client_mode = ray.util.client.ray.get_context().is_connected() - if isinstance(self.ray_context, ray.client_builder.ClientContext): + if self.ray_client_mode: # Run scheduler remotely if the cluster is connected remotely. self.scheduler_actor = SchedulerActor.options( # type: ignore name=SCHEDULER_ACTOR_NAME, @@ -759,7 +768,7 @@ def __init__( ) def active_plans(self) -> list[str]: - if isinstance(self.ray_context, ray.client_builder.ClientContext): + if self.ray_client_mode: return ray.get(self.scheduler_actor.active_plans.remote()) else: return self.scheduler.active_plans() @@ -772,7 +781,7 @@ def _start_plan( ) -> str: psets = {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()} result_uuid = str(uuid.uuid4()) - if isinstance(self.ray_context, ray.client_builder.ClientContext): + if self.ray_client_mode: ray.get( self.scheduler_actor.start_plan.remote( daft_execution_config=daft_execution_config, @@ -795,7 +804,7 @@ def _start_plan( def _stream_plan(self, result_uuid: str) -> Iterator[RayMaterializedResult]: try: while True: - if isinstance(self.ray_context, ray.client_builder.ClientContext): + if self.ray_client_mode: result = ray.get(self.scheduler_actor.next.remote(result_uuid)) else: result = self.scheduler.next(result_uuid) @@ -808,7 +817,7 @@ def _stream_plan(self, result_uuid: str) -> Iterator[RayMaterializedResult]: yield result finally: # Generator is out of scope, ensure that state has been cleaned up - if isinstance(self.ray_context, ray.client_builder.ClientContext): + if self.ray_client_mode: ray.get(self.scheduler_actor.stop_plan.remote(result_uuid)) else: self.scheduler.stop_plan(result_uuid) diff --git a/tutorials/text_to_image/using_cloud_with_ray.ipynb b/tutorials/text_to_image/using_cloud_with_ray.ipynb index d0e44d039c..c14aa26be9 100644 --- a/tutorials/text_to_image/using_cloud_with_ray.ipynb +++ b/tutorials/text_to_image/using_cloud_with_ray.ipynb @@ -71,7 +71,7 @@ "\n", "To activate the RayRunner, you can either:\n", "\n", - "1. Use the `DAFT_RUNNER=ray` and optionally the `DAFT_RAY_ADDRESS` environment variables\n", + "1. Use the `DAFT_RUNNER=ray` and optionally the `RAY_ADDRESS` environment variables\n", "2. Call `daft.context.set_runner_ray(...)` at the start of your program.\n", "\n", "We'll demonstrate option 2 here!"