diff --git a/benchmarking/tpch/__main__.py b/benchmarking/tpch/__main__.py index 7f116896dd..8ad131e08f 100644 --- a/benchmarking/tpch/__main__.py +++ b/benchmarking/tpch/__main__.py @@ -12,14 +12,13 @@ import subprocess import warnings from datetime import datetime, timezone -from typing import Any, Callable +from typing import Any, Callable, Literal import ray import daft from benchmarking.tpch import answers, data_generation from daft import DataFrame -from daft.context import get_context from daft.runners.profiler import profiler logger = logging.getLogger(__name__) @@ -130,8 +129,7 @@ def run_all_benchmarks( ): get_df = get_df_with_parquet_folder(parquet_folder) - daft_context = get_context() - metrics_builder = MetricsBuilder(daft_context.get_or_create_runner().name) + metrics_builder = MetricsBuilder(get_daft_benchmark_runner_name()) for i in questions: # Run as a Ray Job if dashboard URL is provided @@ -194,6 +192,16 @@ def get_daft_version() -> str: return daft.get_version() +def get_daft_benchmark_runner_name() -> Literal["ray"] | Literal["py"] | Literal["native"]: + """Test utility that checks the environment variable for the runner that is being used for the benchmarking""" + name = os.getenv("DAFT_RUNNER") + assert name is not None, "Tests must be run with $DAFT_RUNNER env var" + name = name.lower() + + assert name in {"ray", "py", "native"}, f"Runner name not recognized: {name}" + return name + + def get_ray_runtime_env(requirements: str | None) -> dict: runtime_env = { "py_modules": [daft], @@ -210,13 +218,10 @@ def get_ray_runtime_env(requirements: str | None) -> dict: def warmup_environment(requirements: str | None, parquet_folder: str): """Performs necessary setup of Daft on the current benchmarking environment""" - ctx = daft.context.get_context() - - if ctx.get_or_create_runner().name == "ray": + if get_daft_benchmark_runner_name() == "ray": runtime_env = get_ray_runtime_env(requirements) ray.init( - address=ctx._runner.ray_address, runtime_env=runtime_env, )