diff --git a/.github/ci-scripts/templatize_ray_config.py b/.github/ci-scripts/templatize_ray_config.py new file mode 100644 index 0000000000..ea597c1dd1 --- /dev/null +++ b/.github/ci-scripts/templatize_ray_config.py @@ -0,0 +1,80 @@ +import sys +from argparse import ArgumentParser +from dataclasses import dataclass +from typing import Optional + +CLUSTER_NAME_PLACEHOLDER = "{{CLUSTER_NAME}}" +DAFT_VERSION_PLACEHOLDER = "{{DAFT_VERSION}}" +PYTHON_VERSION_PLACEHOLDER = "{{PYTHON_VERSION}}" +CLUSTER_PROFILE__NODE_COUNT = "'{{CLUSTER_PROFILE/node_count}}'" +CLUSTER_PROFILE__INSTANCE_TYPE = "{{CLUSTER_PROFILE/instance_type}}" +CLUSTER_PROFILE__IMAGE_ID = "{{CLUSTER_PROFILE/image_id}}" +CLUSTER_PROFILE__VOLUME_MOUNT = "'{{CLUSTER_PROFILE/volume_mount}}'" + + +@dataclass +class Profile: + node_count: int + instance_type: int + image_id: int + volume_mount: Optional[int] + + +profiles: dict[str, Optional[Profile]] = { + "debug_xs-x86": None, + "medium-x86": Profile( + instance_type="i3.2xlarge", + image_id="ami-04dd23e62ed049936", + node_count=4, + volume_mount=""" | + findmnt /tmp 1> /dev/null + code=$? + if [ $code -ne 0 ]; then + sudo mkfs.ext4 /dev/nvme0n1 + sudo mount -t ext4 /dev/nvme0n1 /tmp + sudo chmod 777 /tmp + fi""", + ), +} + + +if __name__ == "__main__": + content = sys.stdin.read() + + parser = ArgumentParser() + parser.add_argument("--cluster-name") + parser.add_argument("--daft-version") + parser.add_argument("--python-version") + parser.add_argument("--cluster-profile") + args = parser.parse_args() + + if args.cluster_name: + content = content.replace(CLUSTER_NAME_PLACEHOLDER, args.cluster_name) + + if args.daft_version: + content = content.replace(DAFT_VERSION_PLACEHOLDER, f"=={args.daft_version}") + else: + content = content.replace(DAFT_VERSION_PLACEHOLDER, "") + + if args.python_version: + content = content.replace(PYTHON_VERSION_PLACEHOLDER, args.python_version) + + if cluster_profile := args.cluster_profile: + cluster_profile: str + if cluster_profile not in profiles: + raise Exception(f'Cluster profile "{cluster_profile}" not found') + + profile = profiles[cluster_profile] + if profile is None: + raise Exception(f'Cluster profile "{cluster_profile}" not yet implemented') + + assert profile is not None + content = content.replace(CLUSTER_PROFILE__NODE_COUNT, str(profile.node_count)) + content = content.replace(CLUSTER_PROFILE__INSTANCE_TYPE, profile.instance_type) + content = content.replace(CLUSTER_PROFILE__IMAGE_ID, profile.image_id) + if profile.volume_mount: + content = content.replace(CLUSTER_PROFILE__VOLUME_MOUNT, profile.volume_mount) + else: + content = content.replace(CLUSTER_PROFILE__VOLUME_MOUNT, "echo 'Nothing to mount; skipping'") + + print(content) diff --git a/.github/workflows/run-cluster.yaml b/.github/workflows/run-cluster.yaml index 1a296b36a0..3d3e5b1ae6 100644 --- a/.github/workflows/run-cluster.yaml +++ b/.github/workflows/run-cluster.yaml @@ -12,6 +12,13 @@ on: description: The version of python to use required: false default: "3.9" + cluster_profile: + type: choice + options: + - medium-x86 + description: The profile to use for the cluster + required: false + default: medium-x86 command: type: string description: The command to run on the cluster @@ -50,14 +57,16 @@ jobs: uv pip install ray[default] boto3 - name: Dynamically update ray config file run: | - id="ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" - sed -i "s|{{RAY_CLUSTER_NAME}}|$id|g" .github/assets/benchmarking_ray_config.yaml - sed -i 's|{{PYTHON_VERSION}}|${{ inputs.python_version }}|g' .github/assets/benchmarking_ray_config.yaml - if [[ '${{ inputs.daft_version }}' ]]; then - sed -i 's|{{DAFT_VERSION}}|==${{ inputs.daft_version }}|g' .github/assets/benchmarking_ray_config.yaml - else - sed -i 's|{{DAFT_VERSION}}||g' .github/assets/benchmarking_ray_config.yaml - fi + source .venv/bin/activate + (cat .github/assets/template.yaml \ + | python .github/ci-scripts/templatize_ray_config.py \ + --cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \ + --daft-version ${{ inputs.daft_version }} \ + --python-version ${{ inputs.python_version }} \ + --cluster-profile 'medium-x86' + ) >> .github/assets/ray.yaml + echo "Ray configuration file:" >> $GITHUB_STEP_SUMMARY + cat .github/assets/ray.yaml >> $GITHUB_STEP_SUMMARY - name: Download private ssh key run: | KEY=$(aws secretsmanager get-secret-value --secret-id ci-github-actions-ray-cluster-key-3 --query SecretString --output text) @@ -66,11 +75,11 @@ jobs: - name: Spin up ray cluster run: | source .venv/bin/activate - ray up .github/assets/benchmarking_ray_config.yaml -y + ray up .github/assets/ray.yaml -y - name: Setup connection to ray cluster run: | source .venv/bin/activate - ray dashboard .github/assets/benchmarking_ray_config.yaml & + ray dashboard .github/assets/ray.yaml & - name: Submit job to ray cluster run: | source .venv/bin/activate @@ -86,7 +95,7 @@ jobs: - name: Download log files from ray cluster run: | source .venv/bin/activate - ray rsync-down .github/assets/benchmarking_ray_config.yaml /tmp/ray/session_*/logs ray-daft-logs + ray rsync-down .github/assets/ray.yaml /tmp/ray/session_*/logs ray-daft-logs - name: Kill connection to ray cluster run: | PID=$(lsof -t -i:8265) @@ -103,7 +112,7 @@ jobs: if: always() run: | source .venv/bin/activate - ray down .github/assets/benchmarking_ray_config.yaml -y + ray down .github/assets/ray.yaml -y - name: Upload log files uses: actions/upload-artifact@v4 with: