Skip to content

Commit

Permalink
Add (very basic) argument parsing into a list of strings
Browse files Browse the repository at this point in the history
  • Loading branch information
raunakab committed Dec 14, 2024
1 parent 1ccac49 commit 98c52ad
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions .github/ci-scripts/job_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import asyncio
import json
from pathlib import Path

from ray.job_submission import JobStatus, JobSubmissionClient
Expand Down Expand Up @@ -29,27 +30,30 @@ def submit_job(
env_vars: str,
enable_ray_tracing: bool,
):
env_vars = parse_env_var_str(entrypoint_args)
if args.enable_ray_tracing:
env_vars["DAFT_ENABLE_RAY_TRACING"] = "1"

client = JobSubmissionClient(address="http://localhost:8265")
job_id = client.submit_job(
entrypoint=f"python {entrypoint_script} {entrypoint_args}",
runtime_env={
"working_dir": working_dir,
"env_vars": env_vars,
},
)

asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s=60 * 30))

status = client.get_job_status(job_id)
assert status.is_terminal(), "Job should have terminated"
if status != JobStatus.SUCCEEDED:
job_info = client.get_job_info(job_id)
raise RuntimeError(f"Job failed with {job_info.error_type} error: {job_info.message}")
print(f"Job completed with {status}")
args_list: list[str] = json.loads(f"[{entrypoint_args}]")

for args in args_list:
env_vars = parse_env_var_str(entrypoint_args)
if args_list.enable_ray_tracing:
env_vars["DAFT_ENABLE_RAY_TRACING"] = "1"

client = JobSubmissionClient(address="http://localhost:8265")
job_id = client.submit_job(
entrypoint=f"python {entrypoint_script} {args}",
runtime_env={
"working_dir": working_dir,
"env_vars": env_vars,
},
)

asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s=60 * 30))

status = client.get_job_status(job_id)
assert status.is_terminal(), "Job should have terminated"
if status != JobStatus.SUCCEEDED:
job_info = client.get_job_info(job_id)
raise RuntimeError(f"Job failed with {job_info.error_type} error: {job_info.message}")
print(f"Job completed with {status}")


if __name__ == "__main__":
Expand Down

0 comments on commit 98c52ad

Please sign in to comment.