diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index a31c07a1..ad13922f 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -214,20 +214,88 @@ struct LLVMFunc{F,tt} end +const GPUCompiler = CUDA.GPUCompiler +const LLVM = GPUCompiler.LLVM + + +GPULowerCPUFeaturesPass() = LLVM.NewPMModulePass("GPULowerCPUFeatures", GPUCompiler.cpu_features!) +GPULowerPTLSPass() = LLVM.NewPMModulePass("GPULowerPTLS", GPUCompiler.lower_ptls!) +GPULowerGCFramePass() = LLVM.NewPMFunctionPass("GPULowerGCFrame", GPUCompiler.lower_gc_frame!) +function noop_pass(x) + return false +end +function kern_pass(mod) + for fname in ("julia.gpu.state_getter",) + if LLVM.haskey(LLVM.functions(mod), fname) + fn = LLVM.functions(mod)[fname] + insts = LLVM.Instruction[] + for u in LLVM.uses(fn) + u = LLVM.user(u) + LLVM.replace_uses!(u, LLVM.UndefValue(LLVM.value_type(u))) + push!(insts, u) + end + for inst in insts + Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst) + end + Reactant.Enzyme.Compiler.eraseInst(mod, fn) + end + end + + return true +end +AddKernelStatePass() = LLVM.NewPMModulePass("AddKernelStatePass", kern_pass) +LowerKernelStatePass() = LLVM.NewPMFunctionPass("LowerKernelStatePass", noop_pass) +CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_pass) + # compile to executable machine code function compile(job) + # lower to PTX # TODO: on 1.9, this actually creates a context. cache those. - modstr, image, entry = CUDA.GPUCompiler.JuliaContext() do ctx - asm, meta = CUDA.GPUCompiler.compile(:asm, job) - mod = meta.ir - + modstr, image, entry = GPUCompiler.JuliaContext() do ctx + mod, meta = GPUCompiler.compile(:llvm, job; optimize=false, cleanup=false, validate=false) + GPUCompiler.optimize_module!(job, mod) + opt_level = 2 + tm = GPUCompiler.llvm_machine(job.config.target) + LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin + LLVM.register!(pb, GPULowerCPUFeaturesPass()) + LLVM.register!(pb, GPULowerPTLSPass()) + LLVM.register!(pb, GPULowerGCFramePass()) + LLVM.register!(pb, AddKernelStatePass()) + LLVM.register!(pb, LowerKernelStatePass()) + LLVM.register!(pb, CleanupKernelStatePass()) + + LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm + GPUCompiler.buildNewPMPipeline!(mpm, job, opt_level) + end + LLVM.run!(pb, mod, tm) + end + GPUCompiler.optimize_module!(job, mod) + LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm) + + + for fname in ("gpu_report_exception", "gpu_signal_exception") + if LLVM.haskey(LLVM.functions(mod), fname) + fn = LLVM.functions(mod)[fname] + insts = LLVM.Instruction[] + for u in LLVM.uses(fn) + push!(insts, LLVM.user(u)) + end + for inst in insts + Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst) + end + Reactant.Enzyme.Compiler.eraseInst(mod, fn) + end + end + + LLVM.strip_debuginfo!(mod) modstr = string(mod) # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version # it is probably safer to reparse a string using the right llvm module api, so we will do that. - mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMToMLIR(mod::CUDA.LLVM.API.LLVMModuleRef, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule) + println(string(modstr)) + mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule) @show mmod # check if we'll need the device runtime @@ -461,8 +529,18 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg cache = compiler_cache(MLIR.IR.context()) source = CUDA.methodinstance(F, tt) - cuda = CUDA.active_state() - config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig + # cuda = CUDA.active_state() + device = nothing # cuda.device + # config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig + cuda_cap=v"5.0" + cuda_ptx=v"6.3" + llvm_cap=v"5.0" + llvm_ptx=v"6.3" + kernel=true + always_inline=false + name=nothing + debuginfo=false + config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline) CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) end res