From 25c3b26f4e795b96d396473b8085b36f9ac82168 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 15 Nov 2024 07:35:10 +0800 Subject: [PATCH] [BUG] Check env in benchmarking script (#3297) Using `ctx.get_or_create_runner` in benchmarking warmup code / metrics builder causes subsequent `ray.inits` to crash. Just check the `DAFT_RUNNER` environment var instead, which should be set. Tested: - local -> https://github.com/Eventual-Inc/daft-benchmarking/actions/runs/11838323155 - remote -> https://github.com/Eventual-Inc/daft-benchmarking/actions/runs/11838783067 --------- Co-authored-by: Colin Ho --- benchmarking/tpch/__main__.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/benchmarking/tpch/__main__.py b/benchmarking/tpch/__main__.py index 7f116896dd..8ad131e08f 100644 --- a/benchmarking/tpch/__main__.py +++ b/benchmarking/tpch/__main__.py @@ -12,14 +12,13 @@ import subprocess import warnings from datetime import datetime, timezone -from typing import Any, Callable +from typing import Any, Callable, Literal import ray import daft from benchmarking.tpch import answers, data_generation from daft import DataFrame -from daft.context import get_context from daft.runners.profiler import profiler logger = logging.getLogger(__name__) @@ -130,8 +129,7 @@ def run_all_benchmarks( ): get_df = get_df_with_parquet_folder(parquet_folder) - daft_context = get_context() - metrics_builder = MetricsBuilder(daft_context.get_or_create_runner().name) + metrics_builder = MetricsBuilder(get_daft_benchmark_runner_name()) for i in questions: # Run as a Ray Job if dashboard URL is provided @@ -194,6 +192,16 @@ def get_daft_version() -> str: return daft.get_version() +def get_daft_benchmark_runner_name() -> Literal["ray"] | Literal["py"] | Literal["native"]: + """Test utility that checks the environment variable for the runner that is being used for the benchmarking""" + name = os.getenv("DAFT_RUNNER") + assert name is not None, "Tests must be run with $DAFT_RUNNER env var" + name = name.lower() + + assert name in {"ray", "py", "native"}, f"Runner name not recognized: {name}" + return name + + def get_ray_runtime_env(requirements: str | None) -> dict: runtime_env = { "py_modules": [daft], @@ -210,13 +218,10 @@ def get_ray_runtime_env(requirements: str | None) -> dict: def warmup_environment(requirements: str | None, parquet_folder: str): """Performs necessary setup of Daft on the current benchmarking environment""" - ctx = daft.context.get_context() - - if ctx.get_or_create_runner().name == "ray": + if get_daft_benchmark_runner_name() == "ray": runtime_env = get_ray_runtime_env(requirements) ray.init( - address=ctx._runner.ray_address, runtime_env=runtime_env, )