diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index afcbcc3f39..d765a313d5 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -397,14 +397,14 @@ def _ray_num_cpus_provider(ttl_seconds: int = 1) -> Generator[int, None, None]: >>> next(p) """ last_checked_time = time.time() - last_num_cpus_queried = int(ray.cluster_resources()["CPU"]) + last_num_cpus_queried = int(ray.cluster_resources().get("CPU", 0)) while True: currtime = time.time() if currtime - last_checked_time < ttl_seconds: yield last_num_cpus_queried else: last_checked_time = currtime - last_num_cpus_queried = int(ray.cluster_resources()["CPU"]) + last_num_cpus_queried = int(ray.cluster_resources().get("CPU", 0)) yield last_num_cpus_queried @@ -526,7 +526,7 @@ def place_in_queue(item): while is_active(): # Loop: Dispatch (get tasks -> batch dispatch). tasks_to_dispatch: list[PartitionTask] = [] - cores: int = next(num_cpus_provider) - self.reserved_cores + cores: int = max(next(num_cpus_provider) - self.reserved_cores, 0) max_inflight_tasks = cores + self.max_task_backlog dispatches_allowed = max_inflight_tasks - len(inflight_tasks) dispatches_allowed = min(cores, dispatches_allowed)