From 736d9ea1db597d47803a2814e852a8cb52d070c4 Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 24 Dec 2024 00:17:06 -0800 Subject: [PATCH] Bump jax 38 (#846) --- .github/workflows/launch_small_fast.yaml | 2 +- .github/workflows/run_entry_tests.yaml | 2 +- .github/workflows/run_pre_commit.yaml | 2 +- .github/workflows/run_ray_tests.yaml | 2 +- .github/workflows/run_tests.yaml | 2 +- docker/tpu/Dockerfile.base | 3 +-- infra/helpers/setup-tpu-vm-tests.sh | 2 +- infra/helpers/setup-tpu-vm.sh | 2 +- 8 files changed, 8 insertions(+), 9 deletions(-) diff --git a/.github/workflows/launch_small_fast.yaml b/.github/workflows/launch_small_fast.yaml index 15f423674..584795201 100644 --- a/.github/workflows/launch_small_fast.yaml +++ b/.github/workflows/launch_small_fast.yaml @@ -41,7 +41,7 @@ jobs: - name: Install locally run: | python -m pip install --upgrade pip - pip install -e .[test] "jax[cpu]==0.4.30" + pip install -e .[test] "jax[cpu]==0.4.38" - name: Launch Small Fast TPU Train LM job run: | diff --git a/.github/workflows/run_entry_tests.yaml b/.github/workflows/run_entry_tests.yaml index ab08013ee..d9de2d815 100644 --- a/.github/workflows/run_entry_tests.yaml +++ b/.github/workflows/run_entry_tests.yaml @@ -9,7 +9,7 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["0.4.26"] + jax-version: ["0.4.38"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/run_pre_commit.yaml b/.github/workflows/run_pre_commit.yaml index ee3f0f587..842354ae0 100644 --- a/.github/workflows/run_pre_commit.yaml +++ b/.github/workflows/run_pre_commit.yaml @@ -10,7 +10,7 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["0.4.14"] + jax-version: ["0.4.38"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/run_ray_tests.yaml b/.github/workflows/run_ray_tests.yaml index 42139e576..a1788f777 100644 --- a/.github/workflows/run_ray_tests.yaml +++ b/.github/workflows/run_ray_tests.yaml @@ -9,7 +9,7 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["0.4.26"] + jax-version: ["0.4.38"] steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index 6e9ed7024..ac01bcf5e 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -10,7 +10,7 @@ jobs: strategy: matrix: python-version: ["3.10"] - jax-version: ["0.4.26"] + jax-version: ["0.4.38"] steps: - uses: actions/checkout@v3 diff --git a/docker/tpu/Dockerfile.base b/docker/tpu/Dockerfile.base index e2e032e82..09914eb79 100644 --- a/docker/tpu/Dockerfile.base +++ b/docker/tpu/Dockerfile.base @@ -5,8 +5,7 @@ RUN pip install virtualenv # venv binaries encode their directory, so we need to setup the venv in the final location RUN virtualenv -p python3.10 /opt/levanter/.venv ENV PATH /opt/levanter/.venv/bin:$PATH -#RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.34" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # Install package dependencies to make incremental builds faster. WORKDIR /tmp/ diff --git a/infra/helpers/setup-tpu-vm-tests.sh b/infra/helpers/setup-tpu-vm-tests.sh index 33c1c4add..a5cae20e1 100755 --- a/infra/helpers/setup-tpu-vm-tests.sh +++ b/infra/helpers/setup-tpu-vm-tests.sh @@ -105,7 +105,7 @@ pip install -U wheel # jax and jaxlib # libtpu sometimes has issues installing for clinical (probably firewall?) -retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # clone levanter git clone $REPO levanter diff --git a/infra/helpers/setup-tpu-vm.sh b/infra/helpers/setup-tpu-vm.sh index 3ca81d76b..5bca127e9 100755 --- a/infra/helpers/setup-tpu-vm.sh +++ b/infra/helpers/setup-tpu-vm.sh @@ -105,7 +105,7 @@ pip install -U wheel # jax and jaxlib # libtpu sometimes has issues installing for clinical (probably firewall?) -retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # clone levanter git clone $REPO levanter