diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index ad13922f3..adf35e0aa 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -2,7 +2,7 @@ module ReactantCUDAExt using CUDA using Reactant: - Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber + Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber using ReactantCore: @trace using Adapt @@ -465,7 +465,7 @@ function transpose_val(val) return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) end -function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, +Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} @show args @show call_kwargs @@ -522,7 +522,7 @@ function compiler_cache(ctx::MLIR.IR.Context) return cache end -Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} +Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} @show "recufunction", f, tt res = Base.@lock CUDA.cufunction_lock begin # compile the function @@ -543,6 +543,7 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline) CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) end + @show res res end diff --git a/src/utils.jl b/src/utils.jl index b65077c03..3fdd646e0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -307,78 +307,6 @@ function call_with_reactant_generator( # No method could be found (including in our method table), bail with an error if lookup_result == nothing return stub(world, source, method_error) - tmp_min_world = Ref{UInt}(typemin(UInt)) - tmp_max_world = Ref{UInt}(typemax(UInt)) - match = ccall( - :jl_gf_invoke_lookup_worlds, - Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - Tuple{typeof(throw_method_error),sig}, - nothing, - world, - tmp_min_world, - tmp_max_world, - ) #=mt=# - @assert match !== nothing - - # look up the method and code instance - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - match.method, - match.spec_types, - match.sparams, - ) - - ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo - - src = copy(ci) - src.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] - - src.edges = Any[ - ccall(:jl_method_table_for, Any, (Any,), sig)::Core.MethodTable, sig - ] - src.min_world = min_world[] - src.max_world = max_world[] - - push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), 1))) - push!(overdubbed_codelocs, 0) - - expr_fn = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :($(Base.lastindex)($(Core.Argument(2))))) - push!(overdubbed_codelocs, 0) - - expr_lastindex = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :(2:($expr_lastindex))) - push!(overdubbed_codelocs, 0) - - expr_slice = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), $expr_slice))) - push!(overdubbed_codelocs, 0) - - expr_args = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :($(Base.MethodError)($expr_fn, $expr_args, $world))) - push!(overdubbed_codelocs, 0) - - expr_method = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :($(Base.throw)($expr_method))) - push!(overdubbed_codelocs, 0) - - push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code)))) - push!(overdubbed_codelocs, 0) - - src.code = overdubbed_code - src.codelocs = overdubbed_codelocs - src.ssavaluetypes = length(overdubbed_code) - src.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - - return src end match = lookup_result::Core.MethodMatch @@ -438,17 +366,19 @@ function call_with_reactant_generator( # Also rewrite invoke (type stable call) to be :call, since otherwise apparently # screws up type inference after this (TODO this should be fixed). any_changed = false - for (i, inst) in enumerate(ir.stmts) - @static if VERSION < v"1.11" - changed, next = rewrite_inst(inst[:inst], ir, interp) - Core.Compiler.setindex!(ir.stmts[i], next, :inst) - else - changed, next = rewrite_inst(inst[:stmt], ir, interp) - Core.Compiler.setindex!(ir.stmts[i], next, :stmt) - end - if changed - any_changed = true - Core.Compiler.setindex!(ir.stmts[i], Any, :type) + if should_rewrite_ft(args[1]) && !is_reactant_method(mi) + for (i, inst) in enumerate(ir.stmts) + @static if VERSION < v"1.11" + changed, next = rewrite_inst(inst[:inst], ir, interp) + Core.Compiler.setindex!(ir.stmts[i], next, :inst) + else + changed, next = rewrite_inst(inst[:stmt], ir, interp) + Core.Compiler.setindex!(ir.stmts[i], next, :stmt) + end + if changed + any_changed = true + Core.Compiler.setindex!(ir.stmts[i], Any, :type) + end end end