Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 17, 2024
1 parent e3f4df2 commit 019a690
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
12 changes: 5 additions & 7 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,15 +459,13 @@ extern "C" void XLAExecute(xla::PjRtLoadedExecutable* exec, int num_args, PjRtBu
}
}

void prepareRegistry(mlir::DialectRegistry &registry);

extern "C" void RegisterDialects(MlirContext cctx) {
mlir::MLIRContext &context = *unwrap(cctx);
context.loadDialect<mlir::arith::ArithDialect>();
context.loadDialect<mlir::enzyme::EnzymeDialect>();
context.loadDialect<mlir::tensor::TensorDialect>();
context.loadDialect<mlir::func::FuncDialect>();
context.loadDialect<mlir::mhlo::MhloDialect>();
context.loadDialect<mlir::stablehlo::StablehloDialect>();
context.loadDialect<mlir::chlo::ChloDialect>();
DialectRegistry registry;
prepareRegistry(registry);
context.appendDialectRegistry(registry);
}

#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
Expand Down
2 changes: 1 addition & 1 deletion deps/ReactantExtra/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ http_archive(
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
)

ENZYMEXLA_COMMIT = "e059f8c6e559c92846b110537c9a8b53f65ec053"
ENZYMEXLA_COMMIT = "fb483c06f697990c60cc3c0bda7fb1d730fca3de"
ENZYMEXLA_SHA256 = ""

http_archive(
Expand Down
15 changes: 7 additions & 8 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,24 +354,23 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(args...; co
# Force public for now while we don't have real users
# MLIR.IR.rmattr!(func.entry, "sym_visibility")

op_ty_results = IR.Type[result_0...,]
operands = MLIR.IR.Value[]
for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z)
push!(operands, TracedUtils.promote_to(TracedRNumber{Int}, idx).mlir_data)
for idx in (blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem)
push!(operands, Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, idx).mlir_data)
end
for arg in mlir_ir_args
for arg in mlir_args
push!(operands, arg)
end
owned_regions = MLIR.IR.Region[]
successors = MLIR.IR.Block[]
attributes = MLIR.IR.NamedAttribute[
MLIR.IR.namedattribute("fn", fname),
MLIR.IR.namedattribute("output_operand_aliases", output_operand_aliases)
MLIR.IR.NamedAttribute("fn", fname),
MLIR.IR.NamedAttribute("output_operand_aliases", output_operand_aliases)
]

location = MLIR.IR.Location()
call = MLIR.IR.create_operation(
"stablehlo.custom_call",
"enzymexla.kern_call",
location;
operands,
owned_regions,
Expand Down Expand Up @@ -404,7 +403,7 @@ function compiler_cache(ctx::MLIR.IR.Context)
return cache
end

Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
Reactant.@reactant_overlay @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
res = Base.@lock CUDA.cufunction_lock begin
# compile the function
cache = compiler_cache(MLIR.IR.context())
Expand Down

0 comments on commit 019a690

Please sign in to comment.