From 019a69052a27dce33f080421967f862d601f0019 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Mon, 16 Dec 2024 19:28:44 -0600 Subject: [PATCH] fix --- deps/ReactantExtra/API.cpp | 12 +++++------- deps/ReactantExtra/WORKSPACE | 2 +- ext/ReactantCUDAExt.jl | 15 +++++++-------- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index e3068aa5..c08b7921 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -459,15 +459,13 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable* exec, int num_args, PjRtBu } } +void prepareRegistry(mlir::DialectRegistry ®istry); + extern "C" void RegisterDialects(MlirContext cctx) { mlir::MLIRContext &context = *unwrap(cctx); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); - context.loadDialect(); + DialectRegistry registry; + prepareRegistry(registry); + context.appendDialectRegistry(registry); } #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index 6c35de81..88db2332 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -9,7 +9,7 @@ http_archive( urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)], ) -ENZYMEXLA_COMMIT = "e059f8c6e559c92846b110537c9a8b53f65ec053" +ENZYMEXLA_COMMIT = "fb483c06f697990c60cc3c0bda7fb1d730fca3de" ENZYMEXLA_SHA256 = "" http_archive( diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 414876e6..01df663b 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -354,24 +354,23 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; co # Force public for now while we don't have real users # MLIR.IR.rmattr!(func.entry, "sym_visibility") - op_ty_results = IR.Type[result_0...,] operands = MLIR.IR.Value[] - for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z) - push!(operands, TracedUtils.promote_to(TracedRNumber{Int}, idx).mlir_data) + for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem) + push!(operands, Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, idx).mlir_data) end - for arg in mlir_ir_args + for arg in mlir_args push!(operands, arg) end owned_regions = MLIR.IR.Region[] successors = MLIR.IR.Block[] attributes = MLIR.IR.NamedAttribute[ - MLIR.IR.namedattribute("fn", fname), - MLIR.IR.namedattribute("output_operand_aliases", output_operand_aliases) + MLIR.IR.NamedAttribute("fn", fname), + MLIR.IR.NamedAttribute("output_operand_aliases", output_operand_aliases) ] location = MLIR.IR.Location() call = MLIR.IR.create_operation( - "stablehlo.custom_call", + "enzymexla.kern_call", location; operands, owned_regions, @@ -404,7 +403,7 @@ function compiler_cache(ctx::MLIR.IR.Context) return cache end -Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} +Reactant.@reactant_overlay @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} res = Base.@lock CUDA.cufunction_lock begin # compile the function cache = compiler_cache(MLIR.IR.context())