Skip to content

Commit

Permalink
Bump jax 38 (#846)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Dec 24, 2024
1 parent d9678f5 commit 736d9ea
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/launch_small_fast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
- name: Install locally
run: |
python -m pip install --upgrade pip
pip install -e .[test] "jax[cpu]==0.4.30"
pip install -e .[test] "jax[cpu]==0.4.38"
- name: Launch Small Fast TPU Train LM job
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_entry_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.26"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_pre_commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.14"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_ray_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.26"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
strategy:
matrix:
python-version: ["3.10"]
jax-version: ["0.4.26"]
jax-version: ["0.4.38"]

steps:
- uses: actions/checkout@v3
Expand Down
3 changes: 1 addition & 2 deletions docker/tpu/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ RUN pip install virtualenv
# venv binaries encode their directory, so we need to setup the venv in the final location
RUN virtualenv -p python3.10 /opt/levanter/.venv
ENV PATH /opt/levanter/.venv/bin:$PATH
#RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.34" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN /opt/levanter/.venv/bin/pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install package dependencies to make incremental builds faster.
WORKDIR /tmp/
Expand Down
2 changes: 1 addition & 1 deletion infra/helpers/setup-tpu-vm-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pip install -U wheel

# jax and jaxlib
# libtpu sometimes has issues installing for clinical (probably firewall?)
retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# clone levanter
git clone $REPO levanter
Expand Down
2 changes: 1 addition & 1 deletion infra/helpers/setup-tpu-vm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pip install -U wheel

# jax and jaxlib
# libtpu sometimes has issues installing for clinical (probably firewall?)
retry pip install -U "jax[tpu]@git+https://github.com/dlwh/jax@retry_refuse" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
retru pip install -U "jax[tpu]==0.4.38" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# clone levanter
git clone $REPO levanter
Expand Down

0 comments on commit 736d9ea

Please sign in to comment.