diff --git a/daft/context.py b/daft/context.py index b21fa09b60..a7f1948bf3 100644 --- a/daft/context.py +++ b/daft/context.py @@ -57,6 +57,7 @@ def _get_runner_config_from_env() -> _RunnerConfig: ) ray_is_initialized = False + ray_is_in_job = False in_ray_worker = False try: import ray @@ -66,6 +67,10 @@ def _get_runner_config_from_env() -> _RunnerConfig: # Check if running inside a Ray worker if ray._private.worker.global_worker.mode == ray.WORKER_MODE: in_ray_worker = True + # In a Ray job, Ray might not be initialized yet but we can pick up an environment variable as a heuristic here + elif os.getenv("RAY_JOB_ID") is not None: + ray_is_in_job = True + except ImportError: pass @@ -89,7 +94,7 @@ def _get_runner_config_from_env() -> _RunnerConfig: raise ValueError(f"Unsupported DAFT_RUNNER variable: {runner_from_envvar}") # Retrieve the runner from current initialized Ray environment, only if not running in a Ray worker - elif ray_is_initialized and not in_ray_worker: + elif not in_ray_worker and (ray_is_initialized or ray_is_in_job): return _RayRunnerConfig( address=None, # No address supplied, use the existing connection max_task_backlog=task_backlog,