Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into daft…
Browse files Browse the repository at this point in the history
…-docs-v2
  • Loading branch information
ccmao1130 committed Dec 19, 2024
2 parents d4e1397 + ca4d3f7 commit ba3dad4
Show file tree
Hide file tree
Showing 123 changed files with 3,297 additions and 2,720 deletions.
114 changes: 114 additions & 0 deletions .github/ci-scripts/job_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# /// script
# requires-python = ">=3.12"
# dependencies = []
# ///

import argparse
import asyncio
import json
from dataclasses import dataclass
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional

from ray.job_submission import JobStatus, JobSubmissionClient


def parse_env_var_str(env_var_str: str) -> dict:
iter = map(
lambda s: s.strip().split("="),
filter(lambda s: s, env_var_str.split(",")),
)
return {k: v for k, v in iter}


async def print_logs(logs):
async for lines in logs:
print(lines, end="")


async def wait_on_job(logs, timeout_s):
await asyncio.wait_for(print_logs(logs), timeout=timeout_s)


@dataclass
class Result:
query: int
duration: timedelta
error_msg: Optional[str]


def submit_job(
working_dir: Path,
entrypoint_script: str,
entrypoint_args: str,
env_vars: str,
enable_ray_tracing: bool,
):
env_vars_dict = parse_env_var_str(env_vars)
if enable_ray_tracing:
env_vars_dict["DAFT_ENABLE_RAY_TRACING"] = "1"

client = JobSubmissionClient(address="http://localhost:8265")

if entrypoint_args.startswith("[") and entrypoint_args.endswith("]"):
# this is a json-encoded list of strings; parse accordingly
list_of_entrypoint_args: list[str] = json.loads(entrypoint_args)
else:
list_of_entrypoint_args: list[str] = [entrypoint_args]

results = []

for index, args in enumerate(list_of_entrypoint_args):
entrypoint = f"DAFT_RUNNER=ray python {entrypoint_script} {args}"
print(f"{entrypoint=}")
start = datetime.now()
job_id = client.submit_job(
entrypoint=entrypoint,
runtime_env={
"working_dir": working_dir,
"env_vars": env_vars_dict,
},
)

asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s=60 * 30))

status = client.get_job_status(job_id)
assert status.is_terminal(), "Job should have terminated"
end = datetime.now()
duration = end - start
error_msg = None
if status != JobStatus.SUCCEEDED:
job_info = client.get_job_info(job_id)
error_msg = job_info.message

result = Result(query=index, duration=duration, error_msg=error_msg)
results.append(result)

print(f"{results=}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--working-dir", type=Path, required=True)
parser.add_argument("--entrypoint-script", type=str, required=True)
parser.add_argument("--entrypoint-args", type=str, required=True)
parser.add_argument("--env-vars", type=str, required=True)
parser.add_argument("--enable-ray-tracing", action="store_true")

args = parser.parse_args()

if not (args.working_dir.exists() and args.working_dir.is_dir()):
raise ValueError("The working-dir must exist and be a directory")

entrypoint: Path = args.working_dir / args.entrypoint_script
if not (entrypoint.exists() and entrypoint.is_file()):
raise ValueError("The entrypoint script must exist and be a file")

submit_job(
working_dir=args.working_dir,
entrypoint_script=args.entrypoint_script,
entrypoint_args=args.entrypoint_args,
env_vars=args.env_vars,
enable_ray_tracing=args.enable_ray_tracing,
)
2 changes: 2 additions & 0 deletions .github/ci-scripts/templatize_ray_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,7 @@ class Metadata(BaseModel, extra="allow"):
if metadata:
metadata = Metadata(**metadata)
content = content.replace(OTHER_INSTALL_PLACEHOLDER, " ".join(metadata.dependencies))
else:
content = content.replace(OTHER_INSTALL_PLACEHOLDER, "")

print(content)
38 changes: 16 additions & 22 deletions .github/workflows/run-cluster.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ on:
type: string
required: true
entrypoint_args:
description: Entry-point arguments
description: Entry-point arguments (either a simple string or a JSON list)
type: string
required: false
default: ""
Expand Down Expand Up @@ -79,24 +79,15 @@ jobs:
uv run \
--python 3.12 \
.github/ci-scripts/templatize_ray_config.py \
--cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \
--daft-wheel-url '${{ inputs.daft_wheel_url }}' \
--daft-version '${{ inputs.daft_version }}' \
--python-version '${{ inputs.python_version }}' \
--cluster-profile '${{ inputs.cluster_profile }}' \
--working-dir '${{ inputs.working_dir }}' \
--entrypoint-script '${{ inputs.entrypoint_script }}'
--cluster-name="ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \
--daft-wheel-url='${{ inputs.daft_wheel_url }}' \
--daft-version='${{ inputs.daft_version }}' \
--python-version='${{ inputs.python_version }}' \
--cluster-profile='${{ inputs.cluster_profile }}' \
--working-dir='${{ inputs.working_dir }}' \
--entrypoint-script='${{ inputs.entrypoint_script }}'
) >> .github/assets/ray.yaml
cat .github/assets/ray.yaml
- name: Setup ray env vars
run: |
source .venv/bin/activate
ray_env_var=$(python .github/ci-scripts/format_env_vars.py \
--env-vars '${{ inputs.env_vars }}' \
--enable-ray-tracing \
)
echo $ray_env_var
echo "ray_env_var=$ray_env_var" >> $GITHUB_ENV
- name: Download private ssh key
run: |
KEY=$(aws secretsmanager get-secret-value --secret-id ci-github-actions-ray-cluster-key-3 --query SecretString --output text)
Expand All @@ -117,12 +108,14 @@ jobs:
echo 'Invalid command submitted; command cannot be empty'
exit 1
fi
ray job submit \
--working-dir ${{ inputs.working_dir }} \
--address http://localhost:8265 \
--runtime-env-json "$ray_env_var" \
-- python ${{ inputs.entrypoint_script }} ${{ inputs.entrypoint_args }}
python .github/ci-scripts/job_runner.py \
--working-dir='${{ inputs.working_dir }}' \
--entrypoint-script='${{ inputs.entrypoint_script }}' \
--entrypoint-args='${{ inputs.entrypoint_args }}' \
--env-vars='${{ inputs.env_vars }}' \
--enable-ray-tracing
- name: Download log files from ray cluster
if: always()
run: |
source .venv/bin/activate
ray rsync-down .github/assets/ray.yaml /tmp/ray/session_*/logs ray-daft-logs
Expand Down Expand Up @@ -152,6 +145,7 @@ jobs:
source .venv/bin/activate
ray down .github/assets/ray.yaml -y
- name: Upload log files
if: always()
uses: actions/upload-artifact@v4
with:
name: ray-daft-logs
Expand Down
51 changes: 42 additions & 9 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit ba3dad4

Please sign in to comment.