Skip to content

Commit

Permalink
Bumped jax_triton version and updated JAX/Triton requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
superbobry committed Oct 21, 2024
1 parent cd49678 commit 61e2c33
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 22 deletions.
28 changes: 10 additions & 18 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,25 @@ Check out the [JAX installation guide](https://github.com/google/jax#pip-install

### Installation at HEAD

JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run:
JAX-Triton is developed at JAX and jaxlib HEAD and close to Triton HEAD. To get
a bleeding edge installation of JAX-Triton, run:

```bash
$ pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'
```

This should install compatible versions of JAX and Triton.

JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release:
```bash
$ pip install jaxlib[cuda]
$ # or
$ pip install jaxlib[cuda11_pip]
$ # or
$ pip install jaxlib[cuda12_pip]
```
JAX-Triton requires jaxlib with GPU support. You could install the latest stable
release via

If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly.
To install a new jaxlib, you can find a link to a [CUDA 11 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html) or [CUDA 12 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html). Then install it via:
```bash
$ pip install 'jaxlib @ <link to nightly>'
```
or to install CUDA via pip automatically, you can do:
```bash
$ pip install 'jaxlib[cuda11_pip] @ <link to nightly>'
$ # or
$ pip install 'jaxlib[cuda12_pip] @ <link to nightly>'
$ pip install jaxlib[cuda12]
```

In rare cases JAX-Triton might need a nighly version of jaxlib. You can install
it following the instructions
[here](https://jax.readthedocs.io/en/latest/installation.html#jax-nightly-installation).

### Quickstart

Expand Down
5 changes: 4 additions & 1 deletion jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,10 @@ def triton_kernel_call_lowering(
named_args = dict(unsafe_zip(fn.arg_names, args))

if isinstance(fn, autotuner.Autotuner):
key_idxs = [fn.arg_names.index(k) for k in fn.keys]
if hasattr(fn, "key_idx"):
key_idxs = fn.key_idx # Triton <=3.2
else:
key_idxs = [fn.arg_names.index(k) for k in fn.keys]
if any(idx not in key_idxs for idx, _, _ in scalar_args):
logging.warning(
"Auto-tuning key does not include all scalar arguments. "
Expand Down
2 changes: 1 addition & 1 deletion jax_triton/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version_info__ = (0, 2, 0)
__version_info__ = (0, 3, 0)
__version__ = ".".join(str(v) for v in __version_info__)
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"absl-py>=1.4.0",
"jax>=0.4.31",
"triton>=3.0",
"jax>=0.4.34",
"triton>=3.1",
]

[project.optional-dependencies]
Expand Down

0 comments on commit 61e2c33

Please sign in to comment.