Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
William Moses committed Dec 15, 2024
1 parent 8a0bfb0 commit 0c61f5d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 86 deletions.
7 changes: 4 additions & 3 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module ReactantCUDAExt

using CUDA
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
Reactant, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
using ReactantCore: @trace

using Adapt
Expand Down Expand Up @@ -465,7 +465,7 @@ function transpose_val(val)
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1)
end

function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
Reactant.@reactant_override @noinline function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1,
cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt}
@show args
@show call_kwargs
Expand Down Expand Up @@ -522,7 +522,7 @@ function compiler_cache(ctx::MLIR.IR.Context)
return cache
end

Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
Reactant.@reactant_override @noinline function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
@show "recufunction", f, tt
res = Base.@lock CUDA.cufunction_lock begin
# compile the function
Expand All @@ -543,6 +543,7 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg
config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline)
CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link)
end
@show res
res
end

Expand Down
96 changes: 13 additions & 83 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,78 +307,6 @@ function call_with_reactant_generator(
# No method could be found (including in our method table), bail with an error
if lookup_result == nothing
return stub(world, source, method_error)
tmp_min_world = Ref{UInt}(typemin(UInt))
tmp_max_world = Ref{UInt}(typemax(UInt))
match = ccall(
:jl_gf_invoke_lookup_worlds,
Any,
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
Tuple{typeof(throw_method_error),sig},
nothing,
world,
tmp_min_world,
tmp_max_world,
) #=mt=#
@assert match !== nothing

# look up the method and code instance
mi = ccall(
:jl_specializations_get_linfo,
Ref{Core.MethodInstance},
(Any, Any, Any),
match.method,
match.spec_types,
match.sparams,
)

ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo

src = copy(ci)
src.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME]

src.edges = Any[
ccall(:jl_method_table_for, Any, (Any,), sig)::Core.MethodTable, sig
]
src.min_world = min_world[]
src.max_world = max_world[]

push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), 1)))
push!(overdubbed_codelocs, 0)

expr_fn = Core.SSAValue(length(overdubbed_code))

push!(overdubbed_code, :($(Base.lastindex)($(Core.Argument(2)))))
push!(overdubbed_codelocs, 0)

expr_lastindex = Core.SSAValue(length(overdubbed_code))

push!(overdubbed_code, :(2:($expr_lastindex)))
push!(overdubbed_codelocs, 0)

expr_slice = Core.SSAValue(length(overdubbed_code))

push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), $expr_slice)))
push!(overdubbed_codelocs, 0)

expr_args = Core.SSAValue(length(overdubbed_code))

push!(overdubbed_code, :($(Base.MethodError)($expr_fn, $expr_args, $world)))
push!(overdubbed_codelocs, 0)

expr_method = Core.SSAValue(length(overdubbed_code))

push!(overdubbed_code, :($(Base.throw)($expr_method)))
push!(overdubbed_codelocs, 0)

push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code))))
push!(overdubbed_codelocs, 0)

src.code = overdubbed_code
src.codelocs = overdubbed_codelocs
src.ssavaluetypes = length(overdubbed_code)
src.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code

return src
end

match = lookup_result::Core.MethodMatch
Expand Down Expand Up @@ -438,17 +366,19 @@ function call_with_reactant_generator(
# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
# screws up type inference after this (TODO this should be fixed).
any_changed = false
for (i, inst) in enumerate(ir.stmts)
@static if VERSION < v"1.11"
changed, next = rewrite_inst(inst[:inst], ir, interp)
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
else
changed, next = rewrite_inst(inst[:stmt], ir, interp)
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
end
if changed
any_changed = true
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
if should_rewrite_ft(args[1]) && !is_reactant_method(mi)
for (i, inst) in enumerate(ir.stmts)
@static if VERSION < v"1.11"
changed, next = rewrite_inst(inst[:inst], ir, interp)
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
else
changed, next = rewrite_inst(inst[:stmt], ir, interp)
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
end
if changed
any_changed = true
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
end
end
end

Expand Down

0 comments on commit 0c61f5d

Please sign in to comment.