Skip to content

Commit

Permalink
check env not runner
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Nov 14, 2024
1 parent 711e862 commit 945e1b6
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions benchmarking/tpch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import subprocess
import warnings
from datetime import datetime, timezone
from typing import Any, Callable
from typing import Any, Callable, Literal

import ray

Expand Down Expand Up @@ -194,6 +194,16 @@ def get_daft_version() -> str:
return daft.get_version()


def get_daft_benchmark_runner_name() -> Literal["ray"] | Literal["py"] | Literal["native"]:
"""Test utility that checks the environment variable for the runner that is being used for the benchmarking"""
name = os.getenv("DAFT_RUNNER")
assert name is not None, "Tests must be run with $DAFT_RUNNER env var"
name = name.lower()

assert name in {"ray", "py", "native"}, f"Runner name not recognized: {name}"
return name


def get_ray_runtime_env(requirements: str | None) -> dict:
runtime_env = {
"py_modules": [daft],
Expand All @@ -212,7 +222,7 @@ def warmup_environment(requirements: str | None, parquet_folder: str):
"""Performs necessary setup of Daft on the current benchmarking environment"""
ctx = daft.context.get_context()

if ctx.get_or_create_runner().name == "ray":
if get_daft_benchmark_runner_name() == "ray":
runtime_env = get_ray_runtime_env(requirements)

ray.init(
Expand Down

0 comments on commit 945e1b6

Please sign in to comment.