From c94ed3561fbbbfa291ec6d0cdbfb10a06edb622b Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 10 Dec 2023 15:14:39 -0600 Subject: [PATCH] XLA calling convention fixes (#17) * Minor int and multifunction fixes * continue * [broken] wip * cleaning up * Restored ad functionality * format * Format checker * Fix macos fname * fix llvm namespace * Handle cast return * Fix single return * Consider JaX arg elimination [primal] * tmp * continuing * memstore fix * continuing fixups * handle memset_pattern16 --- .github/workflows/format.yml | 19 + WORKSPACE | 16 +- enzyme_jax/BUILD | 2 + enzyme_jax/clang_compile.cc | 230 ++++++--- enzyme_jax/clang_compile.h | 10 +- enzyme_jax/compile_with_xla.cc | 96 +++- enzyme_jax/compile_with_xla.h | 7 + enzyme_jax/enzyme_call.cc | 915 +++++++++++++++++++++++++-------- enzyme_jax/primitives.py | 135 +++-- test/bench_vs_xla.py | 154 +++++- test/llama.py | 294 +++++++++++ 11 files changed, 1520 insertions(+), 358 deletions(-) create mode 100644 .github/workflows/format.yml create mode 100644 enzyme_jax/compile_with_xla.h create mode 100644 test/llama.py diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 00000000..c47a301b --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,19 @@ +name: Clang-Format + +on: + push: + pull_request: + merge_group: + +jobs: + build: + name: Format + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - uses: DoozyX/clang-format-lint-action@v0.16.2 + with: + source: 'enzyme_jax' + style: 'llvm' + clangFormatVersion: 16 diff --git a/WORKSPACE b/WORKSPACE index 92cb3e42..a08ba1f0 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -13,8 +13,8 @@ load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies") rules_cc_dependencies() -LLVM_COMMIT = "5e5a22caf88ac1ccfa8dc5720295fdeba0ad9372" -LLVM_SHA256 = "" +LLVM_COMMIT = "668865789620f390fbad4d7093ed8ca6eb932c31" +LLVM_SHA256 = "8d7cbbe492a17656c09af1e79b802303f11cb47d64768760b70d52f11ed4d9da" LLVM_TARGETS = ["X86", "AArch64", "AMDGPU", "NVPTX"] http_archive( @@ -30,8 +30,8 @@ http_archive( load("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") llvm_configure(name = "llvm-project", targets = LLVM_TARGETS) -XLA_COMMIT = "fecd5e7e9f00f4a197ad54206f2bc0ca1058c858" -XLA_SHA256 = "" +XLA_COMMIT = "a6e6c1f6a53d4a23451c649110519c7ba8581bf9" +XLA_SHA256 = "5fe6dfa30621bd50b022a6cab026d6f4cde9883a3e150ce1b6fd52822a57c59a" http_archive( name = "xla", @@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen pip_install_dependencies() -ENZYME_COMMIT = "77b4fff47701a240b537a93a2e722626f7421342" -ENZYME_SHA256 = "" +ENZYME_COMMIT = "cbb970161fd41ce55da028f0960a441382b07112" +ENZYME_SHA256 = "ec0450fdbc7f18cab46492acd3288b8347fa222317f9ff475768f5f10c45478c" http_archive( name = "enzyme", @@ -70,8 +70,8 @@ http_archive( urls = ["https://github.com/EnzymeAD/Enzyme/archive/{commit}.tar.gz".format(commit = ENZYME_COMMIT)], ) -JAX_COMMIT = "32a317f7a43440800e1e39e00ed5f2980e088ab1" -JAX_SHA256 = "6e2147be7360a5c0672b6ba0d654cdb2ac96113b63ef457dfdc76cd50fe69ff1" +JAX_COMMIT = "f691fe468a8e1f8545f7d624055d58b823ee3201" +JAX_SHA256 = "" http_archive( name = "jax", diff --git a/enzyme_jax/BUILD b/enzyme_jax/BUILD index bb111b0c..3ee8465c 100644 --- a/enzyme_jax/BUILD +++ b/enzyme_jax/BUILD @@ -42,6 +42,7 @@ py_library( pybind_library( name = "compile_with_xla", srcs = ["compile_with_xla.cc"], + hdrs = ["compile_with_xla.h"], deps = [ # This is similar to xla_binary rule and is needed to make XLA client compile. "@tsl//tsl/framework:allocator", @@ -70,6 +71,7 @@ pybind_library( "@xla//xla/service:buffer_assignment_proto_cc", "@xla//xla/service:buffer_assignment_proto_cc_impl", "@xla//xla/service/cpu:cpu_executable", + "@xla//xla/service/cpu:backend_config_proto_cc", "@xla//xla/service/gpu:backend_configs_cc", "@xla//xla/service/gpu:backend_configs_cc_impl", "@xla//xla/service:hlo_proto_cc", diff --git a/enzyme_jax/clang_compile.cc b/enzyme_jax/clang_compile.cc index fce4dfcc..c84e057d 100644 --- a/enzyme_jax/clang_compile.cc +++ b/enzyme_jax/clang_compile.cc @@ -9,16 +9,18 @@ #include "clang_compile.h" #include "llvm/IRReader/IRReader.h" +#include #include #include -#include -#include -#include -#include -#include #include +#include +#include +#include #include +#include +#include "clang/CodeGen/CodeGenAction.h" +#include "llvm-c/Core.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringRef.h" @@ -27,15 +29,13 @@ #include "llvm/AsmParser/LLToken.h" #include "llvm/AsmParser/Parser.h" #include "llvm/AsmParser/SlotMapping.h" -#include "llvm-c/Core.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" -#include "clang/CodeGen/CodeGenAction.h" +#include "llvm/Support/raw_ostream.h" #include "clang/AST/Decl.h" #include "clang/Basic/DiagnosticOptions.h" @@ -52,21 +52,21 @@ #include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/CompilerInvocation.h" #include "clang/Frontend/FrontendOptions.h" +#include "clang/Frontend/TextDiagnosticBuffer.h" #include "clang/Frontend/TextDiagnosticPrinter.h" #include "clang/Frontend/Utils.h" +#include "clang/FrontendTool/Utils.h" #include "clang/Parse/ParseAST.h" #include "clang/Parse/Parser.h" #include "clang/Sema/Sema.h" #include "clang/Sema/SemaDiagnostic.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/Linker/Linker.h" +#include "llvm/MC/TargetRegistry.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include "clang/Frontend/TextDiagnosticBuffer.h" #include "llvm/Support/Host.h" -#include "clang/FrontendTool/Utils.h" -#include "llvm/MC/TargetRegistry.h" -#include "llvm/CodeGen/CommandFlags.h" #include "llvm/Support/MemoryBufferRef.h" -#include "llvm/Linker/Linker.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" @@ -74,6 +74,7 @@ #include #include "Enzyme/Enzyme.h" +#include "Enzyme/Utils.h" namespace clang { namespace driver { @@ -86,9 +87,9 @@ namespace tools { void addDirectoryList(const llvm::opt::ArgList &Args, llvm::opt::ArgStringList &CmdArgs, const char *ArgName, const char *EnvVar); -} -} -} +} // namespace tools +} // namespace driver +} // namespace clang using namespace clang; using namespace llvm; @@ -127,7 +128,7 @@ class ArgumentList { /// /// The return value of this operation could be invalidated by subsequent /// calls to push_back() or emplace_back(). - llvm::opt::ArgStringList& getArguments() { return Args; } + llvm::opt::ArgStringList &getArguments() { return Args; } }; /* @@ -148,9 +149,10 @@ PYBIND11_DECLARE_HOLDER_TYPE(T, ptr_wrapper, true); */ // Returns the TargetMachine instance or zero if no triple is provided. -static TargetMachine* GetTargetMachine(llvm::Triple TheTriple, StringRef CPUStr, +static TargetMachine *GetTargetMachine(llvm::Triple TheTriple, StringRef CPUStr, StringRef FeaturesStr, - const llvm::TargetOptions &Options, CodeGenOptLevel level) { + const llvm::TargetOptions &Options, + CodeGenOptLevel level) { std::string Error; const Target *TheTarget = TargetRegistry::lookupTarget(codegen::getMArch(), TheTriple, Error); @@ -165,9 +167,12 @@ static TargetMachine* GetTargetMachine(llvm::Triple TheTriple, StringRef CPUStr, codegen::getExplicitCodeModel(), level); } -std::unique_ptr GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp, ArrayRef pyargv, LLVMContext* Context, std::unique_ptr linkMod) { - const llvm::opt::InputArgList Args; - const char *binary = cpp ? "clang++" : "clang"; +std::unique_ptr +GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp, + ArrayRef pyargv, LLVMContext *Context, + std::unique_ptr linkMod) { + const llvm::opt::InputArgList Args; + const char *binary = cpp ? "clang++" : "clang"; // Buffer diagnostics from argument parsing so that we can output them using a // well formed diagnostic object. IntrusiveRefCntPtr DiagOpts = new DiagnosticOptions(); @@ -179,15 +184,15 @@ std::unique_ptr GetLLVMFromJob(std::string filename, std::string f IntrusiveRefCntPtr DiagOpts0 = new DiagnosticOptions(); IntrusiveRefCntPtr DiagID0(new DiagnosticIDs()); DiagnosticsEngine Diags0(DiagID0, &*DiagOpts0, DiagsBuffer0); - const std::unique_ptr driver( - new clang::driver::Driver(binary, llvm::sys::getDefaultTargetTriple(), Diags0)); + const std::unique_ptr driver(new clang::driver::Driver( + binary, llvm::sys::getDefaultTargetTriple(), Diags0)); ArgumentList Argv; - + Argv.emplace_back(StringRef(filename)); for (auto v : pyargv) Argv.emplace_back(v); - SmallVector PreArgs; + SmallVector PreArgs; PreArgs.push_back(binary); PreArgs.append(Argv.getArguments()); PreArgs[1] = "-"; @@ -204,27 +209,34 @@ std::unique_ptr GetLLVMFromJob(std::string filename, std::string f // frontend into the driver. It will allow deleting 4 otherwise unused flags. // CPATH - included following the user specified includes (but prior to // builtin and standard includes). - clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-I", "CPATH"); + clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-I", + "CPATH"); // C_INCLUDE_PATH - system includes enabled when compiling C. - clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-c-isystem", "C_INCLUDE_PATH"); + clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), + "-c-isystem", "C_INCLUDE_PATH"); // CPLUS_INCLUDE_PATH - system includes enabled when compiling C++. - clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-cxx-isystem", "CPLUS_INCLUDE_PATH"); + clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), + "-cxx-isystem", "CPLUS_INCLUDE_PATH"); // OBJC_INCLUDE_PATH - system includes enabled when compiling ObjC. - clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-objc-isystem", "OBJC_INCLUDE_PATH"); + clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), + "-objc-isystem", "OBJC_INCLUDE_PATH"); // OBJCPLUS_INCLUDE_PATH - system includes enabled when compiling ObjC++. - clang::driver::tools::addDirectoryList(Args, Argv.getArguments(), "-objcxx-isystem", "OBJCPLUS_INCLUDE_PATH"); + clang::driver::tools::addDirectoryList( + Args, Argv.getArguments(), "-objcxx-isystem", "OBJCPLUS_INCLUDE_PATH"); auto &TC = compilation->getDefaultToolChain(); if (cpp) { - bool HasStdlibxxIsystem = false; // Args.hasArg(options::OPT_stdlibxx_isystem); - HasStdlibxxIsystem ? TC.AddClangCXXStdlibIsystemArgs(Args, Argv.getArguments()) - : TC.AddClangCXXStdlibIncludeArgs(Args, Argv.getArguments()); + bool HasStdlibxxIsystem = + false; // Args.hasArg(options::OPT_stdlibxx_isystem); + HasStdlibxxIsystem + ? TC.AddClangCXXStdlibIsystemArgs(Args, Argv.getArguments()) + : TC.AddClangCXXStdlibIncludeArgs(Args, Argv.getArguments()); } - TC.AddClangSystemIncludeArgs(Args, Argv.getArguments()); - + TC.AddClangSystemIncludeArgs(Args, Argv.getArguments()); + SmallVector outputvec; - + std::unique_ptr Clang(new CompilerInstance()); // Register the support for object-file-wrapped Clang modules. @@ -232,20 +244,27 @@ std::unique_ptr GetLLVMFromJob(std::string filename, std::string f // PCHOps->registerWriter(std::make_unique()); // PCHOps->registerReader(std::make_unique()); + auto baseFS = createVFSFromCompilerInvocation(Clang->getInvocation(), Diags); - auto baseFS = createVFSFromCompilerInvocation(Clang->getInvocation(), - Diags); - - IntrusiveRefCntPtr fs(new llvm::vfs::InMemoryFileSystem()); + IntrusiveRefCntPtr fs( + new llvm::vfs::InMemoryFileSystem()); struct tm y2k = {}; - y2k.tm_hour = 0; y2k.tm_min = 0; y2k.tm_sec = 0; - y2k.tm_year = 100; y2k.tm_mon = 0; y2k.tm_mday = 1; + y2k.tm_hour = 0; + y2k.tm_min = 0; + y2k.tm_sec = 0; + y2k.tm_year = 100; + y2k.tm_mon = 0; + y2k.tm_mday = 1; time_t timer = mktime(&y2k); - fs->addFile(filename, timer, llvm::MemoryBuffer::getMemBuffer(filecontents, filename, /*RequiresNullTerminator*/false)); - fs->addFile("/enzyme/enzyme/utils", timer, llvm::MemoryBuffer::getMemBuffer(R"( + fs->addFile(filename, timer, + llvm::MemoryBuffer::getMemBuffer( + filecontents, filename, /*RequiresNullTerminator*/ false)); + fs->addFile("/enzyme/enzyme/utils", timer, + llvm::MemoryBuffer::getMemBuffer( + R"( namespace enzyme { template RT __enzyme_fwddiff(Args...); @@ -258,14 +277,18 @@ namespace enzyme { template std::size_t __enzyme_augmentsize(Args...); } +extern "C" void prevent_stores(void*, ...); extern "C" int enzyme_dup; extern "C" int enzyme_const; extern "C" int enzyme_dupnoneed; extern "C" int enzyme_nooverwrite; extern "C" int enzyme_tape; extern "C" int enzyme_allocated; - )", "/enzyme/enzyme/utils", /*RequiresNullTerminator*/false)); - fs->addFile("/enzyme/enzyme/tensor", timer, llvm::MemoryBuffer::getMemBuffer(R"( + )", + "/enzyme/enzyme/utils", /*RequiresNullTerminator*/ false)); + fs->addFile("/enzyme/enzyme/tensor", timer, + llvm::MemoryBuffer::getMemBuffer( + R"( #include #include namespace enzyme { @@ -445,26 +468,28 @@ struct tensor }; } - )", "/enzyme/enzyme/tensor", /*RequiresNullTerminator*/false)); + )", + "/enzyme/enzyme/tensor", /*RequiresNullTerminator*/ false)); - std::unique_ptr outputStream(new llvm::raw_svector_ostream(outputvec)); + std::unique_ptr outputStream( + new llvm::raw_svector_ostream(outputvec)); Clang->setOutputStream(std::move(outputStream)); - IntrusiveRefCntPtr fuseFS(new llvm::vfs::OverlayFileSystem(baseFS)); + IntrusiveRefCntPtr fuseFS( + new llvm::vfs::OverlayFileSystem(baseFS)); fuseFS->pushOverlay(fs); fuseFS->pushOverlay(baseFS); Clang->createFileManager(fuseFS); - - bool Success = CompilerInvocation::CreateFromArgs(Clang->getInvocation(), - Argv.getArguments(), Diags, binary); + bool Success = CompilerInvocation::CreateFromArgs( + Clang->getInvocation(), Argv.getArguments(), Diags, binary); // Infer the builtin include path if unspecified. if (Clang->getHeaderSearchOpts().UseBuiltinIncludes && Clang->getHeaderSearchOpts().ResourceDir.empty()) Clang->getHeaderSearchOpts().ResourceDir = - CompilerInvocation::GetResourcesPath(binary, /*MainAddr*/0x0); + CompilerInvocation::GetResourcesPath(binary, /*MainAddr*/ 0x0); // Create the actual diagnostics engine. Clang->createDiagnostics(); @@ -500,11 +525,13 @@ struct tensor } for (auto &f : *mod) { - if (f.empty()) continue; - if (f.getName() == "entry") continue; + if (f.empty()) + continue; + if (f.getName() == "entry") + continue; f.setLinkage(Function::LinkageTypes::InternalLinkage); } - + PipelineTuningOptions PTO; LoopAnalysisManager LAM; FunctionAnalysisManager FAM; @@ -518,13 +545,14 @@ struct tensor llvm::driver::createTLII(triple, Clang->getCodeGenOpts().getVecLib())); FAM.registerPass([&] { return TargetLibraryAnalysis(*TLII); }); - - auto level = CodeGenOptLevel::Aggressive; //OptimizationLevel::O3; + auto level = CodeGenOptLevel::Aggressive; // OptimizationLevel::O3; Triple ModuleTriple(mod->getTargetTriple()); std::string CPUStr, FeaturesStr; - auto ETM = llvm::orc::JITTargetMachineBuilder(llvm::Triple(mod->getTargetTriple())).createTargetMachine (); + auto ETM = + llvm::orc::JITTargetMachineBuilder(llvm::Triple(mod->getTargetTriple())) + .createTargetMachine(); if (!ETM) { throw pybind11::value_error("failed to create targetmachine"); } @@ -546,6 +574,84 @@ struct tensor ModulePassManager MPM; PB.parsePassPipeline(MPM, "default"); MPM.run(*mod, MAM); + + auto F = mod->getFunction("prevent_stores"); + if (F) { + for (const auto user : llvm::make_early_inc_range(F->users())) { + auto CI = dyn_cast(user); + if (!CI) + continue; + std::deque> todo; + SmallVector cargs; + for (auto &arg : CI->args()) + cargs.push_back(arg); + CI->eraseFromParent(); + for (auto &arg : cargs) { + Value *cur = getBaseObject(arg); + assert(isa(cur)); + for (auto U : cur->users()) + todo.emplace_back(U, cur); + } + std::set> seen; + SmallPtrSet toErase; + while (todo.size()) { + auto pair = todo.back(); + todo.pop_back(); + auto [cur, prev] = pair; + if (seen.count(pair)) + continue; + seen.insert(pair); + if (isPointerArithmeticInst(cur)) { + for (auto u : cur->users()) + todo.emplace_back(u, cur); + continue; + } + if (isa(cur)) + continue; + if (auto MTI = dyn_cast(cur)) { + if (MTI->getSource() == prev) + continue; + } + if (auto CI = dyn_cast(cur)) + if (auto F = CI->getCalledFunction()) + if (F->getName() == "memset_pattern16") + continue; + if (auto MS = dyn_cast(cur)) { + toErase.insert(MS); + continue; + } + if (auto II = dyn_cast(cur)) { + if (II->getIntrinsicID() == llvm::Intrinsic::dbg_value) + continue; + } + if (isa(cur)) + continue; + if (auto SI = dyn_cast(cur)) { + assert(SI->getPointerOperand() == prev); + auto C = dyn_cast(SI->getValueOperand()); + if (C && C->isNullValue()) { + } else if (auto CF = dyn_cast_or_null(C)) { + assert(CF->isZero()); + } else { + llvm::errs() << "SI: " << *SI << " C: " << *SI->getValueOperand() + << "\n"; + assert(0); + } + toErase.insert(SI); + continue; + } + std::string err_str; + llvm::raw_string_ostream ss(err_str); + ss << *mod << "\n"; + ss << " unsupported value to erase:\n"; + ss << " cur: " << *cur << " prev: " << *prev << "\n"; + throw pybind11::value_error(ss.str()); + } + for (auto I : toErase) { + I->eraseFromParent(); + } + } + } + return mod; } - diff --git a/enzyme_jax/clang_compile.h b/enzyme_jax/clang_compile.h index ff9a32c1..7dad0fcd 100644 --- a/enzyme_jax/clang_compile.h +++ b/enzyme_jax/clang_compile.h @@ -9,10 +9,14 @@ #ifndef ENZYME_JAX_CLANG_COMPILE_H #define ENZYME_JAX_CLANG_COMPILE_H +#include "llvm/IR/Module.h" #include #include -#include "llvm/IR/Module.h" -std::unique_ptr GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp, llvm::ArrayRef pyargv, llvm::LLVMContext*ctx=nullptr, std::unique_ptr linkMod=nullptr); +std::unique_ptr +GetLLVMFromJob(std::string filename, std::string filecontents, bool cpp, + llvm::ArrayRef pyargv, + llvm::LLVMContext *ctx = nullptr, + std::unique_ptr linkMod = nullptr); -#endif // ENZYME_JAX_CLANG_COMPILE_H +#endif // ENZYME_JAX_CLANG_COMPILE_H diff --git a/enzyme_jax/compile_with_xla.cc b/enzyme_jax/compile_with_xla.cc index 2b8ad74e..c6e24b3b 100644 --- a/enzyme_jax/compile_with_xla.cc +++ b/enzyme_jax/compile_with_xla.cc @@ -1,13 +1,12 @@ -//===----------------------------------------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// +#define protected public +#include "xla/service/service.h" +#undef protected -#include -#include +#include "xla/service/cpu/cpu_executable.h" +#include "xla/service/local_service_utils.h" + +#include "absl/status/statusor.h" +#include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -15,14 +14,22 @@ #include "xla/client/client_library.h" #include "xla/client/executable_build_options.h" #include "xla/client/xla_computation.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/printer.h" +#include "xla/service/cpu/backend_config.pb.h" #include "xla/service/cpu/cpu_executable.h" #include "xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/translate/mhlo_to_hlo/type_to_shape.h" +#include "pybind11/pybind11.h" + +#include "compile_with_xla.h" + // Compile an MHLO module given as a string to LLVM IR using XLA. -absl::StatusOr compile_mhlo_to_llvm_with_xla( - llvm::StringRef mhlo_text) { +std::unique_ptr +compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output) { // Parse MLIR. mlir::MLIRContext context; context.loadDialect(); @@ -36,6 +43,21 @@ absl::StatusOr compile_mhlo_to_llvm_with_xla( xla::HloProto hlo_proto; mlir::ConvertMlirHloToHlo(*parsed_module, &hlo_proto, /*use_tuple_args=*/false, /*return_tuple=*/false); + + for (auto &computation : + *hlo_proto.mutable_hlo_module()->mutable_computations()) { + if (computation.id() != hlo_proto.hlo_module().entry_computation_id()) + continue; + // Assume root is the last instruction. + xla::HloInstructionProto &instruction = + *computation.mutable_instructions()->rbegin(); + xla::cpu::BackendConfig backend_config; + backend_config.ParseFromString(instruction.backend_config()); + backend_config.Clear(); + instruction.set_backend_config(backend_config.SerializeAsString()); + break; + } + xla::XlaComputation xla_computation(hlo_proto.hlo_module()); // Extract and convert the shapes fro MHLO. @@ -60,19 +82,51 @@ absl::StatusOr compile_mhlo_to_llvm_with_xla( // generation only to throw away the binary. absl::StatusOr local_client_or_error = xla::ClientLibrary::GetOrCreateLocalClient(); - if (!local_client_or_error.ok()) return local_client_or_error.status(); + if (!local_client_or_error.ok()) { + throw pybind11::value_error(local_client_or_error.status().ToString()); + } xla::LocalClient *local_client = local_client_or_error.value(); + xla::ExecutableBuildOptions build_options; build_options.mutable_debug_options()->set_xla_embed_ir_in_executable(true); - absl::StatusOr>> - local_executables = - local_client->Compile(xla_computation, shape_pointers, build_options); - if (!local_executables.ok()) return local_executables.status(); - // Extract the LLVM IR stored in the executable. - xla::LocalExecutable &local_executable = *local_executables.value()[0]; + if (build_options.device_ordinal() == -1) { + build_options.set_device_ordinal(local_client->default_device_ordinal()); + } + + absl::StatusOr> module_config_or_error = + xla::GetHloModuleConfig( + xla_computation, shape_pointers, build_options, + /*(serice) options=*/&local_client->local_service()->options_, + local_client->mutable_backend()); + if (!module_config_or_error.ok()) { + throw pybind11::value_error(module_config_or_error.status().ToString()); + } + module_config_or_error.value()->set_intra_op_parallelism_threads(1); + + auto executor = local_client->mutable_backend()->stream_executor( + build_options.device_ordinal()); + if (!executor.ok()) { + throw pybind11::value_error(executor.status().ToString()); + } + + auto executable = local_client->local_service()->BuildExecutable( + xla_computation.proto(), std::move(module_config_or_error.value()), + local_client->mutable_backend(), executor.value(), + {build_options.device_allocator(), build_options.compile_thread_pool(), + build_options.layout_canonicalization_callback()}, + build_options.run_backend_only()); + if (!executable.ok()) { + throw pybind11::value_error(executable.status().ToString()); + } + + auto local_executable = std::make_unique( + std::move(executable.value()), + local_client->local_service()->mutable_backend(), build_options); + auto *cpu_executable = - static_cast(local_executable.executable()); - const std::string &llvm_ir = cpu_executable->ir_module_string(); - return llvm_ir; + static_cast(local_executable->executable()); + + output = cpu_executable->ir_module_string(); + return std::move(local_executable); } diff --git a/enzyme_jax/compile_with_xla.h b/enzyme_jax/compile_with_xla.h new file mode 100644 index 00000000..841b3d1c --- /dev/null +++ b/enzyme_jax/compile_with_xla.h @@ -0,0 +1,7 @@ +#pragma once +#include "xla/client/local_client.h" +#include "llvm/ADT/StringRef.h" +#include +// Compile an MHLO module given as a string to LLVM IR using XLA. +std::unique_ptr +compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output); diff --git a/enzyme_jax/enzyme_call.cc b/enzyme_jax/enzyme_call.cc index ae13383b..1bfaaecd 100644 --- a/enzyme_jax/enzyme_call.cc +++ b/enzyme_jax/enzyme_call.cc @@ -10,10 +10,12 @@ #include #include #include +#include #include #include "absl/status/statusor.h" #include "clang_compile.h" +#include "pybind11/pybind11.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ExecutionEngine/Orc/CompileUtils.h" @@ -29,40 +31,43 @@ #include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Instructions.h" +#include "llvm/IRReader/IRReader.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/RWMutex.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Support/raw_ostream.h" -#include "pybind11/pybind11.h" -#include "llvm/IRReader/IRReader.h" -#include "llvm/Support/SourceMgr.h" -absl::StatusOr compile_mhlo_to_llvm_with_xla( - llvm::StringRef mhlo_text); +#include "compile_with_xla.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/service/cpu/cpu_executable.h" -enum class Language : int { - CPP = 0, - LLVM = 1, - MHLO = 2 -}; +#include "Enzyme/FunctionUtils.h" + +enum class ABI { Primal, Forward, Augmented, Reverse, Tape }; + +enum class Language : int { CPP = 0, LLVM = 1, MHLO = 2 }; namespace { class CpuKernel { // static llvm::orc::ExecutionSession ES; static std::unique_ptr DL; - static std::unique_ptr JIT; + static std::unique_ptr JIT; int64_t identifier; size_t num_out; uint64_t addr; - public: - CpuKernel(int64_t identifier, - size_t num_out, uint64_t addr) - : identifier(identifier), num_out(num_out), addr(addr) { - } - static std::string make_type(std::string typenam, llvm::ArrayRef shape, bool constv, Language lang) { - std::string s = std::string(constv ? "const " : "") + "enzyme::tensor<" + typenam; +public: + CpuKernel(int64_t identifier, size_t num_out, uint64_t addr) + : identifier(identifier), num_out(num_out), addr(addr) {} + + static std::string make_type(std::string typenam, + llvm::ArrayRef shape, bool constv, + Language lang) { + std::string s = + std::string(constv ? "const " : "") + "enzyme::tensor<" + typenam; for (auto v : shape) { s += ", " + std::to_string(v); } @@ -70,45 +75,78 @@ class CpuKernel { } static std::tuple, - std::unique_ptr, size_t> - createLLVMMod(llvm::StringRef fn, llvm::StringRef source, + std::unique_ptr, size_t, size_t> + createLLVMMod(std::string fn, llvm::StringRef source, llvm::ArrayRef> out_shapes, llvm::ArrayRef out_names, llvm::ArrayRef> in_shapes, llvm::ArrayRef in_names, PyObject *pyargv, - int mode, Language lang) { + ABI mode, Language lang) { auto llvm_ctx = std::make_unique(); std::string input; llvm::raw_string_ostream ss(input); - size_t num_out; ss << "#include \n"; ss << "#include \n"; ss << "#include \n"; std::unique_ptr linkMod; + std::unique_ptr local_executable; std::string stringbuf; + size_t tmpBuf = 0; + llvm::StringRef origSource = source; switch (lang) { case Language::CPP: ss << source << "\n"; break; - - case Language::MHLO:{ - absl::StatusOr llvm_ir = - compile_mhlo_to_llvm_with_xla(source); - if (!llvm_ir.ok()) { - throw std::runtime_error("failed to compile to LLVM IR with XLA:" + - llvm_ir.status().ToString()); + case Language::MHLO: { + local_executable = compile_mhlo_to_llvm_with_xla(source, stringbuf); + auto *cpu_executable = static_cast( + local_executable->executable()); + auto &assignment = cpu_executable->buffer_assignment(); + size_t num_in = 0; + for (auto &buf2 : assignment.Allocations()) { + if (buf2.is_entry_computation_parameter()) { + num_in++; + } + } + if (num_in != in_shapes.size()) { + std::string err_str; + llvm::raw_string_ostream ss(err_str); + ss << " Number of mhlo inputs (" << num_in + << ") != number of jax inputs (" << in_shapes.size() << "):\n"; + ss << source << "\n"; + throw pybind11::value_error(ss.str()); + } + for (size_t i = 0; i < in_shapes.size(); i++) { + ssize_t idx = -1; + for (auto &buf2 : assignment.Allocations()) { + if (!buf2.is_entry_computation_parameter()) + continue; + if (buf2.parameter_number() != i) + continue; + assert(idx == -1); + idx = buf2.index(); + } + if (idx == -1) { + std::string err_str; + llvm::raw_string_ostream ss(err_str); + ss << " Could not find input parameter (" << i + << ") as hlo parameter:\n"; + ss << source << "\n"; + throw pybind11::value_error(ss.str()); + } } - stringbuf = *llvm_ir; source = stringbuf; + tmpBuf = assignment.temp_allocation_total_size(); // explicitly fall through } case Language::LLVM: llvm::SMDiagnostic Err; - linkMod = llvm::parseIR(llvm::MemoryBufferRef(source, ""), Err, *llvm_ctx); + linkMod = llvm::parseIR(llvm::MemoryBufferRef(source, ""), Err, + *llvm_ctx); if (!linkMod) { std::string err_str; llvm::raw_string_ostream ss(err_str); @@ -117,171 +155,532 @@ class CpuKernel { } assert(linkMod); if (lang == Language::MHLO) { - for (auto &lfn : linkMod->functions()) { - if (lfn.empty()) continue; - assert(fn != "mhlo_main"); - fn = "mhlo_main"; - lfn.setName(fn); - lfn.addFnAttr(llvm::Attribute::AlwaysInline); + auto *cpu_executable = static_cast( + local_executable->executable()); + llvm::StringRef fname = cpu_executable->module_name(); + if (fname.size() && fname[0] == '_') + fname = fname.substr(1); + auto F = linkMod->getFunction(fname); + if (!F) { + llvm::errs() << *linkMod << "\n"; + llvm::errs() << "fname: " << fname << "\n"; } + assert(F); + fn = "mhlo_main"; + F->setName(fn); + assert(!F->empty()); + for (auto &F2 : *linkMod) + if (!F2.empty()) { + F2.addFnAttr(llvm::Attribute::AlwaysInline); + // Remove invariant_load if we expect enzyme to cache explicitly all + // data. Otherwise invariant_load allows Enzyme to assume it need + // not cache, and it is illegal for us to pass in nullptr as the + // primal (since it may be needed). + if (mode == ABI::Augmented || mode == ABI::Reverse || + mode == ABI::Tape) { + for (auto &BB : F2) + for (auto &I : BB) + if (auto LI = llvm::dyn_cast(&I)) + if (LI->hasMetadata(llvm::LLVMContext::MD_invariant_load)) + LI->setMetadata(llvm::LLVMContext::MD_invariant_load, + nullptr); + } + } } - ss << " extern \"C\" void " << fn << "(void* retval, void* run_options, void* params, void* buffer_table, void* status, void* prof_counters);\n\n"; + ss << " extern \"C\" void " << fn + << "(void* retval, void* run_options, void* params, void* " + "buffer_table, void* status, void* prof_counters);\n\n"; - ss << " __attribute__((always_inline)) static inline void abi_wrap("; - bool comma = false; - for (size_t i=0, off=0; i( + local_executable->executable()); + auto &assignment = cpu_executable->buffer_assignment(); + for (auto &buf : assignment.Allocations()) { + if (!buf.is_constant()) + continue; + assert(buf.assigned_buffers().size() == 1); + auto hlo = buf.assigned_buffers().begin()->first; + auto tyenum = hlo->shape().element_type(); + std::string ty; + switch (tyenum) { + case xla::PrimitiveType::S8: + ty = "int8_t"; + break; + case xla::PrimitiveType::S16: + ty = "int16_t"; + break; + case xla::PrimitiveType::S32: + ty = "int32_t"; + break; + case xla::PrimitiveType::S64: + ty = "int64_t"; + break; + case xla::PrimitiveType::U8: + ty = "uint8_t"; + break; + case xla::PrimitiveType::U16: + ty = "uint16_t"; + break; + case xla::PrimitiveType::U32: + ty = "uint32_t"; + break; + case xla::PrimitiveType::U64: + ty = "uint64_t"; + break; + case xla::PrimitiveType::F16: + ty = "half"; + break; + case xla::PrimitiveType::F32: + ty = "float"; + break; + case xla::PrimitiveType::F64: + ty = "double"; + break; + default: { + std::string err; + llvm::raw_string_ostream ess(err); + ess << " Failed to compile mhlo, unknown constant element type: " + << hlo->shape().ToString() << "\n"; + throw std::runtime_error(ess.str()); + } + } + auto val = xla::Cast(hlo->instruction()); + + llvm::ArrayRef shape(hlo->shape().dimensions().begin(), + hlo->shape().dimensions().end()); + ss << " static constexpr " + << make_type(ty, shape, /*const*/ false, lang) << " const_" + << buf.index() << " = "; + + xla::StringPrinter printer; + val->literal().PrintWithoutShape(&printer); + auto str = std::move(printer).ToString(); + if (shape.size() == 0) + ss << "{"; + str = std::regex_replace(str, std::regex("\\{"), "{{"); + str = std::regex_replace(str, std::regex("\\}"), "}}"); + ss << str; + if (shape.size() == 0) + ss << "}"; + ss << ";\n"; } - for (size_t i=0, off=0; i & __restrict__ tmpBuf"; + comma = true; + } + for (size_t i = 0; i < in_shapes.size(); i++) { + if (comma) + ss << ", "; + ss << " " << make_type(in_names[i], in_shapes[i], true, lang) << "& in_" + << i; comma = true; } ss << ") {\n"; - ss << " void* buffers[" << (out_shapes.size() + in_shapes.size()) << "] = {"; - comma = false; - for (size_t i=0, off=0; i out_idxs; + if (local_executable) { + auto *cpu_executable = static_cast( + local_executable->executable()); + auto &assignment = cpu_executable->buffer_assignment(); + numBuffers = assignment.Allocations().size(); + if (out_shapes.size() == 1) { + ssize_t idx = -1; + for (auto &buf2 : assignment.Allocations()) { + if (!buf2.maybe_live_out()) + continue; + assert(!buf2.is_tuple()); + assert(idx == -1); + idx = buf2.index(); + } + assert(idx != -1); + out_idxs.push_back(idx); + } else { + // If a tuple, find the tuple buf, then use that to index the outputs. + ssize_t tupidx = -1; + for (auto &buf2 : assignment.Allocations()) { + if (!buf2.maybe_live_out()) + continue; + if (!buf2.is_tuple()) + continue; + assert(tupidx == -1); + tupidx = buf2.index(); + } + assert(tupidx != -1); + auto &tup_buf = assignment.Allocations()[tupidx]; + assert(tup_buf.assigned_buffers().size() == 1); + auto hlo = tup_buf.assigned_buffers().begin()->first; + auto val = hlo->instruction(); + assert(val->operand_count() == out_shapes.size()); + for (size_t i = 0; i < out_shapes.size(); i++) { + ssize_t found = -1; + auto operand = val->operand(i); + while (found == -1) { + for (auto &buf : assignment.Allocations()) { + if (!buf.maybe_live_out()) + continue; + if (buf.is_tuple()) + continue; + bool contains_output = false; + for (auto &pair : buf.assigned_buffers()) { + if (pair.first->instruction() != operand) + continue; + assert(!contains_output); + contains_output = true; + assert(pair.second.offset == 0); + } + if (!contains_output) + continue; + assert(found == -1); + found = buf.index(); + } + if (operand->opcode() == xla::HloOpcode::kBitcast) { + operand = operand->operand(0); + continue; + } + break; + } + if (found == -1) { + llvm::errs() << "assignment: " << assignment.ToString() << "\n"; + llvm::errs() << "val: " << val->ToString() << "\n"; + llvm::errs() << "vop: " << val->operand(i)->ToString() << "\n"; + llvm::errs() << "i: " << i << "\n"; + } + assert(found != -1); + out_idxs.push_back((int)found); + } + } + for (auto &buf : assignment.Allocations()) { + if (buf.is_thread_local()) { + ss << " char local_" << buf.index() << "[" << buf.size() << "];\n"; + continue; + } + if (!buf.maybe_live_out()) + continue; + if (!buf.is_tuple()) + continue; + ss << " void* tup_" << buf.index() << "[" << out_idxs.size() + << "] = {"; + + for (size_t i = 0; i < out_idxs.size(); i++) { + if (i != 0) + ss << ", "; + ss << " " + << "(void*)&out_" << i; + } + ss << "};\n"; + } + } + ss << " void* buffers[" << numBuffers << "] = {"; + + if (local_executable) { + auto *cpu_executable = static_cast( + local_executable->executable()); + auto &assignment = cpu_executable->buffer_assignment(); + for (auto &buf : assignment.Allocations()) { + if (buf.index() != 0) + ss << ", "; + if (buf.is_entry_computation_parameter()) { + ss << " " + << "(void*)&in_" << buf.parameter_number(); + } else if (buf.IsPreallocatedTempBuffer()) { + ss << " " + << "(void*)&tmpBuf"; + } else if (buf.maybe_live_out()) { + if (buf.is_tuple()) { + assert(out_shapes.size() != 1); + ss << " " + << "(void*)&tup_" << buf.index(); + continue; + } + auto it = std::find(out_idxs.begin(), out_idxs.end(), buf.index()); + assert(it != out_idxs.end()); + int index = it - out_idxs.begin(); + ss << " " + << "(void*)&out_" << index; + } else if (buf.is_constant()) { + ss << " " + << "(void*)&const_" << buf.index(); + } else if (buf.is_thread_local()) { + ss << " " + << "(void*)&local_" << buf.index(); + } else { + std::string err; + llvm::raw_string_ostream ess(err); + ess << " Failed to compile mhlo, unknown buffer type\n"; + ess << origSource << "\n"; + ess << source << "\n"; + ess << local_executable->executable()->module().ToString() << "\n"; + ess << " unknown buffer type: " << buf.ToString() << "\n"; + throw std::runtime_error(ess.str()); + } + } + } else { + comma = false; + for (size_t i = 0; i < out_shapes.size(); i++) { + if (comma) + ss << ", "; + ss << " " + << "(void*)&out_" << i; + comma = true; + } + for (size_t i = 0; i < in_shapes.size(); i++) { + if (comma) + ss << ", "; + ss << " " + << "(void*)&in_" << i; comma = true; } - for (size_t i=0, off=0; i & __restrict__ tmpBuf"; + comma = true; + } + for (size_t i = 0; i < in_shapes.size(); i++) { + if (comma) + ss << ", "; + ss << " " << make_type(in_names[i], in_shapes[i], true, lang) << "& in_" + << i; comma = true; } ss << ") {\n"; ss << " " << fn << "("; comma = false; - for (size_t i=0, off=0; i& tmpBuf = " + << "*(enzyme::tensor*)outs[" << out_off + << "];\n"; + out_off++; + } + // forward mode, we have undef dtmpbuf + if (mode == ABI::Forward && tmpBuf != 0) { + ss << " enzyme::tensor& dtmpBuf = " + << "*(enzyme::tensor*)outs[" << out_off + << "];\n"; out_off++; } - if (mode == 0) { - num_out = out_shapes.size(); - ss << " " << fn << "("; - bool comma = false; - for (size_t i=0; i& dtmpBuf = " + << "*(enzyme::tensor*)(nullptr);\n"; + ss << "#pragma clang diagnostic pop\n"; + } + // reverse mode, we have zero'd + if (mode == ABI::Reverse && tmpBuf != 0) { + ss << "#pragma clang diagnostic push\n"; + ss << "#pragma clang diagnostic ignored \"-Wnull-dereference\"\n"; + ss << " enzyme::tensor& tmpBuf = " + << "*(enzyme::tensor*)(nullptr);\n"; + ss << "#pragma clang diagnostic pop\n"; + ss << " __builtin_memset(outs[" << out_off << "], 0, " << tmpBuf + << ");\n"; + ss << " enzyme::tensor& dtmpBuf = " + << "*(enzyme::tensor*)outs[" << out_off + << "];\n"; + out_off++; + } + + if (mode == ABI::Primal) { + ss << " " << fn << "("; + bool comma = false; + for (size_t i = 0; i < out_shapes.size(); i++) { + if (comma) + ss << ", "; ss << "out_" << i; comma = true; - } - for (size_t i=0; i(" << fn << ", enzyme_allocated, tapesize, enzyme_tape, &tape"; - for (size_t i=0; i(" << fn + << ", enzyme_allocated, tapesize, enzyme_tape, &tape"; + for (size_t i = 0; i < out_shapes.size(); i++) { + ss << ", enzyme_dup, &out_" << i << ", nullptr"; + } + if (tmpBuf != 0) { + ss << ", enzyme_dup, &tmpBuf, &dtmpBuf"; } - for (size_t i=0; i(" << fn + << ", enzyme_allocated, tapesize, enzyme_tape, &tape"; + for (size_t i = 0; i < out_shapes.size(); i++) { + ss << ", enzyme_dup, nullptr, &dout_" << i; } - ss << " enzyme::__enzyme_reverse(" << fn << ", enzyme_allocated, tapesize, enzyme_tape, &tape"; - for (size_t i=0; i pyargv_strs; - assert (PySequence_Check(pyargv)); - auto sz = PySequence_Size(pyargv); + assert(PySequence_Check(pyargv)); + auto sz = PySequence_Size(pyargv); for (Py_ssize_t i = 0; i < sz; ++i) { - PyObject* item = PySequence_GetItem(pyargv, i); + PyObject *item = PySequence_GetItem(pyargv, i); #if PY_VERSION_HEX < 0x03000000 - auto argv = PyString_AsString(item); + auto argv = PyString_AsString(item); #else - auto argv = PyUnicode_AsUTF8(item); + auto argv = PyUnicode_AsUTF8(item); #endif - Py_DECREF(item); - assert(argv); - pyargv_strs.emplace_back(argv); + Py_DECREF(item); + assert(argv); + pyargv_strs.emplace_back(argv); #if PY_VERSION_HEX < 0x03000000 - free(argv); + free(argv); #else - // should not free py3+ + // should not free py3+ #endif } - auto mod = GetLLVMFromJob("/enzyme_call/source.cpp", ss.str(), /*cpp*/true, pyargv_strs, llvm_ctx.get(), std::move(linkMod)); + auto mod = GetLLVMFromJob("/enzyme_call/source.cpp", ss.str(), /*cpp*/ true, + pyargv_strs, llvm_ctx.get(), std::move(linkMod)); if (!mod) throw pybind11::value_error("failed to compile C++"); - return std::make_tuple(std::move(mod), std::move(llvm_ctx), num_out); + return std::make_tuple(std::move(mod), std::move(llvm_ctx), out_off, + tmpBuf); } - static size_t tapeSize(llvm::StringRef fn, llvm::StringRef source, - llvm::ArrayRef> out_shapes, - llvm::ArrayRef out_names, - llvm::ArrayRef> in_shapes, - llvm::ArrayRef in_names, - PyObject* pyargv, Language lang) { - int mode = 4; - auto [mod, llvm_ctx, num_out] = createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, pyargv, mode, lang); + static std::pair + tapeAndTempSize(std::string fn, llvm::StringRef source, + llvm::ArrayRef> out_shapes, + llvm::ArrayRef out_names, + llvm::ArrayRef> in_shapes, + llvm::ArrayRef in_names, PyObject *pyargv, + Language lang) { + auto mode = ABI::Tape; + auto [mod, llvm_ctx, num_out, tmpBuf] = + createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, + pyargv, mode, lang); auto lfn = mod->getFunction("entry"); - auto RI = llvm::cast(lfn->getEntryBlock().getTerminator()); + auto RI = + llvm::cast(lfn->getEntryBlock().getTerminator()); auto val = llvm::cast(RI->getReturnValue()); size_t res = val->getZExtValue(); // force deletion of mod first explicitly mod = nullptr; - return res; + return std::make_pair(res, tmpBuf); } + static size_t tempSize(llvm::StringRef source, Language lang) { + switch (lang) { + case Language::MHLO: { + std::string llvm_ir; + auto local_executable = compile_mhlo_to_llvm_with_xla(source, llvm_ir); + auto *cpu_executable = static_cast( + local_executable->executable()); + auto &assignment = cpu_executable->buffer_assignment(); + return assignment.temp_allocation_total_size(); + } + default: + return 0; + } + } - static int64_t create(llvm::StringRef fn, llvm::StringRef source, - llvm::ArrayRef> out_shapes, - llvm::ArrayRef out_names, - llvm::ArrayRef> in_shapes, - llvm::ArrayRef in_names, - PyObject* pyargv, int mode, Language lang) { + static std::tuple + create(std::string fn, llvm::StringRef source, + llvm::ArrayRef> out_shapes, + llvm::ArrayRef out_names, + llvm::ArrayRef> in_shapes, + llvm::ArrayRef in_names, PyObject *pyargv, ABI mode, + Language lang) { llvm::sys::SmartScopedWriter lock(kernel_mutex); - int64_t identifier = last_identifier++; + size_t identifier = last_identifier++; - auto [mod, llvm_ctx, num_out] = createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, pyargv, mode, lang); + auto [mod, llvm_ctx, num_out, tmpBuf] = + createLLVMMod(fn, source, out_shapes, out_names, in_shapes, in_names, + pyargv, mode, lang); if (!JIT) { DL = std::make_unique(mod.get()); - auto tJIT = llvm::orc::LLJITBuilder().setDataLayout(*DL.get()).setLinkProcessSymbolsByDefault(true).setObjectLinkingLayerCreator( - [](llvm::orc::ExecutionSession & ES, const llvm::Triple &OLL) -> llvm::Expected> { - return std::make_unique(ES); - }).setJITTargetMachineBuilder(llvm::orc::JITTargetMachineBuilder(llvm::Triple(mod->getTargetTriple()))).create(); + auto tJIT = + llvm::orc::LLJITBuilder() + .setDataLayout(*DL.get()) + .setLinkProcessSymbolsByDefault(true) + .setObjectLinkingLayerCreator( + [](llvm::orc::ExecutionSession &ES, const llvm::Triple &OLL) + -> llvm::Expected< + std::unique_ptr> { + auto obj = std::make_unique< + llvm::orc::RTDyldObjectLinkingLayer>(ES, []() { + return std::make_unique(); + }); + if (getenv("ENABLE_GDBLISTENER")) { + auto list = llvm::JITEventListener:: + createGDBRegistrationListener(); + obj->registerJITEventListener(*list); + } + return obj; + }) + .setJITTargetMachineBuilder(llvm::orc::JITTargetMachineBuilder( + llvm::Triple(mod->getTargetTriple()))) + .create(); if (!tJIT) { llvm::errs() << tJIT.takeError() << "\n"; throw pybind11::value_error("failed to create jit"); @@ -394,12 +850,16 @@ class CpuKernel { assert(JIT); } - auto LibA = JIT->createJITDylib("enzymedl_"+std::to_string(identifier)); + auto LibA = JIT->createJITDylib("enzymedl_" + std::to_string(identifier)); // Add the module. - // if (auto Err = JIT->addIRModule(llvm::orc::ThreadSafeModule(std::move(mod), std::move(llvm_ctx)))) { - if (auto Err = JIT->addIRModule(LibA.get(), llvm::orc::ThreadSafeModule(std::move(mod), std::move(llvm_ctx)))) { - llvm::errs() <<" error " << Err << "\n"; + // if (auto Err = + // JIT->addIRModule(llvm::orc::ThreadSafeModule(std::move(mod), + // std::move(llvm_ctx)))) { + if (auto Err = JIT->addIRModule( + LibA.get(), + llvm::orc::ThreadSafeModule(std::move(mod), std::move(llvm_ctx)))) { + llvm::errs() << " error " << Err << "\n"; throw pybind11::value_error("failed to add IR module"); } @@ -412,44 +872,44 @@ class CpuKernel { // Cast the entry point address to a function pointer. auto Entry = EntrySym->getValue(); - + kernels.try_emplace( - identifier, - std::make_unique(identifier, num_out, Entry)); - return identifier; + identifier, std::make_unique(identifier, num_out, Entry)); + return std::make_tuple(identifier, tmpBuf); } static CpuKernel *get(int64_t identifier) { llvm::sys::SmartScopedReader lock(kernel_mutex); auto it = kernels.find(identifier); - if (it == kernels.end()) return nullptr; + if (it == kernels.end()) + return nullptr; return it->getSecond().get(); } void call(void *out, void **ins) const { void **outs = num_out > 1 ? reinterpret_cast(out) : &out; - for(int i=0; i> kernels; - static int64_t last_identifier; + static size_t last_identifier; static llvm::sys::SmartRWMutex kernel_mutex; }; -llvm::DenseMap> - CpuKernel::kernels; -int64_t CpuKernel::last_identifier = 1; +llvm::DenseMap> CpuKernel::kernels; +size_t CpuKernel::last_identifier = 1; llvm::sys::SmartRWMutex CpuKernel::kernel_mutex; std::unique_ptr CpuKernel::DL; -std::unique_ptr CpuKernel::JIT = nullptr; -// llvm::orc::ExecutionSession CpuKernel::ES(std::move(*llvm::orc::SelfExecutorProcessControl::Create())); -} // namespace +std::unique_ptr CpuKernel::JIT = nullptr; +// llvm::orc::ExecutionSession +// CpuKernel::ES(std::move(*llvm::orc::SelfExecutorProcessControl::Create())); +} // namespace void CpuCallback(void *out, void **ins) { int64_t identifier = *reinterpret_cast(ins[0]); @@ -466,16 +926,25 @@ PYBIND11_MODULE(enzyme_call, m) { llvm::InitializeAllTargetMCs(); llvm::InitializeAllAsmPrinters(); llvm::InitializeAllAsmParsers(); + EnzymeAlwaysInlineDiff.setValue(true); pybind11::enum_(m, "Language") - .value("CPP", Language::CPP) - .value("LLVM", Language::LLVM) - .value("MHLO", Language::MHLO); + .value("CPP", Language::CPP) + .value("LLVM", Language::LLVM) + .value("MHLO", Language::MHLO); + + pybind11::enum_(m, "ABI") + .value("Primal", ABI::Primal) + .value("Forward", ABI::Forward) + .value("Augmented", ABI::Augmented) + .value("Reverse", ABI::Reverse) + .value("Tape", ABI::Tape); m.def("create_enzyme_cpu_kernel", - [](const std::string &source, const std::string &fn, const pybind11::list &py_out_shapes, - const pybind11::list &py_in_shapes, - pybind11::object pyargv, int mode, Language lang) -> int64_t { + [](const std::string &source, const std::string &fn, + const pybind11::list &py_out_shapes, + const pybind11::list &py_in_shapes, pybind11::object pyargv, + ABI mode, Language lang) -> std::tuple { llvm::SmallVector> out_shapes; out_shapes.reserve(pybind11::len(py_out_shapes)); llvm::SmallVector> in_shapes; @@ -509,13 +978,20 @@ PYBIND11_MODULE(enzyme_call, m) { target.push_back(nested_element.cast()); } } - return CpuKernel::create(fn, source, out_shapes, out_types, in_shapes, in_types, pyargv.ptr(), mode, (Language)lang); + return CpuKernel::create(fn, source, out_shapes, out_types, in_shapes, + in_types, pyargv.ptr(), mode, + (Language)lang); }); - m.def("tape_size", - [](const std::string &source, const std::string &fn, const pybind11::list &py_out_shapes, - const pybind11::list &py_in_shapes, - pybind11::object pyargv, Language lang) -> int64_t { + m.def("tmp_size", [](const std::string &source, Language lang) -> size_t { + return CpuKernel::tempSize(source, (Language)lang); + }); + + m.def("tape_and_tmp_size", + [](const std::string &source, const std::string &fn, + const pybind11::list &py_out_shapes, + const pybind11::list &py_in_shapes, pybind11::object pyargv, + Language lang) -> std::pair { llvm::SmallVector> out_shapes; out_shapes.reserve(pybind11::len(py_out_shapes)); llvm::SmallVector> in_shapes; @@ -549,7 +1025,9 @@ PYBIND11_MODULE(enzyme_call, m) { target.push_back(nested_element.cast()); } } - return (int64_t)CpuKernel::tapeSize(fn, source, out_shapes, out_types, in_shapes, in_types, pyargv.ptr(), (Language)lang); + return CpuKernel::tapeAndTempSize(fn, source, out_shapes, out_types, + in_shapes, in_types, pyargv.ptr(), + (Language)lang); }); m.def("get_cpu_callback", []() { @@ -558,13 +1036,8 @@ PYBIND11_MODULE(enzyme_call, m) { }); m.def("compile_mhlo_to_llvm_with_xla", [](const std::string &mhlo_text) { - absl::StatusOr llvm_ir = - compile_mhlo_to_llvm_with_xla(mhlo_text); - if (!llvm_ir.ok()) { - throw std::runtime_error("failed to compile to LLVM IR with XLA:" + - llvm_ir.status().ToString()); - } - return *llvm_ir; + std::string llvm_ir; + compile_mhlo_to_llvm_with_xla(mhlo_text, llvm_ir); + return llvm_ir; }); } - diff --git a/enzyme_jax/primitives.py b/enzyme_jax/primitives.py index 31ebfad7..2fffc767 100644 --- a/enzyme_jax/primitives.py +++ b/enzyme_jax/primitives.py @@ -31,9 +31,12 @@ def cflags(): import platform import os if platform.system() == 'Darwin': - return ('-isysroot', '/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk', "-isystem", "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1", "-internal-isystem", os.path.join(resource_dir(), "include"), "-internal-externc-isystem", "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include", "-internal-externc-isystem", "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include", "-fgnuc-version=4.2.1") + res = ('-isysroot', '/Library/Developer/CommandLineTools/SDKs/MacOSX.sdk', "-isystem", "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include/c++/v1", "-internal-isystem", os.path.join(resource_dir(), "include"), "-internal-externc-isystem", "/Applications/Xcode.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk/usr/include", "-internal-externc-isystem", "/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/include", "-fgnuc-version=4.2.1") else: - return () + res = () + if os.getenv("ENABLE_GDBLISTENER") is not None: + res = res + ('-debug-info-kind=standalone', '-dwarf-version=5', '-debugger-tuning=gdb',) + return res def _enzyme_primal_impl( *args_flat: jax.Array, @@ -98,11 +101,9 @@ def _enzyme_primal_abstract_eval( out_shapes: Sequence[jax.core.ShapedArray], lang: enzyme_call.Language ) -> Sequence[jax.core.ShapedArray]: - del source, fn, args_flat - # TODO: we may attempt some lightweight parsing of source to extract the # result types instead. - return tuple(out_shapes) + return out_shapes def _enzyme_fwd_abstract_eval( *args_flat: jax.core.ShapedArray, @@ -113,8 +114,6 @@ def _enzyme_fwd_abstract_eval( lang: enzyme_call.Language, ) -> Sequence[jax.core.ShapedArray]: del source, fn, args_flat - - # each return is duplicated return tuple(o for o in out_shapes for _ in range(2)) def absmaketup(ty): @@ -145,10 +144,12 @@ def _enzyme_aug_abstract_eval( lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect='mhlo') source = str(mhlo) + kept = lowered_func.compile()._executable._kept_var_idx + in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] argv = argv + ( "-resource-dir", resource_dir()) + cflags() - tapeSize = enzyme_call.tape_size(source, fn, out_shapes, in_shapes, argv, lang) + tapeSize, tmpSize = enzyme_call.tape_and_tmp_size(source, fn, out_shapes, in_shapes, argv, lang) res = tuple(prev_out_shapes) + (jax.core.ShapedArray((tapeSize,), (jax.numpy.int8)),) return res @@ -171,17 +172,18 @@ def _enzyme_rev_abstract_eval( in_shapes, lang: enzyme_call.Language ) -> Sequence[jax.core.ShapedArray]: - del source, fn, args_flat - return tuple(jax.core.ShapedArray(shape, dejaxify(tyid)) for (shape, tyid) in in_shapes) def maketup(ty): ty = ir.RankedTensorType(ty) - tystr = ty.element_type.__str__() - tystr = {'f32':'float','f64':'double'}[tystr] + tystr = ty.element_type.__str__() + tystr = {'f32':'float','f64':'double','i32':'int32_t','i64':'int64_t'}[tystr] return (tystr, ty.shape) - +def to_jax(ty): + tystr = ty.__str__() + return {'f32':jnp.float32,'f64':jnp.float64}[tystr] + def _enzyme_primal_lowering( ctx: jax_mlir.LoweringRuleContext, *args_flat: ir.Value, @@ -200,26 +202,39 @@ def _enzyme_primal_lowering( out_shapes = list(map(maketup, out_types)) in_shapes = list(map(lambda x: maketup(x.type), args_flat)) + in_args = (*args_flat,) + if lang == LANG_MHLO: (in_tree, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect='mhlo') source = str(mhlo) + kept = lowered_func.compile()._executable._kept_var_idx + in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) + in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] argv = argv + ( "-resource-dir", resource_dir() ) + cflags() - mode = 0 - identifier = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, mode, lang) + identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Primal, lang) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) - mlir_args = (identifier_op, *args_flat) + mlir_args = (identifier_op,) + in_args + + if tmpBuf != 0: + sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) + out_types = out_types + (sa,) + custom_call = stablehlo.CustomCallOp( out_types, mlir_args, call_target_name="jaxzyme.primal" ) - return custom_call.results + results = custom_call.results + if tmpBuf != 0: + results = results[:-1] + + return results def _enzyme_fwd_lowering( ctx: jax_mlir.LoweringRuleContext, @@ -239,6 +254,8 @@ def _enzyme_fwd_lowering( out_shapes = list(map(maketup, out_types[::2])) in_shapes = list(map(lambda x: maketup(x.type), args_flat[::2])) + + in_args = (*args_flat,) if lang == LANG_MHLO: (in_tree, func) = source @@ -246,19 +263,30 @@ def _enzyme_fwd_lowering( lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect='mhlo') source = str(mhlo) + kept = lowered_func.compile()._executable._kept_var_idx + in_args = tuple(arg for (i, arg) in enumerate(in_args) if i//2 in kept) + in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] argv = argv + ( "-resource-dir", resource_dir() ) + cflags() - mode = 1 - identifier = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, mode, lang) + identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Forward, lang) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) - mlir_args = (identifier_op, *args_flat) + mlir_args = (identifier_op,) + in_args + + if tmpBuf != 0: + sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) + out_types = out_types + (sa,sa) + custom_call = stablehlo.CustomCallOp( out_types, mlir_args, call_target_name="jaxzyme.fwd" ) - return custom_call.results + results = custom_call.results + if tmpBuf != 0: + results = results[:-2] + + return results def _enzyme_aug_lowering( @@ -280,25 +308,36 @@ def _enzyme_aug_lowering( in_shapes = list(map(lambda x: maketup(x.type), args_flat)) + in_args = (*args_flat,) + if lang == LANG_MHLO: (in_tree, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect='mhlo') source = str(mhlo) + kept = lowered_func.compile()._executable._kept_var_idx + in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) + in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] argv = argv + ( "-resource-dir", resource_dir()) + cflags() - mode = 2 - identifier = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, mode, lang) + identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Augmented, lang) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) - mlir_args = (identifier_op, *args_flat) + if tmpBuf != 0: + sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) + out_types = out_types + (sa,) + + mlir_args = (identifier_op,) + in_args custom_call = stablehlo.CustomCallOp( out_types, mlir_args, call_target_name="jaxzyme.aug" ) - return custom_call.results + results = custom_call.results + if tmpBuf != 0: + results = results[:-1] + return results def _enzyme_rev_lowering( ctx: jax_mlir.LoweringRuleContext, @@ -311,32 +350,62 @@ def _enzyme_rev_lowering( ) -> Sequence[ir.Value]: del in_shapes - in_types = tuple( + pre_in_types = tuple( itertools.chain(*map(jax_mlir.aval_to_ir_types, ctx.avals_out)) ) - in_shapes = list(map(maketup, in_types)) + in_shapes = list(map(maketup, pre_in_types)) + pre_in_shapes = in_shapes out_shapes = list(map(lambda x: maketup(x.type), args_flat[1:])) + + in_args = (*args_flat,) + rev_return_types = pre_in_types + + kept = None if lang == LANG_MHLO: (in_tree, func) = source avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out) lowered_func = jax.jit(func).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect='mhlo') source = str(mhlo) + kept = lowered_func.compile()._executable._kept_var_idx + # in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) + in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] + rev_return_types = tuple(retty for (i, retty) in enumerate(rev_return_types) if i in kept) argv = tuple(argv) + ( "-resource-dir", resource_dir()) + cflags() - mode = 3 - identifier = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, mode, lang) + identifier, tmpBuf = enzyme_call.create_enzyme_cpu_kernel(source, fn, out_shapes, in_shapes, argv, enzyme_call.ABI.Reverse, lang) identifier_attr = jax_mlir.dense_int_elements([identifier]) identifier_op = stablehlo.ConstantOp(identifier_attr) - mlir_args = (identifier_op, *args_flat) + mlir_args = (identifier_op,) + in_args + + if tmpBuf != 0: + sa = ir.RankedTensorType.get((tmpBuf,), ir.IntegerType.get_signless(8)) + rev_return_types = rev_return_types + (sa,) + custom_call = stablehlo.CustomCallOp( - in_types, mlir_args, call_target_name="jaxzyme.rev" + rev_return_types, mlir_args, call_target_name="jaxzyme.rev" ) - return custom_call.results + results = custom_call.results + if tmpBuf != 0: + results = results[:-1] + if kept != None: + results = [] + cur_idx = 0 + for i, ty in enumerate(pre_in_types): + if i in kept: + results.append(custom_call.results[cur_idx]) + cur_idx += 1 + else: + ty = ir.RankedTensorType(ty) + shape = ty.shape + element_type = ty.element_type + import numpy as np + results.append(stablehlo.ConstantOp(ir.DenseElementsAttr.get(np.zeros(shape, dtype=to_jax(element_type)))).results[0]) + return results def ffi_call(*args, out_shapes: Sequence[jax.core.ShapedArray], source, fn:str="f", argv: tuple[str]=(), lang:int=LANG_CPP): return _enzyme_primal_p.bind( @@ -465,4 +534,4 @@ def wrapped(*args: Any): out_flat = ffi_call(*args_flat, source=(in_tree, func), fn="", out_shapes=out_shape_flat, argv=argv, lang=LANG_MHLO) return jax.tree_util.tree_unflatten(out_tree, out_flat) return wrapped - return decorator \ No newline at end of file + return decorator diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index 1c736a92..617f658f 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -11,18 +11,35 @@ def add_one(x: jax.Array, y) -> jax.Array: def add_one_plain(x: jax.Array, y) -> jax.Array: return x + 1 + y -in0, in1 = jnp.array([1., 2., 3.]), jnp.array([10., 20., 30.]) +@enzyme_jax_ir() +def add_two(x: jax.Array, z, y) -> jax.Array: + return x + y + +@jax.jit +def add_two_plain(x: jax.Array, z, y) -> jax.Array: + return x + y + +in0, in1, in2 = jnp.array([1., 2., 3.]), jnp.array([10., 20., 30.]), jnp.array([100., 200., 300.]) # TODO: this currently throws NYI as it is not yet connected to JIT and runtime. # But it should print LLVM IR in the process. -add_one(in0, in1) -add_one_plain(in0, in1) +ao = add_one(in0, in1) +aop = add_one_plain(in0, in1) +assert (jnp.abs(ao-aop) < 1e-6).all() +print("Primal success") + +at = add_two(in0, in1, in2) +atp = add_two_plain(in0, in1, in2) + +assert (jnp.abs(at-atp) < 1e-6).all() +print("Primal Deadarg success") + import timeit print(timeit.Timer('add_one(in0, in1)', globals={'add_one':add_one, 'in0':in0, 'in1':in1}).timeit()) print(timeit.Timer('add_one_plain(in0, in1)', globals={'add_one_plain':add_one_plain, 'in0':in0, 'in1':in1}).timeit()) -din0, din1 = (jnp.array([.1, .2, .3]), jnp.array([50., 70., 110.])) +din0, din1, din2 = (jnp.array([.1, .2, .3]), jnp.array([50., 70., 110.]), jnp.array([1300., 1700., 1900.])) @jax.jit def fwd(in0, in1, din0, din1): @@ -30,10 +47,36 @@ def fwd(in0, in1, din0, din1): @jax.jit def fwd_plain(in0, in1, din0, din1): - return jax.jvp(add_one, (in0, in1), (din0, din1)) + return jax.jvp(add_one_plain, (in0, in1), (din0, din1)) primals, tangents = fwd(in0, in1, din0, din1) -primals, tangents = fwd_plain(in0, in1, din0, din1) +primals_p, tangents_p = fwd_plain(in0, in1, din0, din1) + +assert (jnp.abs(primals-primals_p) < 1e-6).all() +for t, t_p in zip(tangents, tangents_p): + assert (jnp.abs(t-t_p) < 1e-6).all() + +print("Tangent success") + +@jax.jit +def fwd2(in0, in1, in2, din0, din1, din2): + return jax.jvp(add_two, (in0, in1, in2), (din0, din1, din2)) + +@jax.jit +def fwd2_plain(in0, in1, in2, din0, din1, din2): + return jax.jvp(add_two_plain, (in0, in1, in2), (din0, din1, din2)) + +primals, tangents = fwd2(in0, in1, in2, din0, din1, din2) +primals_p, tangents_p = fwd2_plain(in0, in1, in2, din0, din1, din2) + +print(primals, primals_p) +assert (jnp.abs(primals-primals_p) < 1e-6).all() +for i, (t, t_p) in enumerate(zip(tangents, tangents_p)): + print(i, t, t_p) + assert (jnp.abs(t-t_p) < 1e-6).all() + +print("Tangent deadarg success") + print(timeit.Timer('fwd(in0, in1, din0, din1)', globals={'fwd':fwd, 'in0':in0, 'in1':in1, 'din0':din0, 'din1':din1}).timeit()) print(timeit.Timer('fwd_plain(in0, in1, din0, din1)', globals={'fwd_plain':fwd_plain, 'in0':in0, 'in1':in1, 'din0':din0, 'din1':din1}).timeit()) @@ -53,13 +96,104 @@ def rev_plain(in0, in1, dout): dout = jnp.array([500., 700., 110.]) -rev(in0, in1, dout) -rev_plain(in0, in1, dout) +primals, grads = rev(in0, in1, dout) +# TODO enzyme will in place 0 the gradient inputs, which may not be expected +print(dout) +dout = jnp.array([500., 700., 110.]) +primals_p, grads_p = rev_plain(in0, in1, dout) + +assert (jnp.abs(primals-primals_p) < 1e-6).all() +for g, g_p in zip(grads, grads_p): + print(i, g, g_p) + assert (jnp.abs(g-g_p) < 1e-6).all() + +print("Gradient success") + +@jax.jit +def rev2(in0, in1, in2, dout): + primals, f_vjp = jax.vjp(add_two, in0, in1, in2) + grads = f_vjp(dout) + return primals, grads + +@jax.jit +def rev2_plain(in0, in1, in2, dout): + primals, f_vjp = jax.vjp(add_two_plain, in0, in1, in2) + grads = f_vjp(dout) + return primals, grads + -print(rev_plain.lower(in0, in1, dout).compiler_ir(dialect="mhlo")) +dout = jnp.array([500., 700., 110.]) +primals, grads = rev2(in0, in1, in2, dout) +# TODO enzyme will in place 0 the gradient inputs, which may not be expected +print(dout) +dout = jnp.array([500., 700., 110.]) +primals_p, grads_p = rev2_plain(in0, in1, in2, dout) + +assert (jnp.abs(primals-primals_p) < 1e-6).all() +for g, g_p in zip(grads, grads_p): + print(i, g, g_p) + assert (jnp.abs(g-g_p) < 1e-6).all() -rev_plain(in0, in1, dout) +print("Gradient deadarg success") print(timeit.Timer('rev(in0, in1, dout)', globals={'rev':rev, 'in0':in0, 'in1':in1, 'dout':dout}).timeit()) print(timeit.Timer('rev_plain(in0, in1, dout)', globals={'rev_plain':rev_plain, 'in0':in0, 'in1':in1, 'dout':dout}).timeit()) +x = jnp.array(range(50), dtype=jnp.float32) +dx = jnp.array([i*i for i in range(50)], dtype=jnp.float32) + +@enzyme_jax_ir() +def esum(x): + return jnp.sum(x) + +eres = esum(x) +print(eres) +assert jnp.abs(eres-50*49/2)<1e-6 + +@jax.jit +def sumfwd(in0, din0): + return jax.jvp(esum, (in0,), (din0,)) + +primals, tangents = sumfwd(x, dx) +print(primals, tangents) +assert jnp.abs(primals-50*49/2)<1e-6 +assert jnp.abs(tangents-50*49*99/6)<1e-6 + +@jax.jit +def sumrev_p(in0): + primals, f_vjp = jax.vjp(jnp.sum, in0) + grads = f_vjp(1.0) + return primals, grads + +primals, grads = sumrev_p(x) +print(primals, grads) + +@jax.jit +def sumrev(in0): + primals, f_vjp = jax.vjp(esum, in0) + grads = f_vjp(1.0) + return primals, grads + +primals, grads = sumrev(x) +print(primals, grads) +assert jnp.abs(primals-50*49/2)<1e-6 +assert (jnp.abs(grads[0]-1) <1e-6).all() + +@enzyme_jax_ir() +def ecache(x): + return x * x[0] + +@jax.jit +def cacherev(in0, din0): + primals, f_vjp = jax.vjp(ecache, in0) + grads = f_vjp(din0) + return grads + +dim = 288 + +x = jnp.array(range(dim), dtype=jnp.float32) +dx = jnp.array(range(dim), dtype=jnp.float32) + +grads = cacherev(x, dx) +assert jnp.abs(grads[0][0]-287*288*(2*287+1)/6)<1e-6 +assert (jnp.abs(grads[0][1:]) <1e-6).all() diff --git a/test/llama.py b/test/llama.py new file mode 100644 index 00000000..f90aa28a --- /dev/null +++ b/test/llama.py @@ -0,0 +1,294 @@ +import jax.numpy as jnp +import jax.random +import jax.lax +import enzyme_jax + +def rmsnorm(x, weight): + ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5) + return weight * x * ss + +def softmax(x): + max_val = jnp.max(x) + x = jnp.exp(x - max_val) + return x / sum(x) + +def sigmoid(x): + return 1 / (1 + jnp.exp(-x)) + +def silu(x): + return x * sigmoid(x) + + +# Token is token value +asserts = True +def forward(x, config, weights, key_cache, value_cache): + pos = key_cache.shape[1] + assert pos == key_cache.shape[1] + assert pos == value_cache.shape[1] + + n_layers = config['n_layers'] + seq_len = config['seq_len'] + n_heads = config['n_heads'] + vocab_size = config['vocab_size'] + + # Total number of parameters of the recurrent state + dim = config['dim'] + + n_kv_heads = config['n_kv_heads'] + + # number of hidden dimensions? + hidden_dim = config['hidden_dim'] + + + # Number of parameters per head + head_size = dim // n_heads + + # Number of heads per kv + kv_mul = n_heads // n_kv_heads + + # Number of parameters in a kv + kv_dim = dim // n_heads * n_kv_heads + + + wo = weights['wo'] + if asserts: assert wo.shape == (n_layers, dim, dim) + rms_ffn_weight = weights['rms_ffn_weight'] + if asserts: assert rms_ffn_weight.shape == (n_layers, dim) + w1 = weights['w1'] + if asserts: assert w1.shape == (n_layers, hidden_dim, dim) + w3 = weights['w3'] + if asserts: assert w3.shape == (n_layers, hidden_dim, dim) + w2 = weights['w2'] + if asserts: assert w2.shape == (n_layers, dim, hidden_dim) + + rms_att_weight = weights['rms_att_weight'] + if asserts: assert rms_att_weight.shape == (n_layers,dim) + + rms_final_weight = weights['rms_final_weight'] + if asserts: assert rms_final_weight.shape == (dim,) + wcls = weights['wcls'] + if asserts: assert wcls.shape == (vocab_size, dim) + + # token_embedding_table = weights['token_embedding_table'] + # if asserts: assert token_embedding_table.shape == (vocab_size, dim) + + # x = token_embedding_table[token, :] + # if asserts: assert x.shape == (dim, ) + + wq = weights['wq'] + if asserts: assert wq.shape == (n_layers, dim, dim) + + wk = weights['wk'] + if asserts: assert wk.shape == (n_layers, kv_dim, dim) + + wv = weights['wv'] + if asserts: assert wv.shape == (n_layers, kv_dim, dim) + + toconv = [] + + for i in range(0, dim, 2): + freq = 1 / jnp.power(10000, (i % head_size) / head_size) + val = pos * freq + fcr = jnp.cos(val) + fci = jnp.sin(val) + + rotM = jnp.array([[fcr, -fci], + [fci, fcr]]) + toconv.append(rotM) + toconv2 = toconv[:kv_dim//2] + [jnp.eye(2)] * (dim//2 - kv_dim//2) + + toconv = jnp.array(toconv) + toconv2 = jnp.array(toconv2) + + keys2 = [] + values2 = [] + for l in range(n_layers): + xb = rmsnorm(x, rms_att_weight[l, :]) + if asserts: assert xb.shape == (dim, ) + + q = wq[l, :, :] @ xb + if asserts: assert q.shape == (dim, ) + + k = wk[l, :, :] @ xb + if asserts: assert q.shape == (kv_dim, ) + + v = wv[l, :, :] @ xb + if asserts: assert q.shape == (kv_dim, ) + + q_tmp = jnp.reshape(q, (dim // 2, 2)) + k_tmp = jnp.reshape(k, (dim // 2, 2)) + + # dim == head_size * n_heads + + # Batched gemv + k = jnp.reshape(jnp.einsum('ijk,ik -> ij', toconv2, k_tmp), (dim,)) + q = jnp.reshape(jnp.einsum('ijk,ik -> ij', toconv, q_tmp), (dim,)) + + key_cache_l = key_cache[l, :, :] + key_cache_l = jnp.append(key_cache_l, jnp.reshape(k, (1, dim)), axis=0) + value_cache_l = value_cache[l, :, :] + value_cache_l = jnp.append(value_cache_l, jnp.reshape(v, (1, dim)), axis=0) + keys2.append(key_cache_l) + values2.append(value_cache_l) + + xbs2 = [] + for h in range(n_heads): + + q2 = q[head_size*h:head_size*(h+1)] + if asserts: assert q2.shape == (head_size,) + + # For kv_mul consecutive heads, they share the same kv cache + # reshape key_cache last dim from (kv_dim,) to (kv_mul, head_size) + # generalized einsum reducing the last dim, the rest are batch + att = [] + + key_index = h // kv_mul + + att = jnp.einsum('ij,j->i', key_cache_l[:, key_index * head_size : (key_index+1) * head_size], q2) + + att = att / jnp.sqrt(head_size) + + att = softmax(att) + + x_tmp = jnp.einsum('ij,i->j', value_cache_l[:, key_index * head_size : (key_index+1) * head_size], att) + + xbs2.append(x_tmp) + + # Todo right concat + xb = jnp.concatenate(xbs2, axis=None) + + xb2 = wo[l, :, :] @ xb + if asserts: assert xb2.shape == (dim, ) + + x += xb2 + + # Rmsnorm and feedforward swiglu + + xb = rmsnorm(x, rms_ffn_weight[l, :]) + + # Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) + # first calculate self.w1(x) and self.w3(x) + + + hb = w1[l, :, :] @ xb + hb2 = w3[l, :, :] @ xb + + hb = silu(hb) + + hb = hb * hb2 + + + xb = w2[l, :, :] @ hb + + x += xb + + + x = rmsnorm(x, rms_final_weight) + logits = wcls @ x + + return x + +import numpy as np + +config = {'dim': 288, 'hidden_dim': 768, 'n_layers': 6, 'n_heads': 6, 'n_kv_heads': 6, 'vocab_size': 32000, 'seq_len': 256} + +n_layers = config['n_layers'] +seq_len = config['seq_len'] +n_heads = config['n_heads'] +dim = config['dim'] +n_kv_heads = config['n_kv_heads'] +vocab_size = config['vocab_size'] +hidden_dim = config['hidden_dim'] +kv_dim = dim // n_heads * n_kv_heads +head_size = dim // n_heads + +key = jax.random.PRNGKey(0) +weights = {} +dweights = {} + +for name, shape in [("rms_att_weight", (n_layers, dim)), + ("wq", (n_layers, dim, n_heads * head_size)), + ("wk", (n_layers, dim, n_kv_heads * head_size)), + ("wv", (n_layers, dim, n_kv_heads * head_size)), + ("wo", (n_layers, dim, dim)), + ("rms_ffn_weight", (n_layers, dim)), + ("w1", (n_layers, hidden_dim, dim)), + ("w2", (n_layers, dim, hidden_dim)), + ("w3", (n_layers, hidden_dim, dim)), + ("rms_final_weight", (dim,)), + ("wcls", (vocab_size, dim)) + ]: + key, subkey = jax.random.split(key) + key, subkey2 = jax.random.split(key) + weights[name] = jax.random.uniform(subkey, shape=shape) + dweights[name] = jax.random.uniform(subkey2, shape=shape) + +key, subkey = jax.random.split(key) +x = jax.random.uniform(subkey, shape=(dim,)) +key, subkey = jax.random.split(key) +dx = jax.random.uniform(subkey, shape=(dim,)) + +def partial(func, config): + def sfn(x, weights, key_cache, value_cache): + return func(x, config, weights, key_cache, value_cache) + return sfn + +pos = 1 +key_cache = jnp.zeros((n_layers, pos,kv_dim)) +value_cache = jnp.zeros((n_layers, pos,kv_dim)) + +key, subkey = jax.random.split(key) +dkc = jax.random.uniform(subkey, shape=(n_layers,pos+1,kv_dim)) +key, subkey = jax.random.split(key) +dvc = jax.random.uniform(subkey, shape=(n_layers,pos+1,kv_dim)) + +func = partial(forward, config) + +@jax.jit +def jfunc(x, weights, key_cache, value_cache): + return func(x, weights, key_cache, value_cache) + +@enzyme_jax.enzyme_jax_ir() +def efunc(x, weights, key_cache, value_cache): + return func(x, weights, key_cache, value_cache) + +# eres = efunc(x, weights, key_cache, value_cache) +# print("Enzyme primal", eres) +# res = func(x, weights, key_cache, value_cache) +# print("Jax primal", res) +# print (" max error", jnp.max(jnp.abs(eres-res))) +# assert (jnp.abs(eres - res) < 1e-3).all() + +#jfunc = jax.jit(partial(forward, config)) +# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo") + +@jax.jit +def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc): + return jax.jvp(jfunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) + +@jax.jit +def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc): + return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc)) + +# print("pre fwd diff") +# eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) +# print("Enzyme fwd", eres) +# jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache) +# print("Jax fwd", jres) + + +@jax.jit +def jrev(x, weights, kc, vc, dx, dkc, dvc): + primals, f_vjp = jax.vjp(jfunc, x, weights, kc, vc) + return f_vjp(dx) #, dkc, dvc) + +@jax.jit +def erev(x, weights, kc, vc, dx, dkc, dvc): + primals, f_vjp = jax.vjp(efunc, x, weights, kc, vc) + return f_vjp(dx) #, dkc, dvc) + +eres = erev(x, weights, key_cache, value_cache, dx, dkc, dvc) +print("Enzyme rev", eres) +jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc) +print("Jax rev", jres) +