From 4fdbf9fa1ba28d5ad6708bba2334b845e7acb019 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Wed, 13 Dec 2023 16:14:50 +0100 Subject: [PATCH] Improve package layout and bazel build (#20) * Improve package layout and bazel build Follow PEP420 and other common practices to set up a namespace package `enzyme_ad` containing a subpackage `jax` that is specific for JAX connection. This will allow one to use Enzyme bindings without necessarily connecting to JAX in the future. Package import now starts at `enzyme_ad` and JAX-related functionality can be obtained as `from enzyme_ad import jax as enzyme_jax` if desired. Update the wheel building rules accordingly and rename the distribution to `enzyme_ad`. Declare a dependency on JAX until that functionality is separated into a separate distribution. * fix path * Update signature --------- Co-authored-by: William S. Moses --- .bazelversion | 1 + .buildkite/pipeline.yml | 4 +-- .buildkite/secure_pipeline.yml | 4 +-- .buildkite/secure_pipeline.yml.signature | Bin 96 -> 96 bytes .github/workflows/build.yml | 14 +++++----- .github/workflows/format.yml | 6 ++-- .github/workflows/tag.yml | 6 ++-- BUILD | 26 ++++++++++++------ README.md | 19 +++++++------ enzyme_jax/__init__.py | 1 - package.bzl | 4 --- {enzyme_jax => src/enzyme_ad/jax}/BUILD | 0 src/enzyme_ad/jax/__init__.py | 1 + .../enzyme_ad/jax}/clang_compile.cc | 0 .../enzyme_ad/jax}/clang_compile.h | 0 .../enzyme_ad/jax}/compile_with_xla.cc | 0 .../enzyme_ad/jax}/compile_with_xla.h | 0 .../enzyme_ad/jax}/enzyme_call.cc | 0 .../enzyme_ad/jax}/primitives.py | 3 +- test/bench_vs_xla.py | 2 +- test/lit_tests/ir.pyt | 2 +- test/llama.py | 2 +- test/test.py | 6 ++-- 23 files changed, 54 insertions(+), 47 deletions(-) create mode 100644 .bazelversion delete mode 100644 enzyme_jax/__init__.py rename {enzyme_jax => src/enzyme_ad/jax}/BUILD (100%) create mode 100644 src/enzyme_ad/jax/__init__.py rename {enzyme_jax => src/enzyme_ad/jax}/clang_compile.cc (100%) rename {enzyme_jax => src/enzyme_ad/jax}/clang_compile.h (100%) rename {enzyme_jax => src/enzyme_ad/jax}/compile_with_xla.cc (100%) rename {enzyme_jax => src/enzyme_ad/jax}/compile_with_xla.h (100%) rename {enzyme_jax => src/enzyme_ad/jax}/enzyme_call.cc (100%) rename {enzyme_jax => src/enzyme_ad/jax}/primitives.py (99%) diff --git a/.bazelversion b/.bazelversion new file mode 100644 index 00000000..024b066c --- /dev/null +++ b/.bazelversion @@ -0,0 +1 @@ +6.2.1 diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index cd8044e4..049072fb 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -32,7 +32,7 @@ steps: chmod +x Miniconda*.sh ./Miniconda*.sh -b -p `pwd`/conda rm Miniconda*.sh - elif [ "{{matrix.os}}" == "linux" ]; then + elif [ "{{matrix.os}}" == "linux" ]; then if [ "{{matrix.arch}}" == "aarch64" ]; then curl -fLO https://github.com/bazelbuild/bazelisk/releases/download/v1.17.0/bazelisk-linux-arm64 else @@ -59,7 +59,7 @@ steps: python -m ensurepip --upgrade python -m pip install --user numpy wheel mkdir baztmp - bazel --output_user_root=`pwd`/baztmp build :enzyme_jax + bazel --output_user_root=`pwd`/baztmp build :enzyme_ad cp bazel-bin/*.whl . python -m pip install *.whl cd test && python -m pip install "jax[cpu]" && python test.py diff --git a/.buildkite/secure_pipeline.yml b/.buildkite/secure_pipeline.yml index 91b97da1..3f313dd0 100644 --- a/.buildkite/secure_pipeline.yml +++ b/.buildkite/secure_pipeline.yml @@ -66,8 +66,8 @@ steps: python -m pip install --user numpy wheel mkdir baztmp export TAG=`echo $BUILDKITE_TAG | cut -c2-` - sed -i.bak "s~version = \"[0-9.]*\"~version = \"\$TAG\"~g" BUILD - bazel --output_user_root=`pwd`/baztmp build :enzyme_jax + sed -i.bak "s~version = \"[0-9.]*\"~version = \"\$TAG\"~g" BUILD + bazel --output_user_root=`pwd`/baztmp build :enzyme_ad cp bazel-bin/*.whl . python -m pip install *.whl cd test && python -m pip install "jax[cpu]" && python test.py && cd .. diff --git a/.buildkite/secure_pipeline.yml.signature b/.buildkite/secure_pipeline.yml.signature index 571b249b5dd1b5ae828e6a9c5ac8c3d58874e0af..d04fa0e08b9b2498376adedb9ed321bad0b1bf09 100644 GIT binary patch literal 96 zcmV-m0H6O;VQh3|WM5y4eEE!ka#$5Td*QBjZ_!1A*#Ht5{9ifPRjKa0!SP&xMS0h3MlXe<}t861%M=Ez^8GyPh#n43jcH1Qf CD=TLJ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 506e54d5..7bb36e90 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -15,9 +15,9 @@ jobs: fail-fast: false matrix: os: [openstack22] - timeout-minutes: 500 + timeout-minutes: 500 steps: - - name: add llvm + - name: add llvm run: | if [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then sudo apt-get update @@ -36,21 +36,21 @@ jobs: key: bazel-${{ matrix.os }} - run: find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \; - run: | - bazel build :enzyme_jax @llvm-project//llvm:FileCheck - bazel cquery "allpaths(//enzyme_jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps + bazel build :enzyme_ad @llvm-project//llvm:FileCheck + bazel cquery "allpaths(//src/enzyme_ad/jax:enzyme_call,@xla//xla/stream_executor:executor_cache)" --notool_deps bazel --version nm -C $(find bazel-out/ -name enzyme_call.so -type f) | grep ExecutorCache:: - run: cp bazel-bin/*.whl . - + - name: test run: | python3 -m pip install --user --force-reinstall "jax[cpu]" *.whl cd test - nm -C $(python3 -c "from enzyme_jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache:: + nm -C $(python3 -c "from enzyme_ad.jax import enzyme_call; print(enzyme_call.__file__)") | grep ExecutorCache:: python3 test.py cd lit_tests lit . --verbose - + - name: Upload Build uses: actions/upload-artifact@v3 with: diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index c47a301b..d5861cd8 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -4,16 +4,16 @@ on: push: pull_request: merge_group: - + jobs: build: name: Format runs-on: ubuntu-latest - + steps: - uses: actions/checkout@v3 - uses: DoozyX/clang-format-lint-action@v0.16.2 with: - source: 'enzyme_jax' + source: 'src' style: 'llvm' clangFormatVersion: 16 diff --git a/.github/workflows/tag.yml b/.github/workflows/tag.yml index 89c7c7eb..a51000b1 100644 --- a/.github/workflows/tag.yml +++ b/.github/workflows/tag.yml @@ -12,9 +12,9 @@ jobs: fail-fast: false matrix: os: [openstack22] - timeout-minutes: 500 + timeout-minutes: 500 steps: - - name: add llvm + - name: add llvm run: | if [ "$(expr substr $(uname -s) 1 5)" == "Linux" ]; then sudo apt-get update @@ -36,7 +36,7 @@ jobs: path: "~/.cache/bazel" key: bazel-${{ matrix.os }} - run: find ~/.cache/bazel ~/.cache/bazelisk -iname "*.whl" -exec rm {} \; - - run: bazel build :enzyme_jax + - run: bazel build :enzyme_ad - env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} diff --git a/BUILD b/BUILD index 880271ee..286c625a 100644 --- a/BUILD +++ b/BUILD @@ -11,23 +11,26 @@ package( py_package( name = "enzyme_jax_data", deps = [ - "//enzyme_jax:enzyme_call.so", + "//src/enzyme_ad/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen", ], # Only include these Python packages. - packages = ["@//enzyme_jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"], - prefix = "enzyme_jax/" + packages = ["@//src/enzyme_ad/jax:enzyme_call.so", "@llvm-project//clang:builtin_headers_gen"], ) py_wheel( - name = "enzyme_jax", - # Package data. We're building "example_minimal_package-0.0.1-py3-none-any.whl" - distribution = "enzyme_jax", + name = "enzyme_ad", + distribution = "enzyme_ad", + summary = "Enzyme automatic differentiation tool.", + homepage = "https://enzyme.mit.edu/", + project_urls = { + "GitHub": "https://github.com/EnzymeAD/Enzyme-JAX/", + }, author="Enzyme Authors", - license='LLVM', + license="LLVM", author_email="wmoses@mit.edu, zinenko@google.com", python_tag = "py3", - version = "0.0.5", + version = "0.0.6", platform = select({ "@bazel_tools//src/conditions:windows_x64": "win_amd64", "@bazel_tools//src/conditions:darwin_arm64": "macosx_11_0_arm64", @@ -36,5 +39,10 @@ py_wheel( "@bazel_tools//src/conditions:linux_x86_64": "manylinux2014_x86_64", "@bazel_tools//src/conditions:linux_ppc64le": "manylinux2014_ppc64le", }), - deps = ["//enzyme_jax:enzyme_jax_internal", ":enzyme_jax_data"] + deps = ["//src/enzyme_ad/jax:enzyme_jax_internal", ":enzyme_jax_data"], + strip_path_prefixes = ["src/"], + requires = [ + "jax >= 0.4.21", + "jaxlib >= 0.4.21", + ], ) diff --git a/README.md b/README.md index 1929c0bf..1ae9f714 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,15 @@ # Enzyme-JAX -Custom bindings for Enzyme within JaX. Currently this is set up to allow you -to automatically import, and automatically differentiate (both jvp and vjp) -external C++ code into JaX. As Enzyme is language-agnostic, this can be extended -for arbitrary programming languages (Julia, Swift, Fortran, Rust, and even Python)! +Custom bindings for Enzyme automatic differentiation tool and interfacing with +JAX. Currently this is set up to allow you to automatically import, and +automatically differentiate (both jvp and vjp) external C++ code into JAX. As +Enzyme is language-agnostic, this can be extended for arbitrary programming +languages (Julia, Swift, Fortran, Rust, and even Python)! -You can use +You can use ```python -from enzyme_jax import cpp_call +from enzyme_ad.jax import cpp_call # Forward-mode C++ AD example @@ -48,13 +49,13 @@ Requirements: `bazel-6.2.1`, `clang++`, `python`, `python-virtualenv`, Build our extension with: ```sh -# Will create a whl in bazel-bin/enzyme_jax-VERSION-SYSTEM.whl -bazel build :enzyme_jax +# Will create a whl in bazel-bin/enzyme_ad-VERSION-SYSTEM.whl +bazel build :enzyme_ad ``` Finally, install the built library with: ```sh -pip install bazel-bin/enzyme_jax-VERSION-SYSTEM.whl +pip install bazel-bin/enzyme_ad-VERSION-SYSTEM.whl ``` Note that you cannot run code from the root of the git directory. For instance, in the code below, you have to first run `cd test` before running `test.py`. diff --git a/enzyme_jax/__init__.py b/enzyme_jax/__init__.py deleted file mode 100644 index 201fa273..00000000 --- a/enzyme_jax/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from enzyme_jax.primitives import cpp_call, enzyme_jax_ir diff --git a/package.bzl b/package.bzl index c86a93cb..8e3b0411 100644 --- a/package.bzl +++ b/package.bzl @@ -22,10 +22,6 @@ py_package_lib = struct( "deps": attr.label_list( doc = "", ), - "prefix": attr.string( - doc = "Prefix", - mandatory = True, - ), "packages": attr.string_list( mandatory = False, allow_empty = True, diff --git a/enzyme_jax/BUILD b/src/enzyme_ad/jax/BUILD similarity index 100% rename from enzyme_jax/BUILD rename to src/enzyme_ad/jax/BUILD diff --git a/src/enzyme_ad/jax/__init__.py b/src/enzyme_ad/jax/__init__.py new file mode 100644 index 00000000..0519e33a --- /dev/null +++ b/src/enzyme_ad/jax/__init__.py @@ -0,0 +1 @@ +from enzyme_ad.jax.primitives import cpp_call, enzyme_jax_ir diff --git a/enzyme_jax/clang_compile.cc b/src/enzyme_ad/jax/clang_compile.cc similarity index 100% rename from enzyme_jax/clang_compile.cc rename to src/enzyme_ad/jax/clang_compile.cc diff --git a/enzyme_jax/clang_compile.h b/src/enzyme_ad/jax/clang_compile.h similarity index 100% rename from enzyme_jax/clang_compile.h rename to src/enzyme_ad/jax/clang_compile.h diff --git a/enzyme_jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc similarity index 100% rename from enzyme_jax/compile_with_xla.cc rename to src/enzyme_ad/jax/compile_with_xla.cc diff --git a/enzyme_jax/compile_with_xla.h b/src/enzyme_ad/jax/compile_with_xla.h similarity index 100% rename from enzyme_jax/compile_with_xla.h rename to src/enzyme_ad/jax/compile_with_xla.h diff --git a/enzyme_jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc similarity index 100% rename from enzyme_jax/enzyme_call.cc rename to src/enzyme_ad/jax/enzyme_call.cc diff --git a/enzyme_jax/primitives.py b/src/enzyme_ad/jax/primitives.py similarity index 99% rename from enzyme_jax/primitives.py rename to src/enzyme_ad/jax/primitives.py index 2fffc767..99a9cfa8 100644 --- a/enzyme_jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -25,7 +25,8 @@ def resource_dir(): import os dn = os.path.dirname(enzyme_call.__file__) - return os.path.join(dn, "..", "clang", "staging") + res = os.path.join(dn, "..", "..", "clang", "staging") + return res def cflags(): import platform diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index 617f658f..63a2bd9d 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -1,6 +1,6 @@ import jax import jax.numpy as jnp -from enzyme_jax import enzyme_jax_ir +from enzyme_ad.jax import enzyme_jax_ir @enzyme_jax_ir() diff --git a/test/lit_tests/ir.pyt b/test/lit_tests/ir.pyt index c7237b10..71e538c2 100644 --- a/test/lit_tests/ir.pyt +++ b/test/lit_tests/ir.pyt @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp -from enzyme_jax import cpp_call +from enzyme_ad.jax import cpp_call def do_something(ones, twos): shape = jax.core.ShapedArray(tuple(3 * s for s in ones.shape), ones.dtype) diff --git a/test/llama.py b/test/llama.py index f90aa28a..b0fe50b7 100644 --- a/test/llama.py +++ b/test/llama.py @@ -1,7 +1,7 @@ import jax.numpy as jnp import jax.random import jax.lax -import enzyme_jax +import enzyme_ad.jax as enzyme_jax def rmsnorm(x, weight): ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5) diff --git a/test/test.py b/test/test.py index ac562b9c..2eb5866d 100644 --- a/test/test.py +++ b/test/test.py @@ -1,6 +1,6 @@ import jax import jax.numpy as jnp -from enzyme_jax import cpp_call +from enzyme_ad.jax import cpp_call @jax.jit def do_something(ones): @@ -46,7 +46,7 @@ def do_something(ones): print(grads) # Test enzyme mlir jit -from enzyme_jax import enzyme_jax_ir +from enzyme_ad.jax import enzyme_jax_ir @enzyme_jax_ir() @@ -65,4 +65,4 @@ def add_one(x: jax.Array, y) -> jax.Array: primals, f_vjp = jax.vjp(add_one, jnp.array([1., 2., 3.]), jnp.array([10., 20., 30.])) grads = f_vjp(jnp.array([500., 700., 110.])) print(primals) -print(grads) \ No newline at end of file +print(grads)