Skip to content

Commit

Permalink
Improve package layout and bazel build (#20)
Browse files Browse the repository at this point in the history
* Improve package layout and bazel build

Follow PEP420 and other common practices to set up a namespace package
`enzyme_ad` containing a subpackage `jax` that is specific for JAX
connection. This will allow one to use Enzyme bindings without
necessarily connecting to JAX in the future. Package import now starts
at `enzyme_ad` and JAX-related functionality can be obtained as
`from enzyme_ad import jax as enzyme_jax` if desired.

Update the wheel building rules accordingly and rename the distribution
to `enzyme_ad`. Declare a dependency on JAX until that functionality is
separated into a separate distribution.

* fix path

* Update signature

---------

Co-authored-by: William S. Moses <[email protected]>
  • Loading branch information
ftynse and wsmoses authored Dec 13, 2023
1 parent c94ed35 commit 4fdbf9f
Show file tree
Hide file tree
Showing 23 changed files with 54 additions and 47 deletions.
1 change: 1 addition & 0 deletions .bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6.2.1
4 changes: 2 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ steps:
chmod +x Miniconda*.sh
./Miniconda*.sh -b -p `pwd`/conda
rm Miniconda*.sh
elif [ "{{matrix.os}}" == "linux" ]; then
elif [ "{{matrix.os}}" == "linux" ]; then
if [ "{{matrix.arch}}" == "aarch64" ]; then
curl -fLO https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-arm64
else
Expand All @@ -59,7 +59,7 @@ steps:
python -m ensurepip --upgrade
python -m pip install --user numpy wheel
mkdir baztmp
bazel --output_user_root=`pwd`/baztmp build :enzyme_jax
bazel --output_user_root=`pwd`/baztmp build :enzyme_ad
cp bazel-bin/*.whl .
python -m pip install *.whl
cd test && python -m pip install "jax[cpu]" && python test.py
Expand Down
4 changes: 2 additions & 2 deletions .buildkite/secure_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ steps:
python -m pip install --user numpy wheel
mkdir baztmp
export TAG=`echo $BUILDKITE_TAG | cut -c2-`
sed -i.bak "s~version = \"[0-9.]*\"~version = \"\$TAG\"~g" BUILD
bazel --output_user_root=`pwd`/baztmp build :enzyme_jax
sed -i.bak "s~version = \"[0-9.]*\"~version = \"\$TAG\"~g" BUILD
bazel --output_user_root=`pwd`/baztmp build :enzyme_ad
cp bazel-bin/*.whl .
python -m pip install *.whl
cd test && python -m pip install "jax[cpu]" && python test.py && cd ..
Expand Down
Binary file modified .buildkite/secure_pipeline.yml.signature
Binary file not shown.
14 changes: 7 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ jobs:
fail-fast: false
matrix:
os: [openstack22]
timeout-minutes: 500
timeout-minutes: 500
steps:
- name: add llvm
- name: add llvm
run: |
if [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then
sudo apt-get update
Expand All @@ -36,21 +36,21 @@ jobs:
key: bazel-${{ matrix.os }}
- run: find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \;
- run: |
bazel build :enzyme_jax @llvm-project//llvm:FileCheck
bazel cquery "allpaths(//enzyme_jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps
bazel build :enzyme_ad @llvm-project//llvm:FileCheck
bazel cquery "allpaths(//src/enzyme_ad/jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps
bazel --version
nm -C $(find bazel-out/ -name enzyme_call.so -type f) | grep ExecutorCache::
- run: cp bazel-bin/*.whl .

- name: test
run: |
python3 -m pip install --user --force-reinstall "jax[cpu]" *.whl
cd test
nm -C $(python3 -c "from enzyme_jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache::
nm -C $(python3 -c "from enzyme_ad.jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache::
python3 test.py
cd lit_tests
lit . --verbose
- name: Upload Build
uses: actions/upload-artifact@v3
with:
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ on:
push:
pull_request:
merge_group:

jobs:
build:
name: Format
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: DoozyX/[email protected]
with:
source: 'enzyme_jax'
source: 'src'
style: 'llvm'
clangFormatVersion: 16
6 changes: 3 additions & 3 deletions .github/workflows/tag.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ jobs:
fail-fast: false
matrix:
os: [openstack22]
timeout-minutes: 500
timeout-minutes: 500
steps:
- name: add llvm
- name: add llvm
run: |
if [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then
sudo apt-get update
Expand All @@ -36,7 +36,7 @@ jobs:
path: "~/.cache/bazel"
key: bazel-${{ matrix.os }}
- run: find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \;
- run: bazel build :enzyme_jax
- run: bazel build :enzyme_ad
- env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
Expand Down
26 changes: 17 additions & 9 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,26 @@ package(
py_package(
name = "enzyme_jax_data",
deps = [
"//enzyme_jax:enzyme_call.so",
"//src/enzyme_ad/jax:enzyme_call.so",
"@llvm-project//clang:builtin_headers_gen",
],
# Only include these Python packages.
packages = ["@//enzyme_jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"],
prefix = "enzyme_jax/"
packages = ["@//src/enzyme_ad/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"],
)

py_wheel(
name = "enzyme_jax",
# Package data. We're building "example_minimal_package-0.0.1-py3-none-any.whl"
distribution = "enzyme_jax",
name = "enzyme_ad",
distribution = "enzyme_ad",
summary = "Enzyme automatic differentiation tool.",
homepage = "https://enzyme.mit.edu/",
project_urls = {
"GitHub": "https://github.com/EnzymeAD/Enzyme-JAX/",
},
author="Enzyme Authors",
license='LLVM',
license="LLVM",
author_email="[email protected], [email protected]",
python_tag = "py3",
version = "0.0.5",
version = "0.0.6",
platform = select({
"@bazel_tools//src/conditions:windows_x64": "win_amd64",
"@bazel_tools//src/conditions:darwin_arm64": "macosx_11_0_arm64",
Expand All @@ -36,5 +39,10 @@ py_wheel(
"@bazel_tools//src/conditions:linux_x86_64": "manylinux2014_x86_64",
"@bazel_tools//src/conditions:linux_ppc64le": "manylinux2014_ppc64le",
}),
deps = ["//enzyme_jax:enzyme_jax_internal", ":enzyme_jax_data"]
deps = ["//src/enzyme_ad/jax:enzyme_jax_internal", ":enzyme_jax_data"],
strip_path_prefixes = ["src/"],
requires = [
"jax >= 0.4.21",
"jaxlib >= 0.4.21",
],
)
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Enzyme-JAX

Custom bindings for Enzyme within JaX. Currently this is set up to allow you
to automatically import, and automatically differentiate (both jvp and vjp)
external C++ code into JaX. As Enzyme is language-agnostic, this can be extended
for arbitrary programming languages (Julia, Swift, Fortran, Rust, and even Python)!
Custom bindings for Enzyme automatic differentiation tool and interfacing with
JAX. Currently this is set up to allow you to automatically import, and
automatically differentiate (both jvp and vjp) external C++ code into JAX. As
Enzyme is language-agnostic, this can be extended for arbitrary programming
languages (Julia, Swift, Fortran, Rust, and even Python)!

You can use
You can use

```python
from enzyme_jax import cpp_call
from enzyme_ad.jax import cpp_call

# Forward-mode C++ AD example

Expand Down Expand Up @@ -48,13 +49,13 @@ Requirements: `bazel-6.2.1`, `clang++`, `python`, `python-virtualenv`,

Build our extension with:
```sh
# Will create a whl in bazel-bin/enzyme_jax-VERSION-SYSTEM.whl
bazel build :enzyme_jax
# Will create a whl in bazel-bin/enzyme_ad-VERSION-SYSTEM.whl
bazel build :enzyme_ad
```

Finally, install the built library with:
```sh
pip install bazel-bin/enzyme_jax-VERSION-SYSTEM.whl
pip install bazel-bin/enzyme_ad-VERSION-SYSTEM.whl
```
Note that you cannot run code from the root of the git directory. For instance, in the code below, you have to first run `cd test` before running `test.py`.

Expand Down
1 change: 0 additions & 1 deletion enzyme_jax/__init__.py

This file was deleted.

4 changes: 0 additions & 4 deletions package.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ py_package_lib = struct(
"deps": attr.label_list(
doc = "",
),
"prefix": attr.string(
doc = "Prefix",
mandatory = True,
),
"packages": attr.string_list(
mandatory = False,
allow_empty = True,
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from enzyme_ad.jax.primitives import cpp_call, enzyme_jax_ir
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
3 changes: 2 additions & 1 deletion enzyme_jax/primitives.py → src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
def resource_dir():
import os
dn = os.path.dirname(enzyme_call.__file__)
return os.path.join(dn, "..", "clang", "staging")
res = os.path.join(dn, "..", "..", "clang", "staging")
return res

def cflags():
import platform
Expand Down
2 changes: 1 addition & 1 deletion test/bench_vs_xla.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
import jax.numpy as jnp
from enzyme_jax import enzyme_jax_ir
from enzyme_ad.jax import enzyme_jax_ir


@enzyme_jax_ir()
Expand Down
2 changes: 1 addition & 1 deletion test/lit_tests/ir.pyt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax
import jax.numpy as jnp
from enzyme_jax import cpp_call
from enzyme_ad.jax import cpp_call

def do_something(ones, twos):
shape = jax.core.ShapedArray(tuple(3 * s for s in ones.shape), ones.dtype)
Expand Down
2 changes: 1 addition & 1 deletion test/llama.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax.numpy as jnp
import jax.random
import jax.lax
import enzyme_jax
import enzyme_ad.jax as enzyme_jax

def rmsnorm(x, weight):
ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5)
Expand Down
6 changes: 3 additions & 3 deletions test/test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import jax
import jax.numpy as jnp
from enzyme_jax import cpp_call
from enzyme_ad.jax import cpp_call

@jax.jit
def do_something(ones):
Expand Down Expand Up @@ -46,7 +46,7 @@ def do_something(ones):
print(grads)

# Test enzyme mlir jit
from enzyme_jax import enzyme_jax_ir
from enzyme_ad.jax import enzyme_jax_ir


@enzyme_jax_ir()
Expand All @@ -65,4 +65,4 @@ def add_one(x: jax.Array, y) -> jax.Array:
primals, f_vjp = jax.vjp(add_one, jnp.array([1., 2., 3.]), jnp.array([10., 20., 30.]))
grads = f_vjp(jnp.array([500., 700., 110.]))
print(primals)
print(grads)
print(grads)

0 comments on commit 4fdbf9f

Please sign in to comment.