From 22db5dfd1b371794ff99a9095432c595294e7eb4 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:30:18 -0600 Subject: [PATCH 01/18] Kernel-supporting jll --- deps/ReactantExtra/BUILD | 1 + deps/ReactantExtra/WORKSPACE | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index c538bbb8..559533c7 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -426,6 +426,7 @@ cc_library( "-Wl,-exported_symbol,_ifrt_*", "-Wl,-exported_symbol,_RegisterCustomCallTarget", "-Wl,-exported_symbol,_ConvertLLVMToMLIR", +"-Wl,-exported_symbol,_EnzymeGPUCustomCall", ]}), deps = [ "@enzyme//:EnzymeMLIR", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 174cc671..bb83aae7 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "f6587e37ff7298f2a1a273b08c24d69fca7ff30f" +ENZYMEXLA_COMMIT = "e059f8c6e559c92846b110537c9a8b53f65ec053" ENZYMEXLA_SHA256 = "" http_archive( From b6d0615600b0291dca2061180375fba20271ef25 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:34:00 -0600 Subject: [PATCH 02/18] fix rulescc --- deps/ReactantExtra/WORKSPACE | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index bb83aae7..42823c35 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -57,19 +57,6 @@ sed -i.bak0 "s/patch_cmds = \\[/patch_cmds = \\[\\\"find . -type f -name config. # """, ] -http_archive( - name = "rules_cc", - sha256 = "85723d827f080c5e927334f1fb18a294c0b3f94fee6d6b45945f5cdae6ea0fd4", - strip_prefix = "rules_cc-c8c38f8c710cbbf834283e4777916b68261b359c", - urls = [ - "https://github.com/bazelbuild/rules_cc/archive/c8c38f8c710cbbf834283e4777916b68261b359c.tar.gz", - ], -) - -load("@rules_cc//cc:repositories.bzl", "rules_cc_dependencies") - -rules_cc_dependencies() - LLVM_TARGETS = select({ "@bazel_tools//src/conditions:windows": ["AMDGPU", "NVPTX"], "@bazel_tools//src/conditions:darwin": [], From f6b223862a903a7445adbc8e95c9dd67701e5c1c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:37:19 -0600 Subject: [PATCH 03/18] adapt to hedron dep --- deps/ReactantExtra/WORKSPACE | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 42823c35..6c35de81 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -19,6 +19,27 @@ http_archive( urls = ["https://github.com/EnzymeAD/Enzyme-JAX/archive/{commit}.tar.gz".format(commit = ENZYMEXLA_COMMIT)], ) + +# Hedron's Compile Commands Extractor for Bazel +# https://github.com/hedronvision/bazel-compile-commands-extractor +http_archive( + name = "hedron_compile_commands", + + # Replace the commit hash (0e990032f3c5a866e72615cf67e5ce22186dcb97) in both places (below) with the latest (https://github.com/hedronvision/bazel-compile-commands-extractor/commits/main), rather than using the stale one here. + # Even better, set up Renovate and let it do the work for you (see "Suggestion: Updates" in the README). + url = "https://github.com/hedronvision/bazel-compile-commands-extractor/archive/4f28899228fb3ad0126897876f147ca15026151e.tar.gz", + strip_prefix = "bazel-compile-commands-extractor-4f28899228fb3ad0126897876f147ca15026151e", + # When you first run this tool, it'll recommend a sha256 hash to put here with a message like: "DEBUG: Rule 'hedron_compile_commands' indicated that a canonical reproducible form can be obtained by modifying arguments sha256 = ..." +) +load("@hedron_compile_commands//:workspace_setup.bzl", "hedron_compile_commands_setup") +hedron_compile_commands_setup() +load("@hedron_compile_commands//:workspace_setup_transitive.bzl", "hedron_compile_commands_setup_transitive") +hedron_compile_commands_setup_transitive() +load("@hedron_compile_commands//:workspace_setup_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive") +hedron_compile_commands_setup_transitive_transitive() +load("@hedron_compile_commands//:workspace_setup_transitive_transitive_transitive.bzl", "hedron_compile_commands_setup_transitive_transitive_transitive") +hedron_compile_commands_setup_transitive_transitive_transitive() + load("@enzyme_ad//:workspace.bzl", "JAX_COMMIT", "JAX_SHA256", "ENZYME_COMMIT", "ENZYME_SHA256", "XLA_PATCHES") XLA_PATCHES = XLA_PATCHES + [ From 612184759b8e5768afb35046a9196d72b497b751 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:42:35 -0600 Subject: [PATCH 04/18] init target --- src/XLA.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/XLA.jl b/src/XLA.jl index 00420edb..b21999f9 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -131,6 +131,7 @@ function __init__() end end + @ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c)), "CUDA") return nothing end From d62d01ed98924412722e4a0fc27a14a6adaf8797 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:48:34 -0600 Subject: [PATCH 05/18] fixup --- deps/ReactantExtra/API.cpp | 1 - ext/ReactantCUDAExt.jl | 25 ++++++++++++++++++++++--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 3ae7a7eb..e3068aa5 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -50,7 +50,6 @@ #include "xla/pjrt/pjrt_executable.h" #include "xla/pjrt/pjrt_api.h" #include "xla/pjrt/pjrt_c_api_client.h" -#include "xla/service/cpu/simple_orc_jit.h" #include "xla/python/ifrt/hlo/hlo_program.h" #include "llvm/MC/TargetRegistry.h" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index b38b5500..2f87c97c 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -352,10 +352,29 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c fname = Reactant.TracedUtils.get_attribute_by_name(func.entry, "sym_name") # Force public for now while we don't have real users - MLIR.IR.rmattr!(func.entry, "sym_visibility") + # MLIR.IR.rmattr!(func.entry, "sym_visibility") + + op_ty_results = IR.Type[result_0...,] + operands = Value[inputs...,] + owned_regions = MLIR.IR.Region[] + successors = MLIR.IR.Block[] + attributes = MLIR.IR.NamedAttribute[ + MLIR.IR.namedattribute("fn", fname), + MLIR.IR.namedattribute("output_operand_aliases", output_operand_aliases) + ] + + location = MLIR.IR.Location() + call = MLIR.IR.create_operation( + "stablehlo.custom_call", + location; + mlir_args, + owned_regions, + successors, + attributes, + results=restys + result_inference=false, + ) - call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(fname)) - # call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod)) for (i, res) in enumerate(rarrays) res.mlir_data = transpose_val(MLIR.IR.result(call, i)) end From c3192d92725d8c404eb504d5fe3ca162dfb1f561 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 18:57:33 -0600 Subject: [PATCH 06/18] additional fixups --- deps/ReactantExtra/BUILD | 3 +++ ext/ReactantCUDAExt.jl | 10 ++++++++-- src/Compiler.jl | 19 ++++++++++++++++++- test/runtests.jl | 34 +--------------------------------- 4 files changed, 30 insertions(+), 36 deletions(-) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 559533c7..b5daf012 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -470,6 +470,9 @@ cc_library( "@xla//xla/pjrt:pjrt_c_api_client", "@xla//xla/pjrt/cpu:cpu_client", + "@xla//xla/service:metrics_proto_cc", + "@xla//xla/service:metrics_proto_cc_impl", + "@xla//xla/service/cpu:cpu_compiler", "@xla//xla/stream_executor/tpu:tpu_on_demand_compiler", "@xla//xla/stream_executor/tpu:tpu_executor", diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 2f87c97c..61306986 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -355,7 +355,13 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c # MLIR.IR.rmattr!(func.entry, "sym_visibility") op_ty_results = IR.Type[result_0...,] - operands = Value[inputs...,] + operands = MLIR.IR.Value[] + for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z) + push!(operands, TracedUtils.promote_to(TracedRNumber{Int}, idx).mlir_data) + end + for arg in mlir_ir_args + push!(operands, arg) + end owned_regions = MLIR.IR.Region[] successors = MLIR.IR.Block[] attributes = MLIR.IR.NamedAttribute[ @@ -367,7 +373,7 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c call = MLIR.IR.create_operation( "stablehlo.custom_call", location; - mlir_args, + operands, owned_regions, successors, attributes, diff --git a/src/Compiler.jl b/src/Compiler.jl index cc32a90b..b2fa8cc4 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -305,6 +305,22 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) optimize isa Bool && (optimize = ifelse(optimize, :all, :none)) if optimize === :all + run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) + run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) + run_pass_pipeline!( + mod, + join( + [ + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + opt_passes, + "lower-kernel" + ], + ',', + ), + ) + elseif optimize === :before_kernel run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( @@ -340,6 +356,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", opt_passes, + "lower-kernel" ], ',', ), @@ -348,7 +365,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( - mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math" + mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,lower-kernel" ) elseif optimize !== :none error("Invalid optimize option: $(Meta.quot(optimize))") diff --git a/test/runtests.jl b/test/runtests.jl index fddc963c..297df1a3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,36 +41,4 @@ end const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) -@testset "Reactant.jl Tests" begin - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" - @safetestset "Layout" include("layout.jl") - @safetestset "Tracing" include("tracing.jl") - @safetestset "Basic" include("basic.jl") - @safetestset "Autodiff" include("autodiff.jl") - @safetestset "Complex" include("complex.jl") - @safetestset "Broadcast" include("bcast.jl") - @safetestset "Struct" include("struct.jl") - @safetestset "Closure" include("closure.jl") - @safetestset "Compile" include("compile.jl") - @safetestset "Buffer Donation" include("buffer_donation.jl") - @safetestset "Shortcuts to MLIR ops" include("ops.jl") - @safetestset "Wrapped Arrays" include("wrapped_arrays.jl") - @safetestset "Control Flow" include("control_flow.jl") - end - - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" - @safetestset "Linear Algebra" include("integration/linear_algebra.jl") - @safetestset "AbstractFFTs" include("integration/fft.jl") - end - - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" - @testset "Neural Networks" begin - @safetestset "NNlib Primitives" include("nn/nnlib.jl") - @safetestset "Flux.jl Integration" include("nn/flux.jl") - if Sys.islinux() - @safetestset "LuxLib Primitives" include("nn/luxlib.jl") - @safetestset "Lux Integration" include("nn/lux.jl") - end - end - end -end +include("cuda.jl") From e86fac616dafe809a661df99d25c1e839f08411c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:03:21 -0600 Subject: [PATCH 07/18] fixup --- src/XLA.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/XLA.jl b/src/XLA.jl index b21999f9..3fcf4423 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -131,7 +131,7 @@ function __init__() end end - @ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c)), "CUDA") + @ccall MLIR.API.mlir_c.RegisterCustomCallTarget("enzymexla_gpu"::Cstring, cglobal((:EnzymeGPUCustomCall, MLIR.API.mlir_c))::Ptr{Cvoid}, "CUDA"::Cstring)::Cvoid return nothing end From 6cea3dfbc1d6608fe09091510e09eef12a529677 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:05:26 -0600 Subject: [PATCH 08/18] parse --- ext/ReactantCUDAExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 61306986..6b9e6360 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -377,7 +377,7 @@ Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; c owned_regions, successors, attributes, - results=restys + results=restys, result_inference=false, ) From e3f4df202bc3ab2f44aa4f60eaa49c972e30610d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:13:09 -0600 Subject: [PATCH 09/18] overlay --- ext/ReactantCUDAExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 6b9e6360..414876e6 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -317,7 +317,7 @@ function transpose_val(val) return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) end -Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, +Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} @show call_kwargs From 019a69052a27dce33f080421967f862d601f0019 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:28:44 -0600 Subject: [PATCH 10/18] fix --- deps/ReactantExtra/API.cpp | 12 +++++------- deps/ReactantExtra/WORKSPACE | 2 +- ext/ReactantCUDAExt.jl | 15 +++++++-------- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index e3068aa5..c08b7921 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -459,15 +459,13 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable* exec, int num_args, PjRtBu } } +void prepareRegistry(mlir::DialectRegistry ®istry); + extern "C" void RegisterDialects(MlirContext cctx) { mlir::MLIRContext &context = *unwrap(cctx); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); + DialectRegistry registry; + prepareRegistry(registry); + context.appendDialectRegistry(registry); } #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 6c35de81..88db2332 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "e059f8c6e559c92846b110537c9a8b53f65ec053" +ENZYMEXLA_COMMIT = "fb483c06f697990c60cc3c0bda7fb1d730fca3de" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 414876e6..01df663b 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -354,24 +354,23 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; co # Force public for now while we don't have real users # MLIR.IR.rmattr!(func.entry, "sym_visibility") - op_ty_results = IR.Type[result_0...,] operands = MLIR.IR.Value[] - for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z) - push!(operands, TracedUtils.promote_to(TracedRNumber{Int}, idx).mlir_data) + for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem) + push!(operands, Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, idx).mlir_data) end - for arg in mlir_ir_args + for arg in mlir_args push!(operands, arg) end owned_regions = MLIR.IR.Region[] successors = MLIR.IR.Block[] attributes = MLIR.IR.NamedAttribute[ - MLIR.IR.namedattribute("fn", fname), - MLIR.IR.namedattribute("output_operand_aliases", output_operand_aliases) + MLIR.IR.NamedAttribute("fn", fname), + MLIR.IR.NamedAttribute("output_operand_aliases", output_operand_aliases) ] location = MLIR.IR.Location() call = MLIR.IR.create_operation( - "stablehlo.custom_call", + "enzymexla.kern_call", location; operands, owned_regions, @@ -404,7 +403,7 @@ function compiler_cache(ctx::MLIR.IR.Context) return cache end -Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} +Reactant.@reactant_overlay @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} res = Base.@lock CUDA.cufunction_lock begin # compile the function cache = compiler_cache(MLIR.IR.context()) From 7a4a403041a43c7a8cae26308a6dcb660b52d3bf Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:33:59 -0600 Subject: [PATCH 11/18] registry utils --- deps/ReactantExtra/BUILD | 1 + test/cuda.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index b5daf012..44953fd3 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -358,6 +358,7 @@ cc_library( ], ) + [ + "@enzyme_ad//src/enzyme_ad/jax:RegistryUtils.cpp", # "@com_google_protobuf//:src/google/protobuf/io/coded_stream.cc", "@xla//xla:xla.pb.cc", "@xla//xla:xla_data.pb.cc", diff --git a/test/cuda.jl b/test/cuda.jl index 05d0777c..7fcd1150 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -19,6 +19,7 @@ end oA = collect(1:1:64) A = Reactant.to_rarray(oA) @show @code_hlo optimize=false square!(A) + @show @code_hlo optimize=:before_kernel square!(A) @show @code_hlo square!(A) func = @compile square!(A) @test all(Array(A) .≈ (oA .* oA)) From 8fef9ff032ee7c44fe686c98ef358d71163fd01a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:40:06 -0600 Subject: [PATCH 12/18] callname --- ext/ReactantCUDAExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 01df663b..fb0a084c 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -370,7 +370,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; co location = MLIR.IR.Location() call = MLIR.IR.create_operation( - "enzymexla.kern_call", + "enzymexla.kernel_call", location; operands, owned_regions, From fe43909cda51cb00cfecd8f0e681cb663d9de648 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:50:44 -0600 Subject: [PATCH 13/18] reg --- deps/ReactantExtra/API.cpp | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index c08b7921..f603105a 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -473,34 +473,10 @@ extern "C" void RegisterDialects(MlirContext cctx) { #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::DialectRegistry ®istry = *unwrap(creg); - - // Register MLIR stuff - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - - registry.insert(); + prepareRegistry(registry); mlir::registerenzymePasses(); regsiterenzymeXLAPasses(); - mlir::enzyme::registerXLAAutoDiffInterfaces(registry); - - mlir::func::registerInlinerExtension(registry); // Register the standard passes we want. mlir::registerCSEPass(); @@ -517,7 +493,6 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::registerLLVMDialectImport(registry); mlir::registerNVVMDialectImport(registry); - mlir::LLVM::registerInlinerInterface(registry); /* registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { @@ -535,15 +510,10 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { }); */ - // Register the autodiff interface implementations for upstream dialects. - enzyme::registerCoreDialectAutodiffInterfaces(registry); - // Transform dialect and extensions. mlir::transform::registerInterpreterPass(); - mlir::linalg::registerTransformDialectExtension(registry); mlir::enzyme::registerGenerateApplyPatternsPass(); mlir::enzyme::registerRemoveTransformPass(); - mlir::enzyme::registerEnzymeJaxTransformExtension(registry); } From be1951971be0bbb1499a9cd375f7d43d8ba43f45 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 20:10:38 -0600 Subject: [PATCH 14/18] fix --- deps/ReactantExtra/API.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index f603105a..8be25aa8 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -466,6 +466,14 @@ extern "C" void RegisterDialects(MlirContext cctx) { DialectRegistry registry; prepareRegistry(registry); context.appendDialectRegistry(registry); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); } #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" From b0f679f430d31af7decff3333b40b3e81cd0a132 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 20:25:26 -0600 Subject: [PATCH 15/18] fix bld --- deps/ReactantExtra/API.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 8be25aa8..f374e339 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -11,6 +11,7 @@ #include "Enzyme/MLIR/Passes/Passes.h" #include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" #include "src/enzyme_ad/jax/Passes/Passes.h" +#include "src/enzyme_ad/jax/Dialect/Dialect.h" #include "src/enzyme_ad/jax/TransformOps/TransformOps.h" #include "mlir/Conversion/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -36,6 +37,7 @@ #include "mlir/Transforms/Passes.h" #include "llvm/Support/TargetSelect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -500,7 +502,7 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) { mlir::registerLLVMDialectImport(registry); mlir::registerNVVMDialectImport(registry); - + mlir::LLVM::registerInlinerInterface(registry); /* registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) { From 1cd21ff3a042de9954c06b26de94219cef6614fe Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 21:01:00 -0600 Subject: [PATCH 16/18] cleanup --- deps/ReactantExtra/WORKSPACE | 2 +- ext/ReactantCUDAExt.jl | 13 ++----------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 88db2332..2e04c1ea 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "fb483c06f697990c60cc3c0bda7fb1d730fca3de" +ENZYMEXLA_COMMIT = "3ce6c51887642fa85313dd17a4bbde227e109a35" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index fb0a084c..5557a59e 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -319,8 +319,6 @@ end Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} - @show call_kwargs - blockdim = CUDA.CuDim3(blocks) threaddim = CUDA.CuDim3(threads) @@ -364,8 +362,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; co owned_regions = MLIR.IR.Region[] successors = MLIR.IR.Block[] attributes = MLIR.IR.NamedAttribute[ - MLIR.IR.NamedAttribute("fn", fname), - MLIR.IR.NamedAttribute("output_operand_aliases", output_operand_aliases) + MLIR.IR.NamedAttribute("fn", MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))), + MLIR.IR.NamedAttribute("output_operand_aliases", MLIR.IR.Attribute(output_operand_aliases)) ] location = MLIR.IR.Location() @@ -383,13 +381,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; co for (i, res) in enumerate(rarrays) res.mlir_data = transpose_val(MLIR.IR.result(call, i)) end - - @show blockdim - @show threaddim - #CUDA.cuLaunchKernel(f, - # blockdim.x, blockdim.y, blockdim.z, - # threaddim.x, threaddim.y, threaddim.z, - # shmem, stream, kernelParams, C_NULL) end # cache of compilation caches, per context From 8b5d1316b48bce9245640094478488a622b4aa67 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 21:44:45 -0600 Subject: [PATCH 17/18] no pip --- deps/ReactantExtra/WORKSPACE | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 2e04c1ea..d43833fd 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "3ce6c51887642fa85313dd17a4bbde227e109a35" +ENZYMEXLA_COMMIT = "dea63960da134128b152c1624d1425048cd9fb3a" ENZYMEXLA_SHA256 = "" http_archive( @@ -105,7 +105,7 @@ http_archive( load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") python_init_rules() - + load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories") python_init_repositories( requirements = { @@ -116,23 +116,23 @@ python_init_repositories( "3.13": "//build:requirements_lock_3_13.txt", }, ) - + load("@xla//third_party/py:python_init_toolchains.bzl", "python_init_toolchains") python_init_toolchains() - -load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") -python_init_pip() - -load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") -python_init_rules() - -load("@rules_python//python:repositories.bzl", "py_repositories") - -py_repositories() - -load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependencies") - -pip_install_dependencies() +# +# load("@xla//third_party/py:python_init_pip.bzl", "python_init_pip") +# python_init_pip() +# +# load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules") +# python_init_rules() +# +# load("@rules_python//python:repositories.bzl", "py_repositories") +# +# py_repositories() +# +# load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependencies") +# +# pip_install_dependencies() http_archive( name = "enzyme", From 3072670eeaebd97f09d3bcf9937b05f91a02ba62 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 22:05:58 -0600 Subject: [PATCH 18/18] fix --- src/Compiler.jl | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index b2fa8cc4..39eea231 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -1,5 +1,7 @@ module Compiler +using Reactant_jll + import ..Reactant: Reactant, MLIR, @@ -304,6 +306,11 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) optimize isa Bool && (optimize = ifelse(optimize, :all, :none)) + toolkit = "" + if isdefined(Reactant_jll, :ptxas_path) + toolkit = Reactant_jll.ptxas_path[1:end-length("/bin/ptxas")] + end + kern = "lower-kernel{toolkitPath=$toolkit}" if optimize === :all run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) @@ -315,7 +322,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", opt_passes, - "lower-kernel" + kern ], ',', ), @@ -356,7 +363,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", opt_passes, - "lower-kernel" + kern ], ',', ), @@ -365,7 +372,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ",")) run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false) run_pass_pipeline!( - mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,lower-kernel" + mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,"*kern ) elseif optimize !== :none error("Invalid optimize option: $(Meta.quot(optimize))")