Skip to content

Commit

Permalink
XLA calling convention fixes (#17)
Browse files Browse the repository at this point in the history
* Minor int and multifunction fixes

* continue

* [broken] wip

* cleaning up

* Restored ad functionality

* format

* Format checker

* Fix macos fname

* fix llvm namespace

* Handle cast return

* Fix single return

* Consider JaX arg elimination [primal]

* tmp

* continuing

* memstore fix

* continuing fixups

* handle memset_pattern16
  • Loading branch information
wsmoses authored Dec 10, 2023
1 parent e47ec7b commit c94ed35
Show file tree
Hide file tree
Showing 11 changed files with 1,520 additions and 358 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Clang-Format

on:
push:
pull_request:
merge_group:

jobs:
build:
name: Format
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: DoozyX/[email protected]
with:
source: 'enzyme_jax'
style: 'llvm'
clangFormatVersion: 16
16 changes: 8 additions & 8 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies")

rules_cc_dependencies()

LLVM_COMMIT = "5e5a22caf88ac1ccfa8dc5720295fdeba0ad9372"
LLVM_SHA256 = ""
LLVM_COMMIT = "668865789620f390fbad4d7093ed8ca6eb932c31"
LLVM_SHA256 = "8d7cbbe492a17656c09af1e79b802303f11cb47d64768760b70d52f11ed4d9da"
LLVM_TARGETS = ["X86", "AArch64", "AMDGPU", "NVPTX"]

http_archive(
Expand All @@ -30,8 +30,8 @@ http_archive(
load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
llvm_configure(name = "llvm-project", targets = LLVM_TARGETS)

XLA_COMMIT = "fecd5e7e9f00f4a197ad54206f2bc0ca1058c858"
XLA_SHA256 = ""
XLA_COMMIT = "a6e6c1f6a53d4a23451c649110519c7ba8581bf9"
XLA_SHA256 = "5fe6dfa30621bd50b022a6cab026d6f4cde9883a3e150ce1b6fd52822a57c59a"

http_archive(
name = "xla",
Expand Down Expand Up @@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen

pip_install_dependencies()

ENZYME_COMMIT = "77b4fff47701a240b537a93a2e722626f7421342"
ENZYME_SHA256 = ""
ENZYME_COMMIT = "cbb970161fd41ce55da028f0960a441382b07112"
ENZYME_SHA256 = "ec0450fdbc7f18cab46492acd3288b8347fa222317f9ff475768f5f10c45478c"

http_archive(
name = "enzyme",
Expand All @@ -70,8 +70,8 @@ http_archive(
urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)],
)

JAX_COMMIT = "32a317f7a43440800e1e39e00ed5f2980e088ab1"
JAX_SHA256 = "6e2147be7360a5c0672b6ba0d654cdb2ac96113b63ef457dfdc76cd50fe69ff1"
JAX_COMMIT = "f691fe468a8e1f8545f7d624055d58b823ee3201"
JAX_SHA256 = ""

http_archive(
name = "jax",
Expand Down
2 changes: 2 additions & 0 deletions enzyme_jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ py_library(
pybind_library(
name = "compile_with_xla",
srcs = ["compile_with_xla.cc"],
hdrs = ["compile_with_xla.h"],
deps = [
# This is similar to xla_binary rule and is needed to make XLA client compile.
"@tsl//tsl/framework:allocator",
Expand Down Expand Up @@ -70,6 +71,7 @@ pybind_library(
"@xla//xla/service:buffer_assignment_proto_cc",
"@xla//xla/service:buffer_assignment_proto_cc_impl",
"@xla//xla/service/cpu:cpu_executable",
"@xla//xla/service/cpu:backend_config_proto_cc",
"@xla//xla/service/gpu:backend_configs_cc",
"@xla//xla/service/gpu:backend_configs_cc_impl",
"@xla//xla/service:hlo_proto_cc",
Expand Down
Loading

0 comments on commit c94ed35

Please sign in to comment.