Skip to content

Commit

Permalink
[BUG] Allow for use of Ray jobs for benchmarking (#1690)
Browse files Browse the repository at this point in the history
Co-authored-by: Jay Chia <[email protected]@users.noreply.github.com>
  • Loading branch information
jaychia and Jay Chia authored Dec 2, 2023
1 parent 2d499c4 commit 2fbf885
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 10 deletions.
62 changes: 52 additions & 10 deletions benchmarking/tpch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import math
import os
import pathlib
import platform
import socket
import subprocess
Expand Down Expand Up @@ -120,7 +121,13 @@ def _get_df(table_name: str) -> DataFrame:
return _get_df


def run_all_benchmarks(parquet_folder: str, skip_questions: set[int], csv_output_location: str | None):
def run_all_benchmarks(
parquet_folder: str,
skip_questions: set[int],
csv_output_location: str | None,
ray_job_dashboard_url: str | None = None,
requirements: str | None = None,
):
get_df = get_df_with_parquet_folder(parquet_folder)

daft_context = get_context()
Expand All @@ -131,11 +138,32 @@ def run_all_benchmarks(parquet_folder: str, skip_questions: set[int], csv_output
logger.warning(f"Skipping TPC-H q{i}")
continue

answer = getattr(answers, f"q{i}")
daft_df = answer(get_df)
# Run as a Ray Job if dashboard URL is provided
if ray_job_dashboard_url is not None:
from benchmarking.tpch import ray_job_runner

working_dir = pathlib.Path(os.path.dirname(__file__))
entrypoint = working_dir / "ray_job_runner.py"
job_params = ray_job_runner.ray_job_params(
parquet_folder_path=parquet_folder,
tpch_qnum=i,
working_dir=working_dir,
entrypoint=entrypoint,
runtime_env=get_ray_runtime_env(requirements),
)
with metrics_builder.collect_metrics(i):
ray_job_runner.run_on_ray(
ray_job_dashboard_url,
job_params,
)

# Run locally (potentially on a local Ray cluster)
else:
answer = getattr(answers, f"q{i}")
daft_df = answer(get_df)

with metrics_builder.collect_metrics(i):
daft_df.collect()
with metrics_builder.collect_metrics(i):
daft_df.collect()

if csv_output_location:
logger.info(f"Writing CSV to: {csv_output_location}")
Expand All @@ -162,16 +190,23 @@ def get_daft_version() -> str:
return daft.get_version()


def get_ray_runtime_env(requirements: str | None) -> dict:
runtime_env = {
"py_modules": [daft],
"eager_install": True,
"env_vars": {"DAFT_PROGRESS_BAR": "0"},
}
if requirements:
runtime_env.update({"pip": requirements})
return runtime_env


def warmup_environment(requirements: str | None, parquet_folder: str):
"""Performs necessary setup of Daft on the current benchmarking environment"""
ctx = daft.context.get_context()

if ctx.runner_config.name == "ray":
runtime_env = {"py_modules": [daft]}
if requirements:
runtime_env.update({"pip": requirements})
if runtime_env:
runtime_env.update({"eager_install": True})
runtime_env = get_ray_runtime_env(requirements)

ray.init(
address=ctx.runner_config.address,
Expand Down Expand Up @@ -238,6 +273,11 @@ def warm_up_function():
action="store_true",
help="Skip warming up data before benchmark",
)
parser.add_argument(
"--ray_job_dashboard_url",
default=None,
help="Ray Dashboard URL to submit jobs instead of using Ray client, most useful when running on a remote cluster",
)

args = parser.parse_args()
if args.output_csv_headers:
Expand All @@ -264,4 +304,6 @@ def warm_up_function():
parquet_folder,
skip_questions={int(s) for s in args.skip_questions.split(",")} if args.skip_questions is not None else set(),
csv_output_location=args.output_csv,
ray_job_dashboard_url=args.ray_job_dashboard_url,
requirements=args.requirements,
)
86 changes: 86 additions & 0 deletions benchmarking/tpch/ray_job_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from __future__ import annotations

import argparse
import asyncio
import os
import pathlib
import time
import uuid
from typing import Callable

from ray.job_submission import JobStatus, JobSubmissionClient

import daft


async def print_logs(logs):
async for lines in logs:
print(lines, end="")


async def wait_on_job(logs, timeout_s):
await asyncio.wait_for(print_logs(logs), timeout=timeout_s)


def run_on_ray(ray_address: str, job_params: dict, timeout_s: int = 1500):
"""Submits a job to run in the Ray cluster"""

print("Submitting benchmarking job to Ray cluster...")
print("Parameters:")
print(job_params)

client = JobSubmissionClient(address=ray_address)
job_id = client.submit_job(**job_params)
print(f"Submitted job: {job_id}")

try:
asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s))
except asyncio.TimeoutError:
print(f"Job timed out after {timeout_s}s! Stopping job now...")
client.stop_job(job_id)
time.sleep(16)

status = client.get_job_status(job_id)
assert status.is_terminal(), "Job should have terminated"
if status != JobStatus.SUCCEEDED:
job_info = client.get_job_info(job_id)
raise RuntimeError(f"Job failed with {job_info.error_type} error: {job_info.message}")
print(f"Job completed with {status}")


def ray_job_params(
parquet_folder_path: str,
tpch_qnum: int,
working_dir: pathlib.Path,
entrypoint: pathlib.Path,
runtime_env: dict,
) -> dict:
return dict(
submission_id=f"tpch-q{tpch_qnum}-{str(uuid.uuid4())[:4]}",
entrypoint=f"python {str(entrypoint.relative_to(working_dir))} --parquet-folder {parquet_folder_path} --question-number {tpch_qnum}",
runtime_env={
"working_dir": str(working_dir),
**runtime_env,
},
)


def get_df_with_parquet_folder(parquet_folder: str) -> Callable[[str], daft.DataFrame]:
def _get_df(table_name: str) -> daft.DataFrame:
return daft.read_parquet(os.path.join(parquet_folder, table_name, "*.parquet"))

return _get_df


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--parquet-folder", help="Path to TPC-H data stored on workers", required=True)
parser.add_argument("--question-number", help="Question number to run", required=True)
args = parser.parse_args()

import answers

get_df = get_df_with_parquet_folder(args.parquet_folder)
answer = getattr(answers, f"q{args.question_number}")
daft_df = answer(get_df)
daft_df.collect()

0 comments on commit 2fbf885

Please sign in to comment.