Skip to content

Commit

Permalink
continuing
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 15, 2024
1 parent 3bfd8a3 commit 5a43ae8
Showing 1 changed file with 85 additions and 7 deletions.
92 changes: 85 additions & 7 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,20 +214,88 @@ struct LLVMFunc{F,tt}
end


const GPUCompiler = CUDA.GPUCompiler
const LLVM = GPUCompiler.LLVM


GPULowerCPUFeaturesPass() = LLVM.NewPMModulePass("GPULowerCPUFeatures", GPUCompiler.cpu_features!)
GPULowerPTLSPass() = LLVM.NewPMModulePass("GPULowerPTLS", GPUCompiler.lower_ptls!)
GPULowerGCFramePass() = LLVM.NewPMFunctionPass("GPULowerGCFrame", GPUCompiler.lower_gc_frame!)
function noop_pass(x)
return false
end
function kern_pass(mod)
for fname in ("julia.gpu.state_getter",)
if LLVM.haskey(LLVM.functions(mod), fname)
fn = LLVM.functions(mod)[fname]
insts = LLVM.Instruction[]
for u in LLVM.uses(fn)
u = LLVM.user(u)
LLVM.replace_uses!(u, LLVM.UndefValue(LLVM.value_type(u)))
push!(insts, u)
end
for inst in insts
Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst)
end
Reactant.Enzyme.Compiler.eraseInst(mod, fn)
end
end

return true
end
AddKernelStatePass() = LLVM.NewPMModulePass("AddKernelStatePass", kern_pass)
LowerKernelStatePass() = LLVM.NewPMFunctionPass("LowerKernelStatePass", noop_pass)
CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_pass)

# compile to executable machine code
function compile(job)

# lower to PTX
# TODO: on 1.9, this actually creates a context. cache those.
modstr, image, entry = CUDA.GPUCompiler.JuliaContext() do ctx
asm, meta = CUDA.GPUCompiler.compile(:asm, job)
mod = meta.ir

modstr, image, entry = GPUCompiler.JuliaContext() do ctx
mod, meta = GPUCompiler.compile(:llvm, job; optimize=false, cleanup=false, validate=false)
GPUCompiler.optimize_module!(job, mod)
opt_level = 2
tm = GPUCompiler.llvm_machine(job.config.target)
LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin
LLVM.register!(pb, GPULowerCPUFeaturesPass())
LLVM.register!(pb, GPULowerPTLSPass())
LLVM.register!(pb, GPULowerGCFramePass())
LLVM.register!(pb, AddKernelStatePass())
LLVM.register!(pb, LowerKernelStatePass())
LLVM.register!(pb, CleanupKernelStatePass())

LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm
GPUCompiler.buildNewPMPipeline!(mpm, job, opt_level)
end
LLVM.run!(pb, mod, tm)
end
GPUCompiler.optimize_module!(job, mod)
LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm)


for fname in ("gpu_report_exception", "gpu_signal_exception")
if LLVM.haskey(LLVM.functions(mod), fname)
fn = LLVM.functions(mod)[fname]
insts = LLVM.Instruction[]
for u in LLVM.uses(fn)
push!(insts, LLVM.user(u))
end
for inst in insts
Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst)
end
Reactant.Enzyme.Compiler.eraseInst(mod, fn)
end
end

LLVM.strip_debuginfo!(mod)
modstr = string(mod)

# This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version
# it is probably safer to reparse a string using the right llvm module api, so we will do that.

mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMToMLIR(mod::CUDA.LLVM.API.LLVMModuleRef, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule)
println(string(modstr))
mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule)
@show mmod

# check if we'll need the device runtime
Expand Down Expand Up @@ -461,8 +529,18 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg
cache = compiler_cache(MLIR.IR.context())
source = CUDA.methodinstance(F, tt)

cuda = CUDA.active_state()
config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig
# cuda = CUDA.active_state()
device = nothing # cuda.device
# config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig
cuda_cap=v"5.0"
cuda_ptx=v"6.3"
llvm_cap=v"5.0"
llvm_ptx=v"6.3"
kernel=true
always_inline=false
name=nothing
debuginfo=false
config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline)
CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link)
end
res
Expand Down

0 comments on commit 5a43ae8

Please sign in to comment.