From 715975645e1f2f41a6b0890a4d81df2b4c436939 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 29 Nov 2024 17:49:17 -0500 Subject: [PATCH 01/78] WIP: kernels --- Project.toml | 3 +++ test/Project.toml | 1 + test/cuda.jl | 23 +++++++++++++++++++++++ test/runtests.jl | 21 +++++++++++---------- 4 files changed, 38 insertions(+), 10 deletions(-) create mode 100644 test/cuda.jl diff --git a/Project.toml b/Project.toml index ac37e645a..a5d243705 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353" [weakdeps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" @@ -31,6 +32,7 @@ path = "lib/ReactantCore" [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" +ReactantCUDAExt = "CUDA" ReactantNNlibExt = "NNlib" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" @@ -58,4 +60,5 @@ julia = "1.10" [extras] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/test/Project.toml b/test/Project.toml index 4b50a487f..9956337ea 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/test/cuda.jl b/test/cuda.jl new file mode 100644 index 000000000..2475e6275 --- /dev/null +++ b/test/cuda.jl @@ -0,0 +1,23 @@ +using Reactant +using Test +using CUDA + +function square_kernel!(x) + i = threadIdx().x + x[i] *= x[i] + sync_threads() + return nothing +end + +# basic squaring on GPU +function square!(x) + @cuda blocks = 1 threads = length(x) square_kernel!(x) + return nothing +end + +@testset "Square Kernel" begin + oA = collect(1:1:64) + A = Reactant.to_rarray(oA) + func = @compile square!(A) + @test all(A .≈ (oA .* oA)) +end diff --git a/test/runtests.jl b/test/runtests.jl index fddc963ce..f1f52ba96 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -60,17 +60,18 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" @safetestset "Linear Algebra" include("integration/linear_algebra.jl") + @safetestset "CUDA" include("cuda.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") end - if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" - @testset "Neural Networks" begin - @safetestset "NNlib Primitives" include("nn/nnlib.jl") - @safetestset "Flux.jl Integration" include("nn/flux.jl") - if Sys.islinux() - @safetestset "LuxLib Primitives" include("nn/luxlib.jl") - @safetestset "Lux Integration" include("nn/lux.jl") - end - end - end + # if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" + # @testset "Neural Networks" begin + # @safetestset "NNlib Primitives" include("nn/nnlib.jl") + # @safetestset "Flux.jl Integration" include("nn/flux.jl") + # if Sys.islinux() + # @safetestset "LuxLib Primitives" include("nn/luxlib.jl") + # @safetestset "Lux Integration" include("nn/lux.jl") + # end + # end + # end end From f891bded2340f42cc47e91c5711b6cca303db68a Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 29 Nov 2024 17:49:52 -0500 Subject: [PATCH 02/78] more files --- ext/ReactantCUDAExt.jl | 81 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 ext/ReactantCUDAExt.jl diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl new file mode 100644 index 000000000..3dbd73491 --- /dev/null +++ b/ext/ReactantCUDAExt.jl @@ -0,0 +1,81 @@ +module ReactantCUDAExt + +using CUDA +using Reactant: + Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber +using ReactantCore: @trace + + +const _kernel_instances = Dict{Any, Any}() + +function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} + cuda = CUDA.active_state() + + F2 = Reactant.traced_type(F, (), Val(Reactant.TracedToConcrete)) + tt2 = Reactant.traced_type(tt, (), Val(Reactant.TracedToConcrete)) + + + Base.@lock CUDA.cufunction_lock begin + # compile the function + cache = CUDA.compiler_cache(cuda.context) + source = CUDA.methodinstance(F2, tt2) + config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig + fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, CUDA.compile, CUDA.link) + + @show fun + @show fun.mod + # create a callable object that captures the function instance. we don't need to think + # about world age here, as GPUCompiler already does and will return a different object + key = (objectid(source), hash(fun), f) + kernel = get(_kernel_instances, key, nothing) + if kernel === nothing + # create the kernel state object + state = CUDA.KernelState(create_exceptions!(fun.mod), UInt32(0)) + + kernel = CUDA.HostKernel{F,tt}(f, fun, state) + _kernel_instances[key] = kernel + end + return kernel::CUDA.HostKernel{F,tt} + end +end + +const CC = Core.Compiler + +import Core.Compiler: + AbstractInterpreter, + abstract_call, + abstract_call_known, + ArgInfo, + StmtInfo, + AbsIntState, + get_max_methods, + CallMeta, + Effects, + NoCallInfo, + widenconst, + mapany, + MethodResultPure + + +function Reactant.set_reactant_abi( + interp, + f::typeof(CUDA.cufunction), + arginfo::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int=get_max_methods(interp, f, sv), +) + (; fargs, argtypes) = arginfo + + arginfo2 = ArgInfo( + if fargs isa Nothing + nothing + else + [:($(recufunction)), fargs[2:end]...] + end, + [Core.Const(recufunction), argtypes[2:end]...], + ) + return abstract_call_known(interp, recufunction, arginfo2, si, sv, max_methods) +end + +end # module ReactantCUDAExt From 14174db6f9f63dd2e1a633c049365e4f62ac9189 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 29 Nov 2024 18:02:53 -0500 Subject: [PATCH 03/78] fix --- ext/ReactantCUDAExt.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 3dbd73491..87f2ed416 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -5,20 +5,21 @@ using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber using ReactantCore: @trace +using Adapt + +function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} + CuDeviceArray{T,N,CUDA.AS.Global}(pointer(xs.mlir_data.value), size(xs)) +end const _kernel_instances = Dict{Any, Any}() function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} cuda = CUDA.active_state() - F2 = Reactant.traced_type(F, (), Val(Reactant.TracedToConcrete)) - tt2 = Reactant.traced_type(tt, (), Val(Reactant.TracedToConcrete)) - - Base.@lock CUDA.cufunction_lock begin # compile the function cache = CUDA.compiler_cache(cuda.context) - source = CUDA.methodinstance(F2, tt2) + source = CUDA.methodinstance(F, tt) config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, CUDA.compile, CUDA.link) From e401e4a15f9745e50f3a78682327f28eb666defd Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 4 Dec 2024 21:22:41 -0500 Subject: [PATCH 04/78] wip --- Project.toml | 1 + ext/ReactantCUDAExt.jl | 200 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 192 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index a5d243705..7f120f778 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.2.10" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 87f2ed416..22543a976 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -8,35 +8,217 @@ using ReactantCore: @trace using Adapt function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} - CuDeviceArray{T,N,CUDA.AS.Global}(pointer(xs.mlir_data.value), size(xs)) + res = CuDeviceArray{T,N,CUDA.AS.Global}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, xs.mlir_data.value.ptr), size(xs)) + @show res, xs + return res end const _kernel_instances = Dict{Any, Any}() + + +# compile to executable machine code +function compile(job) + # lower to PTX + # TODO: on 1.9, this actually creates a context. cache those. + modstr = JuliaContext() do ctx + mod, meta = GPUCompiler.compile(:llvm, job) + string(mod) + end + return modstr +#= + # check if we'll need the device runtime + undefined_fs = filter(collect(functions(meta.ir))) do f + isdeclaration(f) && !LLVM.isintrinsic(f) + end + intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail", + "__nvvm_reflect" #= TODO: should have been optimized away =#] + needs_cudadevrt = !isempty(setdiff(LLVM.name.(undefined_fs), intrinsic_fns)) + + # prepare invocations of CUDA compiler tools + ptxas_opts = String[] + nvlink_opts = String[] + ## debug flags + if Base.JLOptions().debug_level == 1 + push!(ptxas_opts, "--generate-line-info") + elseif Base.JLOptions().debug_level >= 2 + push!(ptxas_opts, "--device-debug") + push!(nvlink_opts, "--debug") + end + ## relocatable device code + if needs_cudadevrt + push!(ptxas_opts, "--compile-only") + end + + ptx = job.config.params.ptx + cap = job.config.params.cap + arch = "sm_$(cap.major)$(cap.minor)" + + # validate use of parameter memory + argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt + !isghosttype(dt) && !Core.Compiler.isconstType(dt) + end + param_usage = sum(sizeof, argtypes) + param_limit = 4096 + if cap >= v"7.0" && ptx >= v"8.1" + param_limit = 32764 + end + if param_usage > param_limit + msg = """Kernel invocation uses too much parameter memory. + $(Base.format_bytes(param_usage)) exceeds the $(Base.format_bytes(param_limit)) limit imposed by sm_$(cap.major)$(cap.minor) / PTX v$(ptx.major).$(ptx.minor).""" + + try + details = "\n\nRelevant parameters:" + + source_types = job.source.specTypes.parameters + source_argnames = Base.method_argnames(job.source.def) + while length(source_argnames) < length(source_types) + # this is probably due to a trailing vararg; repeat its name + push!(source_argnames, source_argnames[end]) + end + + for (i, typ) in enumerate(source_types) + if isghosttype(typ) || Core.Compiler.isconstType(typ) + continue + end + name = source_argnames[i] + details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))" + end + details *= "\n" + + if cap >= v"7.0" && ptx < v"8.1" && param_usage < 32764 + details *= "\nNote: use a newer CUDA to support more parameters on your device.\n" + end + + msg *= details + catch err + @error "Failed to analyze kernel parameter usage; please file an issue with a reproducer." + end + error(msg) + end + + # compile to machine code + # NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow + ptx_input = tempname(cleanup=false) * ".ptx" + ptxas_output = tempname(cleanup=false) * ".cubin" + write(ptx_input, asm) + + # we could use the driver's embedded JIT compiler, but that has several disadvantages: + # 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to + # upgrade the toolkit to get a newer compiler; + # 2. version checking is simpler, we otherwise need to use NVML to query the driver + # version, which is hard to correlate to PTX JIT improvements; + # 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an + # older driver, we should use the newer compiler to ensure compatibility. + append!(ptxas_opts, [ + "--verbose", + "--gpu-name", arch, + "--output-file", ptxas_output, + ptx_input + ]) + proc, log = run_and_collect(`$(ptxas()) $ptxas_opts`) + log = strip(log) + if !success(proc) + reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" : + "ptxas exited with code $(proc.exitcode)" + msg = "Failed to compile PTX code ($reason)" + msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))" + if !isempty(log) + msg *= "\n" * log + end + msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)" + if parse(Bool, get(ENV, "BUILDKITE", "false")) + run(`buildkite-agent artifact upload $(ptx_input)`) + end + error(msg) + elseif !isempty(log) + @debug "PTX compiler log:\n" * log + end + rm(ptx_input) +=# +#= + # link device libraries, if necessary + # + # this requires relocatable device code, which prevents certain optimizations and + # hurts performance. as such, we only do so when absolutely necessary. + # TODO: try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`. + # fails with `Ignoring -lto option because no LTO objects found` + if needs_cudadevrt + nvlink_output = tempname(cleanup=false) * ".cubin" + append!(nvlink_opts, [ + "--verbose", "--extra-warnings", + "--arch", arch, + "--library-path", dirname(libcudadevrt), + "--library", "cudadevrt", + "--output-file", nvlink_output, + ptxas_output + ]) + proc, log = run_and_collect(`$(nvlink()) $nvlink_opts`) + log = strip(log) + if !success(proc) + reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" : + "nvlink exited with code $(proc.exitcode)" + msg = "Failed to link PTX code ($reason)" + msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))" + if !isempty(log) + msg *= "\n" * log + end + msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)" + error(msg) + elseif !isempty(log) + @debug "PTX linker info log:\n" * log + end + rm(ptxas_output) + + image = read(nvlink_output) + rm(nvlink_output) + else + image = read(ptxas_output) + rm(ptxas_output) + end +=# + return (image, entry=LLVM.name(meta.entry)) +end + +# link into an executable kernel +function link(job, compiled) + # load as an executable kernel object + return compiled +end + +struct LLVMFunc{F,tt} + f::F + mod::String +end + +function (func::LLVMFunc{F,tt})(args...) where{F, tt} + +end + function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} cuda = CUDA.active_state() + @show f, tt + flush(stdout) Base.@lock CUDA.cufunction_lock begin # compile the function cache = CUDA.compiler_cache(cuda.context) source = CUDA.methodinstance(F, tt) config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig - fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, CUDA.compile, CUDA.link) + fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) @show fun - @show fun.mod + println(string(fun)) + #@show fun.mod # create a callable object that captures the function instance. we don't need to think # about world age here, as GPUCompiler already does and will return a different object - key = (objectid(source), hash(fun), f) + key = (objectid(source)) kernel = get(_kernel_instances, key, nothing) if kernel === nothing - # create the kernel state object - state = CUDA.KernelState(create_exceptions!(fun.mod), UInt32(0)) - - kernel = CUDA.HostKernel{F,tt}(f, fun, state) + kernel = LLVMFunc{F,tt}(f, fun) _kernel_instances[key] = kernel end - return kernel::CUDA.HostKernel{F,tt} + return kernel::LLVMFunc{F,tt} end end From e7bc31877b48c3631a968c5148585f4b1443c29b Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Dec 2024 17:58:52 -0500 Subject: [PATCH 05/78] wqtmp --- ext/ReactantCUDAExt.jl | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 22543a976..eef007634 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -16,15 +16,19 @@ end const _kernel_instances = Dict{Any, Any}() - # compile to executable machine code function compile(job) # lower to PTX # TODO: on 1.9, this actually creates a context. cache those. - modstr = JuliaContext() do ctx - mod, meta = GPUCompiler.compile(:llvm, job) + modstr = CUDA.GPUCompiler.JuliaContext() do ctx + mod, meta = CUDA.GPUCompiler.compile(:llvm, job) string(mod) end + println(string(modstr)) + @show job + @show job.params + @show job.source + kernel = LLVMFunc{F,tt}(f, modstr) return modstr #= # check if we'll need the device runtime @@ -187,12 +191,23 @@ function link(job, compiled) end struct LLVMFunc{F,tt} - f::F - mod::String + f::F + mod::String end function (func::LLVMFunc{F,tt})(args...) where{F, tt} - + +end + +# cache of compilation caches, per context +const _compiler_caches = Dict{MLIR.IR.Context, Dict{Any, LLVMFunc}}(); +function compiler_cache(ctx::MLIR.IR.Context) + cache = get(_compiler_caches, ctx, nothing) + if cache === nothing + cache = Dict{Any, LLVMFunc}() + _compiler_caches[ctx] = cache + end + return cache end function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} @@ -202,20 +217,17 @@ function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} Base.@lock CUDA.cufunction_lock begin # compile the function - cache = CUDA.compiler_cache(cuda.context) + cache = compiler_cache(MLIR.IR.context()) source = CUDA.methodinstance(F, tt) config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) - @show fun - println(string(fun)) #@show fun.mod # create a callable object that captures the function instance. we don't need to think # about world age here, as GPUCompiler already does and will return a different object key = (objectid(source)) kernel = get(_kernel_instances, key, nothing) if kernel === nothing - kernel = LLVMFunc{F,tt}(f, fun) _kernel_instances[key] = kernel end return kernel::LLVMFunc{F,tt} From ae6c7c63d17cbb4d61f26a993a7ca78f36026b28 Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Dec 2024 21:08:28 -0500 Subject: [PATCH 06/78] wip --- ext/ReactantCUDAExt.jl | 44 ++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index eef007634..1b22999d3 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -24,20 +24,13 @@ function compile(job) mod, meta = CUDA.GPUCompiler.compile(:llvm, job) string(mod) end - println(string(modstr)) - @show job - @show job.params - @show job.source - kernel = LLVMFunc{F,tt}(f, modstr) - return modstr -#= # check if we'll need the device runtime undefined_fs = filter(collect(functions(meta.ir))) do f - isdeclaration(f) && !LLVM.isintrinsic(f) + isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) end intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail", "__nvvm_reflect" #= TODO: should have been optimized away =#] - needs_cudadevrt = !isempty(setdiff(LLVM.name.(undefined_fs), intrinsic_fns)) + needs_cudadevrt = !isempty(setdiff(CUDA.LLVM.name.(undefined_fs), intrinsic_fns)) # prepare invocations of CUDA compiler tools ptxas_opts = String[] @@ -59,7 +52,7 @@ function compile(job) arch = "sm_$(cap.major)$(cap.minor)" # validate use of parameter memory - argtypes = filter([KernelState, job.source.specTypes.parameters...]) do dt + argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt !isghosttype(dt) && !Core.Compiler.isconstType(dt) end param_usage = sum(sizeof, argtypes) @@ -120,7 +113,7 @@ function compile(job) "--output-file", ptxas_output, ptx_input ]) - proc, log = run_and_collect(`$(ptxas()) $ptxas_opts`) + proc, log = CUDA.run_and_collect(`$(ptxas()) $ptxas_opts`) log = strip(log) if !success(proc) reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" : @@ -139,8 +132,7 @@ function compile(job) @debug "PTX compiler log:\n" * log end rm(ptx_input) -=# -#= + # link device libraries, if necessary # # this requires relocatable device code, which prevents certain optimizations and @@ -180,8 +172,12 @@ function compile(job) image = read(ptxas_output) rm(ptxas_output) end -=# - return (image, entry=LLVM.name(meta.entry)) + + println(string(modstr)) + @show job + @show job.source + @show job.config + LLVMFunc{F,job.source.specTypes}(f, modstr, image, LLVM.name(meta.entry)) end # link into an executable kernel @@ -193,10 +189,24 @@ end struct LLVMFunc{F,tt} f::F mod::String + image + entry::String end -function (func::LLVMFunc{F,tt})(args...) where{F, tt} - +function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuDim=1, + shmem::Integer=0) where{F, tt} + blockdim = CUDA.CuDim3(blocks) + threaddim = CUDA.CuDim3(threads) + + @show args + +# void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, +# size_t opaque_len, XlaCustomCallStatus* status) { + + CUDA.cuLaunchKernel(f, + blockdim.x, blockdim.y, blockdim.z, + threaddim.x, threaddim.y, threaddim.z, + shmem, stream, kernelParams, C_NULL) end # cache of compilation caches, per context From b8f206f6c6c25d5ec3fc33a565001ab78d37203e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 16:37:23 -0500 Subject: [PATCH 07/78] inc --- src/Interpreter.jl | 1 + test/runtests.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 2efb53792..52a98962b 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -98,6 +98,7 @@ function set_reactant_abi( end end + @show f, arginfo return Base.@invoke abstract_call_known( interp::AbstractInterpreter, f::Any, diff --git a/test/runtests.jl b/test/runtests.jl index f1f52ba96..ce0fefede 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,6 +41,7 @@ end const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) +include("cuda.jl") @testset "Reactant.jl Tests" begin if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" @safetestset "Layout" include("layout.jl") From c414601dba63853c8ba83f4178d9746925fe60bc Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 18:32:17 -0500 Subject: [PATCH 08/78] continuing --- ext/ReactantCUDAExt.jl | 369 +++++++++++++++++++++-------------------- src/Interpreter.jl | 1 - test/runtests.jl | 2 + 3 files changed, 192 insertions(+), 180 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 1b22999d3..1c29905c8 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -7,11 +7,11 @@ using ReactantCore: @trace using Adapt -function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} - res = CuDeviceArray{T,N,CUDA.AS.Global}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, xs.mlir_data.value.ptr), size(xs)) - @show res, xs - return res -end +#function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} +# res = CuDeviceArray{T,N,CUDA.AS.Global}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, xs.mlir_data.value.ptr), size(xs)) +# @show res, xs +# return res +#end const _kernel_instances = Dict{Any, Any}() @@ -20,164 +20,167 @@ const _kernel_instances = Dict{Any, Any}() function compile(job) # lower to PTX # TODO: on 1.9, this actually creates a context. cache those. - modstr = CUDA.GPUCompiler.JuliaContext() do ctx - mod, meta = CUDA.GPUCompiler.compile(:llvm, job) - string(mod) - end - # check if we'll need the device runtime - undefined_fs = filter(collect(functions(meta.ir))) do f - isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) - end - intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail", - "__nvvm_reflect" #= TODO: should have been optimized away =#] - needs_cudadevrt = !isempty(setdiff(CUDA.LLVM.name.(undefined_fs), intrinsic_fns)) - - # prepare invocations of CUDA compiler tools - ptxas_opts = String[] - nvlink_opts = String[] - ## debug flags - if Base.JLOptions().debug_level == 1 - push!(ptxas_opts, "--generate-line-info") - elseif Base.JLOptions().debug_level >= 2 - push!(ptxas_opts, "--device-debug") - push!(nvlink_opts, "--debug") - end - ## relocatable device code - if needs_cudadevrt - push!(ptxas_opts, "--compile-only") - end - - ptx = job.config.params.ptx - cap = job.config.params.cap - arch = "sm_$(cap.major)$(cap.minor)" - - # validate use of parameter memory - argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt - !isghosttype(dt) && !Core.Compiler.isconstType(dt) - end - param_usage = sum(sizeof, argtypes) - param_limit = 4096 - if cap >= v"7.0" && ptx >= v"8.1" - param_limit = 32764 - end - if param_usage > param_limit - msg = """Kernel invocation uses too much parameter memory. - $(Base.format_bytes(param_usage)) exceeds the $(Base.format_bytes(param_limit)) limit imposed by sm_$(cap.major)$(cap.minor) / PTX v$(ptx.major).$(ptx.minor).""" - - try - details = "\n\nRelevant parameters:" - - source_types = job.source.specTypes.parameters - source_argnames = Base.method_argnames(job.source.def) - while length(source_argnames) < length(source_types) - # this is probably due to a trailing vararg; repeat its name - push!(source_argnames, source_argnames[end]) - end - - for (i, typ) in enumerate(source_types) - if isghosttype(typ) || Core.Compiler.isconstType(typ) - continue - end - name = source_argnames[i] - details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))" - end - details *= "\n" - - if cap >= v"7.0" && ptx < v"8.1" && param_usage < 32764 - details *= "\nNote: use a newer CUDA to support more parameters on your device.\n" - end - - msg *= details - catch err - @error "Failed to analyze kernel parameter usage; please file an issue with a reproducer." - end - error(msg) - end - - # compile to machine code - # NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow - ptx_input = tempname(cleanup=false) * ".ptx" - ptxas_output = tempname(cleanup=false) * ".cubin" - write(ptx_input, asm) - - # we could use the driver's embedded JIT compiler, but that has several disadvantages: - # 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to - # upgrade the toolkit to get a newer compiler; - # 2. version checking is simpler, we otherwise need to use NVML to query the driver - # version, which is hard to correlate to PTX JIT improvements; - # 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an - # older driver, we should use the newer compiler to ensure compatibility. - append!(ptxas_opts, [ - "--verbose", - "--gpu-name", arch, - "--output-file", ptxas_output, - ptx_input - ]) - proc, log = CUDA.run_and_collect(`$(ptxas()) $ptxas_opts`) - log = strip(log) - if !success(proc) - reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" : - "ptxas exited with code $(proc.exitcode)" - msg = "Failed to compile PTX code ($reason)" - msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))" - if !isempty(log) - msg *= "\n" * log - end - msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)" - if parse(Bool, get(ENV, "BUILDKITE", "false")) - run(`buildkite-agent artifact upload $(ptx_input)`) - end - error(msg) - elseif !isempty(log) - @debug "PTX compiler log:\n" * log - end - rm(ptx_input) - - # link device libraries, if necessary - # - # this requires relocatable device code, which prevents certain optimizations and - # hurts performance. as such, we only do so when absolutely necessary. - # TODO: try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`. - # fails with `Ignoring -lto option because no LTO objects found` - if needs_cudadevrt - nvlink_output = tempname(cleanup=false) * ".cubin" - append!(nvlink_opts, [ - "--verbose", "--extra-warnings", - "--arch", arch, - "--library-path", dirname(libcudadevrt), - "--library", "cudadevrt", - "--output-file", nvlink_output, - ptxas_output - ]) - proc, log = run_and_collect(`$(nvlink()) $nvlink_opts`) - log = strip(log) - if !success(proc) - reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" : - "nvlink exited with code $(proc.exitcode)" - msg = "Failed to link PTX code ($reason)" - msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))" - if !isempty(log) - msg *= "\n" * log - end - msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)" - error(msg) - elseif !isempty(log) - @debug "PTX linker info log:\n" * log - end - rm(ptxas_output) - - image = read(nvlink_output) - rm(nvlink_output) - else - image = read(ptxas_output) - rm(ptxas_output) + modstr, image, entry = CUDA.GPUCompiler.JuliaContext() do ctx + asm, meta = CUDA.GPUCompiler.compile(:asm, job) + mod = meta.ir + modstr = string(mod) + # check if we'll need the device runtime + undefined_fs = filter(collect(functions(meta.ir))) do f + isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) + end + intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail", + "__nvvm_reflect" #= TODO: should have been optimized away =#] + needs_cudadevrt = !isempty(setdiff(CUDA.LLVM.name.(undefined_fs), intrinsic_fns)) + + # prepare invocations of CUDA compiler tools + ptxas_opts = String[] + nvlink_opts = String[] + ## debug flags + if Base.JLOptions().debug_level == 1 + push!(ptxas_opts, "--generate-line-info") + elseif Base.JLOptions().debug_level >= 2 + push!(ptxas_opts, "--device-debug") + push!(nvlink_opts, "--debug") + end + ## relocatable device code + if needs_cudadevrt + push!(ptxas_opts, "--compile-only") + end + + ptx = job.config.params.ptx + cap = job.config.params.cap + arch = "sm_$(cap.major)$(cap.minor)" + + # validate use of parameter memory + argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt + !isghosttype(dt) && !Core.Compiler.isconstType(dt) + end + param_usage = sum(sizeof, argtypes) + param_limit = 4096 + if cap >= v"7.0" && ptx >= v"8.1" + param_limit = 32764 + end + if param_usage > param_limit + msg = """Kernel invocation uses too much parameter memory. + $(Base.format_bytes(param_usage)) exceeds the $(Base.format_bytes(param_limit)) limit imposed by sm_$(cap.major)$(cap.minor) / PTX v$(ptx.major).$(ptx.minor).""" + + try + details = "\n\nRelevant parameters:" + + source_types = job.source.specTypes.parameters + source_argnames = Base.method_argnames(job.source.def) + while length(source_argnames) < length(source_types) + # this is probably due to a trailing vararg; repeat its name + push!(source_argnames, source_argnames[end]) + end + + for (i, typ) in enumerate(source_types) + if isghosttype(typ) || Core.Compiler.isconstType(typ) + continue + end + name = source_argnames[i] + details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))" + end + details *= "\n" + + if cap >= v"7.0" && ptx < v"8.1" && param_usage < 32764 + details *= "\nNote: use a newer CUDA to support more parameters on your device.\n" + end + + msg *= details + catch err + @error "Failed to analyze kernel parameter usage; please file an issue with a reproducer." + end + error(msg) + end + + # compile to machine code + # NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow + ptx_input = tempname(cleanup=false) * ".ptx" + ptxas_output = tempname(cleanup=false) * ".cubin" + write(ptx_input, asm) + + # we could use the driver's embedded JIT compiler, but that has several disadvantages: + # 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to + # upgrade the toolkit to get a newer compiler; + # 2. version checking is simpler, we otherwise need to use NVML to query the driver + # version, which is hard to correlate to PTX JIT improvements; + # 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an + # older driver, we should use the newer compiler to ensure compatibility. + append!(ptxas_opts, [ + "--verbose", + "--gpu-name", arch, + "--output-file", ptxas_output, + ptx_input + ]) + proc, log = CUDA.run_and_collect(`$(ptxas()) $ptxas_opts`) + log = strip(log) + if !success(proc) + reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" : + "ptxas exited with code $(proc.exitcode)" + msg = "Failed to compile PTX code ($reason)" + msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))" + if !isempty(log) + msg *= "\n" * log + end + msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)" + if parse(Bool, get(ENV, "BUILDKITE", "false")) + run(`buildkite-agent artifact upload $(ptx_input)`) + end + error(msg) + elseif !isempty(log) + @debug "PTX compiler log:\n" * log + end + rm(ptx_input) + + # link device libraries, if necessary + # + # this requires relocatable device code, which prevents certain optimizations and + # hurts performance. as such, we only do so when absolutely necessary. + # TODO: try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`. + # fails with `Ignoring -lto option because no LTO objects found` + if needs_cudadevrt + nvlink_output = tempname(cleanup=false) * ".cubin" + append!(nvlink_opts, [ + "--verbose", "--extra-warnings", + "--arch", arch, + "--library-path", dirname(libcudadevrt), + "--library", "cudadevrt", + "--output-file", nvlink_output, + ptxas_output + ]) + proc, log = run_and_collect(`$(nvlink()) $nvlink_opts`) + log = strip(log) + if !success(proc) + reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" : + "nvlink exited with code $(proc.exitcode)" + msg = "Failed to link PTX code ($reason)" + msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))" + if !isempty(log) + msg *= "\n" * log + end + msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)" + error(msg) + elseif !isempty(log) + @debug "PTX linker info log:\n" * log + end + rm(ptxas_output) + + image = read(nvlink_output) + rm(nvlink_output) + else + image = read(ptxas_output) + rm(ptxas_output) + end + + modstr, image, meta.entry end println(string(modstr)) @show job @show job.source @show job.config - LLVMFunc{F,job.source.specTypes}(f, modstr, image, LLVM.name(meta.entry)) + LLVMFunc{F,job.source.specTypes}(f, modstr, image, LLVM.name(entry)) end # link into an executable kernel @@ -200,13 +203,32 @@ function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuD @show args -# void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, -# size_t opaque_len, XlaCustomCallStatus* status) { + mlir_args = MLIR.IR.Value[] + restys = MLIR.IR.Type[] + aliases = MLIR.API.MlirAttribute[] + for (i, a) in enumerate(args) + @show a + arg = nothing + arg = Reactant.Compiler.transpose_val(arg) + push!(restys, MLIR.IR.Type(arg)) + push!(aliases, + MLIR.IR.Dialects.stablehlo.stablehloOutputOperandAliasGet( + MLIR.IR.context(), + len(args) == 1 ? 0 : 1, + len(args) == 1 ? C_NULL : Ref{Int64}(i-1), + i-1, + 0, + C_NULL + ) + ) + end - CUDA.cuLaunchKernel(f, - blockdim.x, blockdim.y, blockdim.z, - threaddim.x, threaddim.y, threaddim.z, - shmem, stream, kernelParams, C_NULL) + output_operand_aliases=MLIR.ArrayAttr.get(MLIR.IR.context(), aliases) + MLIR.IR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases) + #CUDA.cuLaunchKernel(f, + # blockdim.x, blockdim.y, blockdim.z, + # threaddim.x, threaddim.y, threaddim.z, + # shmem, stream, kernelParams, C_NULL) end # cache of compilation caches, per context @@ -221,27 +243,16 @@ function compiler_cache(ctx::MLIR.IR.Context) end function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} - cuda = CUDA.active_state() - @show f, tt - flush(stdout) - - Base.@lock CUDA.cufunction_lock begin + res = Base.@lock CUDA.cufunction_lock begin # compile the function cache = compiler_cache(MLIR.IR.context()) source = CUDA.methodinstance(F, tt) - config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig - fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) - #@show fun.mod - # create a callable object that captures the function instance. we don't need to think - # about world age here, as GPUCompiler already does and will return a different object - key = (objectid(source)) - kernel = get(_kernel_instances, key, nothing) - if kernel === nothing - _kernel_instances[key] = kernel - end - return kernel::LLVMFunc{F,tt} + cuda = CUDA.active_state() + config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig + CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) end + res end const CC = Core.Compiler diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 52a98962b..2efb53792 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -98,7 +98,6 @@ function set_reactant_abi( end end - @show f, arginfo return Base.@invoke abstract_call_known( interp::AbstractInterpreter, f::Any, diff --git a/test/runtests.jl b/test/runtests.jl index ce0fefede..87e1a3702 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,6 +42,7 @@ end const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) include("cuda.jl") +@static if false @testset "Reactant.jl Tests" begin if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" @safetestset "Layout" include("layout.jl") @@ -76,3 +77,4 @@ include("cuda.jl") # end # end end +end From 9138a36eef0cb211b62eb153806c05654825f8f7 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 22:54:14 -0500 Subject: [PATCH 09/78] wip --- ext/ReactantCUDAExt.jl | 207 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 201 insertions(+), 6 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 1c29905c8..84368ed60 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -7,11 +7,202 @@ using ReactantCore: @trace using Adapt -#function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} -# res = CuDeviceArray{T,N,CUDA.AS.Global}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, xs.mlir_data.value.ptr), size(xs)) -# @show res, xs -# return res -#end +struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N} + ptr::Core.LLVMPtr{T,A} +end + + +Base.show(io::IO, a::AT) where AT <: CuTracedArray = + CUDA.Printf.@printf(io, "%s cu traced array at %p", join(size(a), '×'), Int(pointer(a))) + +## array interface + +Base.elsize(::Type{<:CuTracedArray{T}}) where {T} = sizeof(T) +Base.size(g::CuTracedArray{T,N,A,Size}) where {T,N,A,Size} = Size +Base.sizeof(x::CuTracedArray) = Base.elsize(x) * length(x) +Base.pointer(x::CuTracedArray{T,<:Any,A}) where {T,A} = Base.unsafe_convert(Core.LLVMPtr{T,A}, x) +@inline function Base.pointer(x::CuTracedArray{T,<:Any,A}, i::Integer) where {T,A} + Base.unsafe_convert(Core.LLVMPtr{T,A}, x) + Base._memory_offset(x, i) +end + + +## conversions + +Base.unsafe_convert(::Type{Core.LLVMPtr{T,A}}, x::CuTracedArray{T,<:Any,A}) where {T,A} = + x.ptr + + +## indexing intrinsics + +CUDA.@device_function @inline function arrayref(A::CuTracedArray{T}, index::Integer) where {T} + @boundscheck checkbounds(A, index) + if Base.isbitsunion(T) + arrayref_union(A, index) + else + arrayref_bits(A, index) + end +end + +@inline function arrayref_bits(A::CuTracedArray{T}, index::Integer) where {T} + unsafe_load(pointer(A), index) +end + +@inline @generated function arrayref_union(A::CuTracedArray{T,<:Any,AS}, index::Integer) where {T,AS} + typs = Base.uniontypes(T) + + # generate code that conditionally loads a value based on the selector value. + # lacking noreturn, we return T to avoid inference thinking this can return Nothing. + ex = :(Base.llvmcall("unreachable", $T, Tuple{})) + for (sel, typ) in Iterators.reverse(enumerate(typs)) + ex = quote + if selector == $(sel-1) + ptr = reinterpret(Core.LLVMPtr{$typ,AS}, data_ptr) + unsafe_load(ptr, 1) + else + $ex + end + end + end + + quote + selector_ptr = typetagdata(A, index) + selector = unsafe_load(selector_ptr) + + data_ptr = pointer(A, index) + + return $ex + end +end + +CUDA.@device_function @inline function arrayset(A::CuTracedArray{T}, x::T, index::Integer) where {T} + @boundscheck checkbounds(A, index) + if Base.isbitsunion(T) + arrayset_union(A, x, index) + else + arrayset_bits(A, x, index) + end + return A +end + +@inline function arrayset_bits(A::CuTracedArray{T}, x::T, index::Integer) where {T} + unsafe_store!(pointer(A), x, index) +end + +@inline @generated function arrayset_union(A::CuTracedArray{T,<:Any,AS}, x::T, index::Integer) where {T,AS} + typs = Base.uniontypes(T) + sel = findfirst(isequal(x), typs) + + quote + selector_ptr = typetagdata(A, index) + unsafe_store!(selector_ptr, $(UInt8(sel-1))) + + data_ptr = pointer(A, index) + + unsafe_store!(reinterpret(Core.LLVMPtr{$x,AS}, data_ptr), x, 1) + return + end +end + +CUDA.@device_function @inline function const_arrayref(A::CuTracedArray{T}, index::Integer) where {T} + @boundscheck checkbounds(A, index) + unsafe_cached_load(pointer(A), index) +end + + +## indexing + +Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear() + +Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = + arrayref(A, i1) +Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = + arrayset(A, convert(T,x)::T, i1) + +# preserve the specific integer type when indexing device arrays, +# to avoid extending 32-bit hardware indices to 64-bit. +Base.to_index(::CuTracedArray, i::Integer) = i + +# Base doesn't like Integer indices, so we need our own ND get and setindex! routines. +# See also: https://github.com/JuliaLang/julia/pull/42289 +Base.@propagate_inbounds Base.getindex(A::CuTracedArray, + I::Union{Integer, CartesianIndex}...) = + A[Base._to_linear_index(A, to_indices(A, I)...)] +Base.@propagate_inbounds Base.setindex!(A::CuTracedArray, x, + I::Union{Integer, CartesianIndex}...) = + A[Base._to_linear_index(A, to_indices(A, I)...)] = x + + +## const indexing + +""" + Const(A::CuTracedArray) + +Mark a CuTracedArray as constant/read-only. The invariant guaranteed is that you will not +modify an CuTracedArray for the duration of the current kernel. + +This API can only be used on devices with compute capability 3.5 or higher. + +!!! warning + Experimental API. Subject to change without deprecation. +""" +struct Const{T,N,AS} <: DenseArray{T,N} + a::CuTracedArray{T,N,AS} +end +Base.Experimental.Const(A::CuTracedArray) = Const(A) + +Base.IndexStyle(::Type{<:Const}) = IndexLinear() +Base.size(C::Const) = size(C.a) +Base.axes(C::Const) = axes(C.a) +Base.@propagate_inbounds Base.getindex(A::Const, i1::Integer) = const_arrayref(A.a, i1) + +# deprecated +Base.@propagate_inbounds ldg(A::CuTracedArray, i1::Integer) = const_arrayref(A, i1) + + +## other + +@inline function Base.iterate(A::CuTracedArray, i=1) + if (i % UInt) - 1 < length(A) + (@inbounds A[i], i + 1) + else + nothing + end +end + +function Base.reinterpret(::Type{T}, a::CuTracedArray{S,N,A}) where {T,S,N,A} + err = GPUArrays._reinterpret_exception(T, a) + err === nothing || throw(err) + + if sizeof(T) == sizeof(S) # fast case + return CuTracedArray{T,N,A}(reinterpret(Core.LLVMPtr{T,A}, a.ptr), size(a), a.maxsize) + end + + isize = size(a) + size1 = div(isize[1]*sizeof(S), sizeof(T)) + osize = tuple(size1, Base.tail(isize)...) + return CuTracedArray{T,N,A}(reinterpret(Core.LLVMPtr{T,A}, a.ptr), osize, a.maxsize) +end + + +## reshape + +function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M,A} + if prod(dims) != length(a) + throw(DimensionMismatch("new dimensions (argument `dims`) must be consistent with array size (`size(a)`)")) + end + if N == M && dims == size(a) + return a + end + _derived_array(a, T, dims) +end + + + +function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} + res = CuTracedArray{T,N,CUDA.AS.Global, size(xs)}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))) + @show res, xs + return res +end const _kernel_instances = Dict{Any, Any}() @@ -24,6 +215,8 @@ function compile(job) asm, meta = CUDA.GPUCompiler.compile(:asm, job) mod = meta.ir modstr = string(mod) + @show mod + @show modstr # check if we'll need the device runtime undefined_fs = filter(collect(functions(meta.ir))) do f isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) @@ -208,7 +401,9 @@ function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuD aliases = MLIR.API.MlirAttribute[] for (i, a) in enumerate(args) @show a - arg = nothing + @assert a isa CuDeviceArray + ta = Base.pointer_to_objref(a.ptr)::TracedRArray + arg = ta.mlir_data arg = Reactant.Compiler.transpose_val(arg) push!(restys, MLIR.IR.Type(arg)) push!(aliases, From c2ca4cb19a7f7d4d7bb1d2aab682e07b3b9e3fa5 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 8 Dec 2024 19:49:23 -0500 Subject: [PATCH 10/78] more work --- Project.toml | 1 - ext/ReactantCUDAExt.jl | 42 +-------------- src/Interpreter.jl | 57 ++++----------------- src/Reactant.jl | 3 ++ src/utils.jl | 113 +++++++++++++++++++++++++++++++---------- test/cuda.jl | 4 +- 6 files changed, 105 insertions(+), 115 deletions(-) diff --git a/Project.toml b/Project.toml index 7f120f778..a5d243705 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.2.10" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 84368ed60..fe3279df0 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -437,7 +437,8 @@ function compiler_cache(ctx::MLIR.IR.Context) return cache end -function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} +Reactant.@overlay function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} + @show "recufunction", f, tt res = Base.@lock CUDA.cufunction_lock begin # compile the function cache = compiler_cache(MLIR.IR.context()) @@ -450,43 +451,4 @@ function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} res end -const CC = Core.Compiler - -import Core.Compiler: - AbstractInterpreter, - abstract_call, - abstract_call_known, - ArgInfo, - StmtInfo, - AbsIntState, - get_max_methods, - CallMeta, - Effects, - NoCallInfo, - widenconst, - mapany, - MethodResultPure - - -function Reactant.set_reactant_abi( - interp, - f::typeof(CUDA.cufunction), - arginfo::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int=get_max_methods(interp, f, sv), -) - (; fargs, argtypes) = arginfo - - arginfo2 = ArgInfo( - if fargs isa Nothing - nothing - else - [:($(recufunction)), fargs[2:end]...] - end, - [Core.Const(recufunction), argtypes[2:end]...], - ) - return abstract_call_known(interp, recufunction, arginfo2, si, sv, max_methods) -end - end # module ReactantCUDAExt diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 2efb53792..39986d12b 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -21,6 +21,15 @@ import Core.Compiler: mapany, MethodResultPure + +Base.Experimental.@MethodTable REACTANT_METHOD_TABLE + +macro overlay(method_expr) + def = splitdef(method_expr) + def[:name] = Expr(:overlay, :(Reactant.REACTANT_METHOD_TABLE), def[:name]) + return esc(combinedef(def)) +end + function set_reactant_abi( interp, @nospecialize(f), @@ -54,50 +63,6 @@ function set_reactant_abi( end end - if length(argtypes) >= 5 && - f === Core.kwcall && - ( - widenconst(argtypes[3]) == typeof(Enzyme.gradient) || - widenconst(argtypes[3]) == typeof(Enzyme.jacobian) - ) && - widenconst(argtypes[4]) <: Enzyme.Mode - newmode = Enzyme.set_abi(widenconst(argtypes[4]), ReactantABI) - if newmode != widenconst(argtypes[4]) - newmodev = newmode() - arginfo2 = ArgInfo( - if fargs isa Nothing - nothing - else - [fargs[1:3]..., :($(newmodev)), fargs[5:end]...] - end, - [argtypes[1:3]..., Core.Const(newmodev), argtypes[5:end]...], - ) - return abstract_call_known(interp, f, arginfo2, si, sv, max_methods) - end - end - - if length(argtypes) >= 5 && - methods(f)[1].module == Enzyme && - widenconst(argtypes[5]) <: Enzyme.Mode && - ( - widenconst(argtypes[4]) == typeof(Enzyme.gradient) || - widenconst(argtypes[4]) == typeof(Enzyme.jacobian) - ) - newmode = Enzyme.set_abi(widenconst(argtypes[5]), ReactantABI) - if newmode != widenconst(argtypes[5]) - newmodev = newmode() - arginfo2 = ArgInfo( - if fargs isa Nothing - nothing - else - [fargs[1:4]..., :($(newmodev)), fargs[6:end]...] - end, - [argtypes[1:4]..., Core.Const(newmodev), argtypes[6:end]...], - ) - return abstract_call_known(interp, f, arginfo2, si, sv, max_methods) - end - end - return Base.@invoke abstract_call_known( interp::AbstractInterpreter, f::Any, @@ -116,7 +81,7 @@ function set_reactant_abi end function ReactantInterpreter(; world::UInt=Base.get_world_counter()) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( ReactantCacheToken(), - nothing, #=mt=# + REACTANT_METHOD_TABLE, world, true, #=forward_rules=# true, #=reverse_rules=# @@ -132,7 +97,7 @@ else ) return Enzyme.Compiler.Interpreter.EnzymeInterpreter( REACTANT_CACHE, - nothing, #=mt=# + REACTANT_METHOD_TABLE, world, true, #=forward_rules=# true, #=forward_rules=# diff --git a/src/Reactant.jl b/src/Reactant.jl index 06fd59aff..1623503df 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -124,4 +124,7 @@ function set_default_backend(backend::String) return set_default_backend(XLA.backends[backend]) end +# include("../ext/ReactantCUDAExt.jl") + end # module + diff --git a/src/utils.jl b/src/utils.jl index b37e00fd1..50da46e3d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -37,6 +37,88 @@ function apply(f, args...; kwargs...) return f(args...; kwargs...) end +function call_with_reactant end + +function rewrite_inst(inst) + @show inst + if Meta.isexpr(inst, :call) + rep = Expr(:call, call_with_reactant, inst.args...) + @show rep + return rep + end + return inst +end + +function call_with_reactant_generator(world::UInt, source::LineNumberNode, @nospecialize(F::Type), @nospecialize(N::Int), self, @nospecialize(f::Type), @nospecialize(args)) + @nospecialize + @show f, args + + stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :f, :args), Core.svec()) + + # look up the method match + method_error = :(throw(MethodError(f, args, $world))) + + interp = ReactantInterpreter(; world) + + mt = interp.method_table + + sig = Tuple{F, args...} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + match = ccall(:jl_gf_invoke_lookup_worlds, Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, mt, world, min_world, max_world) + match === nothing && return stub(world, source, method_error) + + # look up the method and code instance + mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, + (Any, Any, Any), match.method, match.spec_types, match.sparams) + + result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) + frame = Core.Compiler.InferenceState(result, #=cache_mode=#:global, interp) + @assert frame !== nothing + Core.Compiler.typeinf(interp, frame) + @assert Core.Compiler.is_inferred(frame) + + #if Core.Compiler.result_is_constabi(interp, frame.result) + # rt = frame.result.result::Core.Compiler.Const + # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) + #else + opt = Core.Compiler.OptimizationState(frame, interp) + caller = frame.result + @static if VERSION < v"1.11-" + ir = Core.Compiler.run_passes(opt.src, opt, caller) + else + ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) + Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller) + end + @show ir + for (i, inst) in enumerate(ir.stmts) + @static if VERSION < v"1.11" + Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst]), :inst) + else + Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt]), :stmt) + end + end + @show ir + Core.Compiler.finish(interp, opt, ir, caller) + src = Core.Compiler.ir_to_codeinf!(opt) + #end + + new_ci = copy(src) + new_ci.slotnames = Symbol[Symbol("#self#"), :f, :args] + new_ci.edges = Core.MethodInstance[mi] + new_ci.min_world = min_world[] + new_ci.max_world = max_world[] + + return new_ci +end + +@eval function call_with_reactant(f::F, args::Vararg{Any, N}) where {F, N} + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, call_with_reactant_generator)) +end + function make_mlir_fn( f, args, @@ -131,36 +213,13 @@ function make_mlir_fn( interp = ReactantInterpreter() # TODO replace with `Base.invoke_within` if julia#52964 lands - # TODO fix it for kwargs - ircoderes = Base.code_ircode(f, map(typeof, traced_args); interp) - - if length(ircoderes) != 1 - throw( - AssertionError( - "Could not find unique ircode for $f $traced_args, found $ircoderes" - ), - ) - end - ir, ty = ircoderes[1] - oc = Core.OpaqueClosure(ir) + # TODO fix it for kwargs + oc = call_with_reactant # Core.OpaqueClosure(ir) if f === Reactant.apply - oc(traced_args[1], (traced_args[2:end]...,)) + oc(f, traced_args[1], (traced_args[2:end]...,)) else - if (length(traced_args) + 1 != length(ir.argtypes)) || ( - length(traced_args) > 0 && - length(ir.argtypes) > 0 && - !(last(ir.argtypes) isa Core.Const) && - last(ir.argtypes) != typeof(traced_args[end]) - ) - @assert ir.argtypes[end] <: Tuple - oc( - traced_args[1:(length(ir.argtypes) - 2)]..., - (traced_args[(length(ir.argtypes) - 1):end]...,), - ) - else - oc(traced_args...) - end + oc(f, traced_args...) end end diff --git a/test/cuda.jl b/test/cuda.jl index 2475e6275..64a8caba0 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -11,7 +11,9 @@ end # basic squaring on GPU function square!(x) - @cuda blocks = 1 threads = length(x) square_kernel!(x) + # @cuda blocks = 1 threads = length(x) square_kernel!(x) + cr = @cuda launch=false square_kernel!(x) + @show cr return nothing end From da328d348bc823cb57590e1a95ac56ee940f490f Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 9 Dec 2024 02:04:15 -0500 Subject: [PATCH 11/78] inf rec --- ext/ReactantCUDAExt.jl | 2 +- src/Interpreter.jl | 10 +- src/utils.jl | 227 +++++++++++++++++++++++++++++++++-------- 3 files changed, 189 insertions(+), 50 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index fe3279df0..f396c7ebf 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -437,7 +437,7 @@ function compiler_cache(ctx::MLIR.IR.Context) return cache end -Reactant.@overlay function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} +Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} @show "recufunction", f, tt res = Base.@lock CUDA.cufunction_lock begin # compile the function diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 39986d12b..9708a76f9 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -22,12 +22,12 @@ import Core.Compiler: MethodResultPure -Base.Experimental.@MethodTable REACTANT_METHOD_TABLE +Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE) -macro overlay(method_expr) - def = splitdef(method_expr) - def[:name] = Expr(:overlay, :(Reactant.REACTANT_METHOD_TABLE), def[:name]) - return esc(combinedef(def)) +function var"@reactant_override"(__source__::LineNumberNode, __module__::Module, def) + return Base.Experimental.var"@overlay"( + __source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def + ) end function set_reactant_abi( diff --git a/src/utils.jl b/src/utils.jl index 50da46e3d..97a6bd799 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -39,43 +39,183 @@ end function call_with_reactant end -function rewrite_inst(inst) - @show inst - if Meta.isexpr(inst, :call) - rep = Expr(:call, call_with_reactant, inst.args...) - @show rep - return rep - end - return inst +# generate a LineInfoNode for the current source code location +macro LineInfoNode(method) + Core.LineInfoNode(__module__, method, __source__.file, Int32(__source__.line), Int32(0)) end -function call_with_reactant_generator(world::UInt, source::LineNumberNode, @nospecialize(F::Type), @nospecialize(N::Int), self, @nospecialize(f::Type), @nospecialize(args)) + +const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") + +function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @nospecialize(args)) @nospecialize - @show f, args + + @show args - stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :f, :args), Core.svec()) + stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec()) # look up the method match - method_error = :(throw(MethodError(f, args, $world))) + builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $args"))) + + if args[1] <: Core.Builtin + return stub(world, source, builtin_error) + end + + method_error = :(throw(MethodError(args[1], args[2:end], $world))) interp = ReactantInterpreter(; world) - mt = interp.method_table + sig = Tuple{args...} + lookup_result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp)).matches + + if lookup_result === nothing || lookup_result === missing + return stub(world, source, method_error) + end + + matches = lookup_result.matches - sig = Tuple{F, args...} - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - match = ccall(:jl_gf_invoke_lookup_worlds, Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - sig, mt, world, min_world, max_world) - match === nothing && return stub(world, source, method_error) + if length(matches) != 1 + return stub(world, source, method_error) + end + match = matches[1]::Core.MethodMatch + # look up the method and code instance mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any), match.method, match.spec_types, match.sparams) - + result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) - frame = Core.Compiler.InferenceState(result, #=cache_mode=#:global, interp) + src = Core.Compiler.retrieve_code_info(mi, world) + + # prepare a new code info + code_info = copy(src) + method = match.method + static_params = match.sparams + signature = sig + is_invoke = args[1] === typeof(Core.invoke) + + # propagate edge metadata + code_info.edges = Core.MethodInstance[mi] + code_info.min_world = lookup_result.valid_worlds.min_world + code_info.max_world = lookup_result.valid_worlds.max_world + + code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames...] + code_info.slotflags = UInt8[0x00, 0x00, code_info.slotflags...] + #code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] #code_info.slotnames...] + #code_info.slotflags = UInt8[0x00, 0x00] # code_info.slotflags...] + n_prepended_slots = 2 + overdub_args_slot = Core.SlotNumber(n_prepended_slots) + + # For the sake of convenience, the rest of this pass will translate `code_info`'s fields + # into these overdubbed equivalents instead of updating `code_info` in-place. Then, at + # the end of the pass, we'll reset `code_info` fields accordingly. + overdubbed_code = Any[] + overdubbed_codelocs = Int32[] + + # destructure the generated argument slots into the overdubbed method's argument slots. + n_actual_args = fieldcount(signature) + n_method_args = Int(method.nargs) + offset = 1 + fn_args = Any[] + for i in 1:n_method_args + if is_invoke && (i == 1 || i == 2) + # With an invoke call, we have: 1 is invoke, 2 is f, 3 is Tuple{}, 4... is args. + # In the first loop iteration, we should skip invoke and process f. + # In the second loop iteration, we should skip the Tuple type and process args[1]. + offset += 1 + end + slot = i + n_prepended_slots + actual_argument = Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset) + push!(overdubbed_code, :($(Core.SlotNumber(slot)) = $actual_argument)) + push!(overdubbed_codelocs, code_info.codelocs[1]) + code_info.slotflags[slot] |= 0x02 # ensure this slotflag has the "assigned" bit set + offset += 1 + + #push!(overdubbed_code, actual_argument) + push!(fn_args, Core.SSAValue(length(overdubbed_code))) + end + + # If `method` is a varargs method, we have to restructure the original method call's + # trailing arguments into a tuple and assign that tuple to the expected argument slot. + if method.isva + if !isempty(overdubbed_code) + # remove the final slot reassignment leftover from the previous destructuring + pop!(overdubbed_code) + pop!(overdubbed_codelocs) + pop!(fn_args) + end + trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) + for i in n_method_args:n_actual_args + push!(overdubbed_code, Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - 1)) + push!(overdubbed_codelocs, code_info.codelocs[1]) + push!(trailing_arguments.args, Core.SSAValue(length(overdubbed_code))) + offset += 1 + end + push!(overdubbed_code, Expr(:(=), Core.SlotNumber(n_method_args + n_prepended_slots), trailing_arguments)) + push!(overdubbed_codelocs, code_info.codelocs[1]) + push!(fn_args, Core.SSAValue(length(overdubbed_code))) + end + + #=== finish initialization of `overdubbed_code`/`overdubbed_codelocs` ===# + + # substitute static parameters, offset slot numbers by number of added slots, and + # offset statement indices by the number of additional statements + @show code_info.code + + @show n_prepended_slots + Base.Meta.partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...], + n_prepended_slots, length(overdubbed_code), :propagate) + @show code_info.code + + #callexpr = Expr(:call, Core.OpaqueClosure(ir), fn_args...) + #push!(overdubbed_code, callexpr) + #push!(overdubbed_codelocs, code_info.codelocs[1]) + + #push!(new_ci.code, Core.Compiler.ReturnNode(Core.SSAValue(length(overdubbed_code)))) + #push!(overdubbed_codelocs, code_info.codelocs[1]) + + # original_code_start_index = length(overdubbed_code) + 1 + + append!(overdubbed_code, code_info.code) + append!(overdubbed_codelocs, code_info.codelocs) + + @show overdubbed_code + + for i in eachindex(overdubbed_code) + prev = overdubbed_code[i] + if Base.Meta.isexpr(prev, :call) + @show prev + @show prev.args[1] + @show prev.args[1] isa Core.IntrinsicFunction + if !(prev.args[1] isa Core.IntrinsicFunction) + overdubbed_code[i] = Expr(:call, GlobalRef(Reactant, :call_with_reactant), prev.args...) + @show "post", overdubbed_code[i] + end + end + end + + #=== set `code_info`/`reflection` fields accordingly ===# + + if code_info.method_for_inference_limit_heuristics === nothing + code_info.method_for_inference_limit_heuristics = method + end + + code_info.code = overdubbed_code + code_info.codelocs = overdubbed_codelocs + code_info.ssavaluetypes = length(overdubbed_code) + code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code + self_result = Core.Compiler.InferenceResult(self_mi, Core.Compiler.typeinf_lattice(interp)) + + @show code_info + + @show self + self_meths = Base._methods_by_ftype(Tuple{self, Vararg{Any}}, -1, world) + @show self_meths + self_method = (self_meths[1]::Core.MethodMatch).method + self_mi = Core.Compiler.specialize_method(self_method, Tuple{typeof(Reactant.call_with_reactant), sig.parameters...}, Core.svec()) + @show self_mi + self_result = Core.Compiler.InferenceResult(self_mi, Core.Compiler.typeinf_lattice(interp)) + frame = Core.Compiler.InferenceState(self_result, code_info, #=cache_mode=#:global, interp) @assert frame !== nothing Core.Compiler.typeinf(interp, frame) @assert Core.Compiler.is_inferred(frame) @@ -85,36 +225,37 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, @nosp # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) #else opt = Core.Compiler.OptimizationState(frame, interp) + + ir = opt.src + @show ir + for (i, stmt) in enumerate(ir.stmts) + @show stmt + + end + + @show ir + caller = frame.result @static if VERSION < v"1.11-" - ir = Core.Compiler.run_passes(opt.src, opt, caller) + ir = Core.Compiler.run_passes(ir, opt, caller) else - ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) + ir = Core.Compiler.run_passes_ipo_safe(ir, opt, caller) Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller) end - @show ir - for (i, inst) in enumerate(ir.stmts) - @static if VERSION < v"1.11" - Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst]), :inst) - else - Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt]), :stmt) - end - end - @show ir Core.Compiler.finish(interp, opt, ir, caller) + src = Core.Compiler.ir_to_codeinf!(opt) #end - new_ci = copy(src) - new_ci.slotnames = Symbol[Symbol("#self#"), :f, :args] - new_ci.edges = Core.MethodInstance[mi] - new_ci.min_world = min_world[] - new_ci.max_world = max_world[] + src = copy(src) + src.ssavaluetypes = length(src.code) - return new_ci + @show src + + return src end -@eval function call_with_reactant(f::F, args::Vararg{Any, N}) where {F, N} +@eval function call_with_reactant($REDUB_ARGUMENTS_NAME...) $(Expr(:meta, :generated_only)) $(Expr(:meta, :generated, call_with_reactant_generator)) end @@ -214,12 +355,10 @@ function make_mlir_fn( # TODO replace with `Base.invoke_within` if julia#52964 lands # TODO fix it for kwargs - oc = call_with_reactant # Core.OpaqueClosure(ir) - if f === Reactant.apply - oc(f, traced_args[1], (traced_args[2:end]...,)) + call_with_reactant(f, traced_args[1], (traced_args[2:end]...,)) else - oc(f, traced_args...) + call_with_reactant(f, traced_args...) end end From ad4d05bc8745b00ea3b615100f7c38661bfd7093 Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 9 Dec 2024 10:43:13 -0500 Subject: [PATCH 12/78] fix --- src/utils.jl | 210 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 204 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 97a6bd799..c7ae2567b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -45,14 +45,168 @@ macro LineInfoNode(method) end +function rewrite_inst(inst) + @show inst + if Meta.isexpr(inst, :call) + rep = Expr(:call, call_with_reactant, inst.args...) + @show rep + return rep + end + if Meta.isexpr(inst, :invoke) + return Expr(:call, inst.args[2:end]...) + end + return inst +end + const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") -function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @nospecialize(args)) +function arg_partially_inline!(code::Vector{Any}, slot_replacements::Vector{Any}, + @nospecialize(type_signature)#=::Type{<:Tuple}=#, + static_param_values::Vector{Any}, + slot_offset::Int, arg_offset::Int, statement_offset::Int, + boundscheck::Symbol) + for i = 1:length(code) + isassigned(code, i) || continue + code[i] = _arg_partially_inline!(code[i], slot_replacements, type_signature, + static_param_values, slot_offset, arg_offset, + statement_offset, boundscheck) + end + return code +end + +function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any}, + @nospecialize(type_signature), static_param_values::Vector{Any}, + slot_offset::Int, arg_offset::Int, statement_offset::Int, + boundscheck::Symbol) + if isa(x, Core.SSAValue) + return Core.SSAValue(x.id + statement_offset) + end + if isa(x, Core.GotoNode) + return Core.GotoNode(x.label + statement_offset) + end + if isa(x, Core.SlotNumber) + id = x.id + if 1 <= id <= length(slot_replacements) + return slot_replacements[id] + end + return Core.SlotNumber(id + slot_offset) + end + if isa(x, Core.Argument) + return Core.SlotNumber(x.n + arg_offset) + end + if isa(x, Core.NewvarNode) + return Core.NewvarNode(_arg_partially_inline!(x.slot, slot_replacements, type_signature, + static_param_values, slot_offset, arg_offset, + statement_offset, boundscheck)) + end + if isa(x, Core.PhiNode) + arg_partially_inline!(x.values, slot_replacements, type_signature, static_param_values, + slot_offset, arg_offset, statement_offset, boundscheck) + x.edges .+= slot_offset + return x + end + if isa(x, Core.ReturnNode) + return Core.ReturnNode( + _arg_partially_inline!(x.val, slot_replacements, type_signature, static_param_values, + slot_offset, arg_offset, statement_offset, boundscheck), + ) + end + if isa(x, Core.GotoIfNot) + return Core.GotoIfNot( + _arg_partially_inline!(x.cond, slot_replacements, type_signature, static_param_values, + slot_offset, arg_offset, statement_offset, boundscheck), + x.dest + statement_offset, + ) + end + if isdefined(Core, :EnterNode) && isa(x, Core.EnterNode) + return Core.EnterNode(x, x.catch_dest + statement_offset) + end + if isa(x, Expr) + head = x.head + if head === :static_parameter + if isassigned(static_param_values, x.args[1]) + return QuoteNode(static_param_values[x.args[1]]) + end + return x + elseif head === :cfunction + @assert !isa(type_signature, UnionAll) || !isempty(spvals) + if !isa(x.args[2], QuoteNode) # very common no-op + x.args[2] = Core.Compiler._partially_inline!(x.args[2], slot_replacements, type_signature, + static_param_values, slot_offset, arg_offset, + statement_offset, boundscheck) + end + x.args[3] = Core.Compiler._instantiate_type_in_env(x.args[3], type_signature, static_param_values) + x.args[4] = Core.svec(Any[Core.Compiler._instantiate_type_in_env(argt, type_signature, static_param_values) for argt in x.args[4]]...) + elseif head === :foreigncall + @assert !isa(type_signature, UnionAll) || !isempty(static_param_values) + for i = 1:length(x.args) + if i == 2 + x.args[2] = Core.Compiler._instantiate_type_in_env(x.args[2], type_signature, static_param_values) + elseif i == 3 + x.args[3] = Core.svec(Any[Core.Compiler._instantiate_type_in_env(argt, type_signature, static_param_values) for argt in x.args[3]]...) + elseif i == 4 + @assert isa(x.args[4], Int) + elseif i == 5 + @assert isa((x.args[5]::QuoteNode).value, Union{Symbol, Tuple{Symbol, UInt8}}) + else + x.args[i] = _arg_partially_inline!(x.args[i], slot_replacements, + type_signature, static_param_values, + slot_offset, statement_offset, arg_offset, + boundscheck) + end + end + elseif head === :boundscheck + if boundscheck === :propagate + return x + elseif boundscheck === :off + return false + else + return true + end + elseif head === :gotoifnot + x.args[1] = _arg_partially_inline!(x.args[1], slot_replacements, type_signature, + static_param_values, slot_offset, arg_offset, + statement_offset, boundscheck) + x.args[2] += statement_offset + elseif head === :isdefined + arg = x.args[1] + # inlining a QuoteNode or literal into `Expr(:isdefined, x)` is invalid, replace with true + if isa(arg, Core.SlotNumber) + id = arg.id + if 1 <= id <= length(slot_replacements) + replacement = slot_replacements[id] + if isa(replacement, Union{Core.SlotNumber, GlobalRef, Symbol}) + return Expr(:isdefined, replacement) + else + @assert !isa(replacement, Expr) + return true + end + end + return Expr(:isdefined, Core.SlotNumber(id + slot_offset)) + elseif isexpr(arg, :static_parameter) + if isassigned(static_param_values, arg.args[1]) + return true + end + return x + else + @assert isa(arg, Union{GlobalRef, Symbol}) + return x + end + elseif !Core.Compiler.is_meta_expr_head(head) + arg_partially_inline!(x.args, slot_replacements, type_signature, static_param_values, + slot_offset, arg_offset, statement_offset, boundscheck) + end + end + return x +end + +function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @nospecialize(redub_arguments)) @nospecialize + args = redub_arguments @show args - stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec()) + stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, :redub_arguments), Core.svec()) # look up the method match builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $args"))) @@ -85,7 +239,40 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, (Any, Any, Any), match.method, match.spec_types, match.sparams) result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) - src = Core.Compiler.retrieve_code_info(mi, world) + @static if true + frame = Core.Compiler.InferenceState(result, #=cache_mode=#:local, interp) + @assert frame !== nothing + Core.Compiler.typeinf(interp, frame) + @assert Core.Compiler.is_inferred(frame) + + #if Core.Compiler.result_is_constabi(interp, frame.result) + # rt = frame.result.result::Core.Compiler.Const + # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) + #else + opt = Core.Compiler.OptimizationState(frame, interp) + + caller = frame.result + @static if VERSION < v"1.11-" + ir = Core.Compiler.run_passes(opt.src, opt, caller) + else + ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) + Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller) + end + + for (i, inst) in enumerate(ir.stmts) + @static if VERSION < v"1.11" + Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst]), :inst) + else + Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt]), :stmt) + end + end + Core.Compiler.finish(interp, opt, ir, caller) + + src = Core.Compiler.ir_to_codeinf!(opt) + #end + else + src = Core.Compiler.retrieve_code_info(mi, world) + end # prepare a new code info code_info = copy(src) @@ -99,7 +286,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, code_info.min_world = lookup_result.valid_worlds.min_world code_info.max_world = lookup_result.valid_worlds.max_world - code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames...] + code_info.slotnames = Any[:call_with_reactant, :redub_arguments, code_info.slotnames...] code_info.slotflags = UInt8[0x00, 0x00, code_info.slotflags...] #code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] #code_info.slotnames...] #code_info.slotflags = UInt8[0x00, 0x00] # code_info.slotflags...] @@ -163,8 +350,13 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @show code_info.code @show n_prepended_slots + @static if false Base.Meta.partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...], n_prepended_slots, length(overdubbed_code), :propagate) + else + arg_partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...], + n_prepended_slots, n_prepended_slots, length(overdubbed_code), :propagate) + end @show code_info.code #callexpr = Expr(:call, Core.OpaqueClosure(ir), fn_args...) @@ -181,6 +373,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @show overdubbed_code + @static if false for i in eachindex(overdubbed_code) prev = overdubbed_code[i] if Base.Meta.isexpr(prev, :call) @@ -193,6 +386,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, end end end + end #=== set `code_info`/`reflection` fields accordingly ===# @@ -204,9 +398,12 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, code_info.codelocs = overdubbed_codelocs code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - self_result = Core.Compiler.InferenceResult(self_mi, Core.Compiler.typeinf_lattice(interp)) @show code_info + return code_info + + self_result = Core.Compiler.InferenceResult(self_mi, Core.Compiler.typeinf_lattice(interp)) + @show self self_meths = Base._methods_by_ftype(Tuple{self, Vararg{Any}}, -1, world) @@ -241,6 +438,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, else ir = Core.Compiler.run_passes_ipo_safe(ir, opt, caller) Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller) + end Core.Compiler.finish(interp, opt, ir, caller) @@ -255,7 +453,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, return src end -@eval function call_with_reactant($REDUB_ARGUMENTS_NAME...) +@eval function call_with_reactant(redub_arguments...) $(Expr(:meta, :generated_only)) $(Expr(:meta, :generated, call_with_reactant_generator)) end From 750661a9721ce21412d374487d3837602d6b3aea Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 9 Dec 2024 17:32:05 -0500 Subject: [PATCH 13/78] overload working --- ext/ReactantCUDAExt.jl | 32 +++++------ src/utils.jl | 117 +++++++++-------------------------------- 2 files changed, 39 insertions(+), 110 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index f396c7ebf..c8021975f 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -206,6 +206,13 @@ end const _kernel_instances = Dict{Any, Any}() +struct LLVMFunc{F,tt} + f::Union{F, Nothing} + mod::String + image + entry::String +end + # compile to executable machine code function compile(job) @@ -218,8 +225,8 @@ function compile(job) @show mod @show modstr # check if we'll need the device runtime - undefined_fs = filter(collect(functions(meta.ir))) do f - isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) + undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f + CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) end intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail", "__nvvm_reflect" #= TODO: should have been optimized away =#] @@ -246,7 +253,7 @@ function compile(job) # validate use of parameter memory argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt - !isghosttype(dt) && !Core.Compiler.isconstType(dt) + !CUDA.isghosttype(dt) && !Core.Compiler.isconstType(dt) end param_usage = sum(sizeof, argtypes) param_limit = 4096 @@ -268,7 +275,7 @@ function compile(job) end for (i, typ) in enumerate(source_types) - if isghosttype(typ) || Core.Compiler.isconstType(typ) + if CUDA.isghosttype(typ) || Core.Compiler.isconstType(typ) continue end name = source_argnames[i] @@ -306,7 +313,7 @@ function compile(job) "--output-file", ptxas_output, ptx_input ]) - proc, log = CUDA.run_and_collect(`$(ptxas()) $ptxas_opts`) + proc, log = CUDA.run_and_collect(`$(CUDA.ptxas()) $ptxas_opts`) log = strip(log) if !success(proc) reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" : @@ -342,7 +349,7 @@ function compile(job) "--output-file", nvlink_output, ptxas_output ]) - proc, log = run_and_collect(`$(nvlink()) $nvlink_opts`) + proc, log = run_and_collect(`$(CUDA.nvlink()) $nvlink_opts`) log = strip(log) if !success(proc) reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" : @@ -369,11 +376,7 @@ function compile(job) modstr, image, meta.entry end - println(string(modstr)) - @show job - @show job.source - @show job.config - LLVMFunc{F,job.source.specTypes}(f, modstr, image, LLVM.name(entry)) + LLVMFunc{job.source.specTypes[1],job.source.specTypes}(nothing, modstr, image, LLVM.name(entry)) end # link into an executable kernel @@ -382,13 +385,6 @@ function link(job, compiled) return compiled end -struct LLVMFunc{F,tt} - f::F - mod::String - image - entry::String -end - function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuDim=1, shmem::Integer=0) where{F, tt} blockdim = CUDA.CuDim3(blocks) diff --git a/src/utils.jl b/src/utils.jl index c7ae2567b..704051649 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -45,12 +45,26 @@ macro LineInfoNode(method) end -function rewrite_inst(inst) - @show inst + +function maybe_argextype( + @nospecialize(x), + src, +) + return try + Core.Compiler.argextype(x, src) + catch err + !(err isa Core.Compiler.InvalidIRError) && rethrow() + nothing + end +end + +function rewrite_inst(inst, ir) if Meta.isexpr(inst, :call) - rep = Expr(:call, call_with_reactant, inst.args...) - @show rep - return rep + ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) + if !(ft <: Core.IntrinsicFunction) && !(ft <: Core.Builtin) + rep = Expr(:call, call_with_reactant, inst.args...) + return rep + end end if Meta.isexpr(inst, :invoke) return Expr(:call, inst.args[2:end]...) @@ -204,12 +218,11 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @nospecialize args = redub_arguments - @show args stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, :redub_arguments), Core.svec()) # look up the method match - builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $args"))) + builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $redub_arguments"))) if args[1] <: Core.Builtin return stub(world, source, builtin_error) @@ -218,7 +231,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, method_error = :(throw(MethodError(args[1], args[2:end], $world))) interp = ReactantInterpreter(; world) - + sig = Tuple{args...} lookup_result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp)).matches @@ -239,7 +252,6 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, (Any, Any, Any), match.method, match.spec_types, match.sparams) result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) - @static if true frame = Core.Compiler.InferenceState(result, #=cache_mode=#:local, interp) @assert frame !== nothing Core.Compiler.typeinf(interp, frame) @@ -260,19 +272,16 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, end for (i, inst) in enumerate(ir.stmts) + @static if VERSION < v"1.11" - Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst]), :inst) + Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst], ir), :inst) else - Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt]), :stmt) + Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt], ir), :stmt) end + Core.Compiler.setindex!(ir.stmts[i], Any, :type) end Core.Compiler.finish(interp, opt, ir, caller) - src = Core.Compiler.ir_to_codeinf!(opt) - #end - else - src = Core.Compiler.retrieve_code_info(mi, world) - end # prepare a new code info code_info = copy(src) @@ -347,17 +356,9 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, # substitute static parameters, offset slot numbers by number of added slots, and # offset statement indices by the number of additional statements - @show code_info.code - @show n_prepended_slots - @static if false - Base.Meta.partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...], - n_prepended_slots, length(overdubbed_code), :propagate) - else arg_partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...], n_prepended_slots, n_prepended_slots, length(overdubbed_code), :propagate) - end - @show code_info.code #callexpr = Expr(:call, Core.OpaqueClosure(ir), fn_args...) #push!(overdubbed_code, callexpr) @@ -371,23 +372,6 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, append!(overdubbed_code, code_info.code) append!(overdubbed_codelocs, code_info.codelocs) - @show overdubbed_code - - @static if false - for i in eachindex(overdubbed_code) - prev = overdubbed_code[i] - if Base.Meta.isexpr(prev, :call) - @show prev - @show prev.args[1] - @show prev.args[1] isa Core.IntrinsicFunction - if !(prev.args[1] isa Core.IntrinsicFunction) - overdubbed_code[i] = Expr(:call, GlobalRef(Reactant, :call_with_reactant), prev.args...) - @show "post", overdubbed_code[i] - end - end - end - end - #=== set `code_info`/`reflection` fields accordingly ===# if code_info.method_for_inference_limit_heuristics === nothing @@ -399,58 +383,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - @show code_info return code_info - - self_result = Core.Compiler.InferenceResult(self_mi, Core.Compiler.typeinf_lattice(interp)) - - - @show self - self_meths = Base._methods_by_ftype(Tuple{self, Vararg{Any}}, -1, world) - @show self_meths - self_method = (self_meths[1]::Core.MethodMatch).method - self_mi = Core.Compiler.specialize_method(self_method, Tuple{typeof(Reactant.call_with_reactant), sig.parameters...}, Core.svec()) - @show self_mi - self_result = Core.Compiler.InferenceResult(self_mi, Core.Compiler.typeinf_lattice(interp)) - frame = Core.Compiler.InferenceState(self_result, code_info, #=cache_mode=#:global, interp) - @assert frame !== nothing - Core.Compiler.typeinf(interp, frame) - @assert Core.Compiler.is_inferred(frame) - - #if Core.Compiler.result_is_constabi(interp, frame.result) - # rt = frame.result.result::Core.Compiler.Const - # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) - #else - opt = Core.Compiler.OptimizationState(frame, interp) - - ir = opt.src - @show ir - for (i, stmt) in enumerate(ir.stmts) - @show stmt - - end - - @show ir - - caller = frame.result - @static if VERSION < v"1.11-" - ir = Core.Compiler.run_passes(ir, opt, caller) - else - ir = Core.Compiler.run_passes_ipo_safe(ir, opt, caller) - Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller) - - end - Core.Compiler.finish(interp, opt, ir, caller) - - src = Core.Compiler.ir_to_codeinf!(opt) - #end - - src = copy(src) - src.ssavaluetypes = length(src.code) - - @show src - - return src end @eval function call_with_reactant(redub_arguments...) From 12dec6c37426df95fbb2744224b10c570686b264 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 10 Dec 2024 00:28:31 -0500 Subject: [PATCH 14/78] continuing --- ext/ReactantCUDAExt.jl | 25 ++++++++++++---------- src/utils.jl | 48 +++++++++++++++++++++++++++++++++--------- test/cuda.jl | 7 +++--- 3 files changed, 55 insertions(+), 25 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index c8021975f..1b1feabf2 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -222,8 +222,6 @@ function compile(job) asm, meta = CUDA.GPUCompiler.compile(:asm, job) mod = meta.ir modstr = string(mod) - @show mod - @show modstr # check if we'll need the device runtime undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) @@ -375,8 +373,7 @@ function compile(job) modstr, image, meta.entry end - - LLVMFunc{job.source.specTypes[1],job.source.specTypes}(nothing, modstr, image, LLVM.name(entry)) + LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, modstr, image, CUDA.LLVM.name(entry)) end # link into an executable kernel @@ -385,20 +382,23 @@ function link(job, compiled) return compiled end -function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuDim=1, - shmem::Integer=0) where{F, tt} +function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, + cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} + @show args + @show call_kwargs + blockdim = CUDA.CuDim3(blocks) threaddim = CUDA.CuDim3(threads) - @show args - mlir_args = MLIR.IR.Value[] restys = MLIR.IR.Type[] aliases = MLIR.API.MlirAttribute[] + rarrays = TracedRArray[] for (i, a) in enumerate(args) @show a - @assert a isa CuDeviceArray - ta = Base.pointer_to_objref(a.ptr)::TracedRArray + @assert a isa CuTracedArray + ta = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray + push!(rarrays, ta) arg = ta.mlir_data arg = Reactant.Compiler.transpose_val(arg) push!(restys, MLIR.IR.Type(arg)) @@ -415,7 +415,10 @@ function (func::LLVMFunc{F,tt})(args...; blocks::CUDA.CuDim=1, threads::CUDA.CuD end output_operand_aliases=MLIR.ArrayAttr.get(MLIR.IR.context(), aliases) - MLIR.IR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases) + call = MLIR.IR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases) + for (i, res) in enumerate(rarrays) + ta.mlir_data = Reactant.Compiler.transpose_val(MLIR.IR.result(call, i-1)) + end #CUDA.cuLaunchKernel(f, # blockdim.x, blockdim.y, blockdim.z, # threaddim.x, threaddim.y, threaddim.z, diff --git a/src/utils.jl b/src/utils.jl index 704051649..759c2a4b0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -63,13 +63,13 @@ function rewrite_inst(inst, ir) ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) if !(ft <: Core.IntrinsicFunction) && !(ft <: Core.Builtin) rep = Expr(:call, call_with_reactant, inst.args...) - return rep + return true, rep end end if Meta.isexpr(inst, :invoke) - return Expr(:call, inst.args[2:end]...) + return false, Expr(:call, inst.args[2:end]...) end - return inst + return false, inst end const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") @@ -120,10 +120,14 @@ function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any} return x end if isa(x, Core.ReturnNode) - return Core.ReturnNode( + if !isdefined(x, :val) + return Core.ReturnNode(:nothing) + else + return Core.ReturnNode( _arg_partially_inline!(x.val, slot_replacements, type_signature, static_param_values, slot_offset, arg_offset, statement_offset, boundscheck), - ) + ) + end end if isa(x, Core.GotoIfNot) return Core.GotoIfNot( @@ -257,12 +261,19 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, Core.Compiler.typeinf(interp, frame) @assert Core.Compiler.is_inferred(frame) + method = match.method + @show mi + @show method + #if Core.Compiler.result_is_constabi(interp, frame.result) # rt = frame.result.result::Core.Compiler.Const # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) #else opt = Core.Compiler.OptimizationState(frame, interp) + @show Core.Compiler.retrieve_code_info(mi, world) + @show opt.src + caller = frame.result @static if VERSION < v"1.11-" ir = Core.Compiler.run_passes(opt.src, opt, caller) @@ -271,21 +282,35 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller) end - for (i, inst) in enumerate(ir.stmts) + @show ir + any_changed = false + for (i, inst) in enumerate(ir.stmts) + @static if VERSION < v"1.11" - Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst], ir), :inst) + changed, next = rewrite_inst(inst[:inst], ir) + Core.Compiler.setindex!(ir.stmts[i], next, :inst) else - Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt], ir), :stmt) + changed, next = rewrite_inst(inst[:stmt], ir) + Core.Compiler.setindex!(ir.stmts[i], next, :stmt) end - Core.Compiler.setindex!(ir.stmts[i], Any, :type) + if changed + any_changed = true + Core.Compiler.setindex!(ir.stmts[i], Any, :type) + end end Core.Compiler.finish(interp, opt, ir, caller) + @show "post", ir src = Core.Compiler.ir_to_codeinf!(opt) + + @show any_changed, src + if !any_changed + src = Core.Compiler.retrieve_code_info(mi, world) + @show "post non change", src + end # prepare a new code info code_info = copy(src) - method = match.method static_params = match.sparams signature = sig is_invoke = args[1] === typeof(Core.invoke) @@ -352,6 +377,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, push!(fn_args, Core.SSAValue(length(overdubbed_code))) end + @show code_info.code #=== finish initialization of `overdubbed_code`/`overdubbed_codelocs` ===# # substitute static parameters, offset slot numbers by number of added slots, and @@ -383,6 +409,8 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code + @show code_info + return code_info end diff --git a/test/cuda.jl b/test/cuda.jl index 64a8caba0..a02f45eec 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -11,15 +11,14 @@ end # basic squaring on GPU function square!(x) - # @cuda blocks = 1 threads = length(x) square_kernel!(x) - cr = @cuda launch=false square_kernel!(x) - @show cr + @cuda blocks = 1 threads = length(x) square_kernel!(x) return nothing end @testset "Square Kernel" begin oA = collect(1:1:64) A = Reactant.to_rarray(oA) + @show @code_hlo square!(A) func = @compile square!(A) - @test all(A .≈ (oA .* oA)) + @test all(Array(A) .≈ (oA .* oA)) end From a6cd1044f3aeb6985ef665c8b17ba21543b54065 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 10 Dec 2024 02:09:25 -0500 Subject: [PATCH 15/78] continuing --- ext/ReactantCUDAExt.jl | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 1b1feabf2..82a2c8514 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -382,6 +382,13 @@ function link(job, compiled) return compiled end +function transpose_val(val) + attr = MLIR.IR.DenseArrayAttribute( + Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...] + ) + return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) +end + function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} @show args @@ -392,7 +399,7 @@ function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, th mlir_args = MLIR.IR.Value[] restys = MLIR.IR.Type[] - aliases = MLIR.API.MlirAttribute[] + aliases = MLIR.IR.Attribute[] rarrays = TracedRArray[] for (i, a) in enumerate(args) @show a @@ -400,24 +407,25 @@ function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, th ta = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray push!(rarrays, ta) arg = ta.mlir_data - arg = Reactant.Compiler.transpose_val(arg) - push!(restys, MLIR.IR.Type(arg)) + arg = transpose_val(arg) + @show arg + push!(restys, MLIR.IR.type(arg)) push!(aliases, - MLIR.IR.Dialects.stablehlo.stablehloOutputOperandAliasGet( + MLIR.IR.Attribute(MLIR.API.stablehloOutputOperandAliasGet( MLIR.IR.context(), - len(args) == 1 ? 0 : 1, - len(args) == 1 ? C_NULL : Ref{Int64}(i-1), + length(args) == 1 ? 0 : 1, + length(args) == 1 ? C_NULL : Ref{Int64}(i-1), i-1, 0, C_NULL - ) + )) ) end - output_operand_aliases=MLIR.ArrayAttr.get(MLIR.IR.context(), aliases) - call = MLIR.IR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases) + output_operand_aliases=MLIR.IR.Attribute(aliases) + call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases) for (i, res) in enumerate(rarrays) - ta.mlir_data = Reactant.Compiler.transpose_val(MLIR.IR.result(call, i-1)) + res.mlir_data = transpose_val(MLIR.IR.result(call, i)) end #CUDA.cuLaunchKernel(f, # blockdim.x, blockdim.y, blockdim.z, From db6e37b1c99cecaf64034ff2875a3400808179e2 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 10 Dec 2024 02:19:15 -0500 Subject: [PATCH 16/78] push --- ext/ReactantCUDAExt.jl | 1 + test/cuda.jl | 1 + 2 files changed, 2 insertions(+) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 82a2c8514..c369ebdd4 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -410,6 +410,7 @@ function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, th arg = transpose_val(arg) @show arg push!(restys, MLIR.IR.type(arg)) + push!(mlir_args, arg) push!(aliases, MLIR.IR.Attribute(MLIR.API.stablehloOutputOperandAliasGet( MLIR.IR.context(), diff --git a/test/cuda.jl b/test/cuda.jl index a02f45eec..ae1b473f6 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -18,6 +18,7 @@ end @testset "Square Kernel" begin oA = collect(1:1:64) A = Reactant.to_rarray(oA) + @show @code_hlo optimize=false square!(A) @show @code_hlo square!(A) func = @compile square!(A) @test all(Array(A) .≈ (oA .* oA)) From 8831f4d6940431c60a4595e72e5519cfb90f58c4 Mon Sep 17 00:00:00 2001 From: jumerckx <31353884+jumerckx@users.noreply.github.com> Date: Tue, 10 Dec 2024 17:51:24 +0100 Subject: [PATCH 17/78] fix `call_with_reactant_generator` for Julia 1.11 (#359) --- src/utils.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 759c2a4b0..db1de4505 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -237,7 +237,10 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, interp = ReactantInterpreter(; world) sig = Tuple{args...} - lookup_result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp)).matches + lookup_result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp)) + @static if VERSION < v"1.11-" + lookup_result = lookup_result.matches + end if lookup_result === nothing || lookup_result === missing return stub(world, source, method_error) @@ -259,6 +262,10 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, frame = Core.Compiler.InferenceState(result, #=cache_mode=#:local, interp) @assert frame !== nothing Core.Compiler.typeinf(interp, frame) + @static if VERSION >= v"1.11" + # `typeinf` doesn't update the cfg. We need to do it manually. + frame.cfg = Core.Compiler.compute_basic_blocks(frame.src.code) + end @assert Core.Compiler.is_inferred(frame) method = match.method @@ -279,7 +286,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, ir = Core.Compiler.run_passes(opt.src, opt, caller) else ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) - Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller) + Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end @show ir From e2ffe874a16a8f750c3a326f2c12b69657634691 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 10 Dec 2024 15:25:18 -0500 Subject: [PATCH 18/78] conversion --- deps/ReactantExtra/API.cpp | 8 ++++++++ ext/ReactantCUDAExt.jl | 15 ++++++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 8614bdcd9..622dde4c4 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -376,6 +376,14 @@ extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) { return wrap(res); } +extern "C" MlirModule ConvertLLVMStrToMLIR(const char* lmod, MlirContext cctx) { + LLVMContext Context; + auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Context); + mlir::MLIRContext &context = *unwrap(cctx); + auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context, /*emitExpensiveWarnings*/false, /*dropDICompositeElements*/false).release(); + return wrap(res); +} + /* Note that this */ extern "C" xla::PjRtLoadedExecutable* ClientCompile(PjRtClient * client, MlirModule cmod) { diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index c369ebdd4..a31c07a19 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -221,7 +221,15 @@ function compile(job) modstr, image, entry = CUDA.GPUCompiler.JuliaContext() do ctx asm, meta = CUDA.GPUCompiler.compile(:asm, job) mod = meta.ir + modstr = string(mod) + + # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version + # it is probably safer to reparse a string using the right llvm module api, so we will do that. + + mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMToMLIR(mod::CUDA.LLVM.API.LLVMModuleRef, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule) + @show mmod + # check if we'll need the device runtime undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) @@ -424,7 +432,8 @@ function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, th end output_operand_aliases=MLIR.IR.Attribute(aliases) - call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases) + call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute("configstr")) + # call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod)) for (i, res) in enumerate(rarrays) res.mlir_data = transpose_val(MLIR.IR.result(call, i)) end @@ -459,4 +468,8 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg res end +function __init__() + +end + end # module ReactantCUDAExt From 364823afc05b1358ce729b2e99af5a75021a1512 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 10 Dec 2024 20:27:14 -0500 Subject: [PATCH 19/78] continuing --- deps/ReactantExtra/API.cpp | 4 +- deps/ReactantExtra/BUILD | 2 + ext/ReactantCUDAExt.jl | 92 +++++++++++++++++++++++++++++++++++--- 3 files changed, 90 insertions(+), 8 deletions(-) diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 622dde4c4..f93b32ea4 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -376,9 +376,11 @@ extern "C" MlirModule ConvertLLVMToMLIR(LLVMModuleRef lmod, MlirContext cctx) { return wrap(res); } +#include "llvm/IRReader/IRReader.h" extern "C" MlirModule ConvertLLVMStrToMLIR(const char* lmod, MlirContext cctx) { LLVMContext Context; - auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Context); + SMDiagnostic Err; + auto llvmModule = llvm::parseIR(llvm::MemoryBufferRef(lmod, "conversion"), Err, Context); mlir::MLIRContext &context = *unwrap(cctx); auto res = mlir::translateLLVMIRToModule(std::move(llvmModule), &context, /*emitExpensiveWarnings*/false, /*dropDICompositeElements*/false).release(); return wrap(res); diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index e7157d89c..c718304bd 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -450,6 +450,8 @@ cc_library( "@llvm-project//mlir:SCFDialect", "@llvm-project//mlir:TransformDialect", "@llvm-project//mlir:Transforms", + + "@llvm-project//llvm:IRReader", "@llvm-project//llvm:Support", "@llvm-project//llvm:AArch64AsmParser", "@llvm-project//llvm:AArch64CodeGen", diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index a31c07a19..ad13922f3 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -214,20 +214,88 @@ struct LLVMFunc{F,tt} end +const GPUCompiler = CUDA.GPUCompiler +const LLVM = GPUCompiler.LLVM + + +GPULowerCPUFeaturesPass() = LLVM.NewPMModulePass("GPULowerCPUFeatures", GPUCompiler.cpu_features!) +GPULowerPTLSPass() = LLVM.NewPMModulePass("GPULowerPTLS", GPUCompiler.lower_ptls!) +GPULowerGCFramePass() = LLVM.NewPMFunctionPass("GPULowerGCFrame", GPUCompiler.lower_gc_frame!) +function noop_pass(x) + return false +end +function kern_pass(mod) + for fname in ("julia.gpu.state_getter",) + if LLVM.haskey(LLVM.functions(mod), fname) + fn = LLVM.functions(mod)[fname] + insts = LLVM.Instruction[] + for u in LLVM.uses(fn) + u = LLVM.user(u) + LLVM.replace_uses!(u, LLVM.UndefValue(LLVM.value_type(u))) + push!(insts, u) + end + for inst in insts + Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst) + end + Reactant.Enzyme.Compiler.eraseInst(mod, fn) + end + end + + return true +end +AddKernelStatePass() = LLVM.NewPMModulePass("AddKernelStatePass", kern_pass) +LowerKernelStatePass() = LLVM.NewPMFunctionPass("LowerKernelStatePass", noop_pass) +CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_pass) + # compile to executable machine code function compile(job) + # lower to PTX # TODO: on 1.9, this actually creates a context. cache those. - modstr, image, entry = CUDA.GPUCompiler.JuliaContext() do ctx - asm, meta = CUDA.GPUCompiler.compile(:asm, job) - mod = meta.ir - + modstr, image, entry = GPUCompiler.JuliaContext() do ctx + mod, meta = GPUCompiler.compile(:llvm, job; optimize=false, cleanup=false, validate=false) + GPUCompiler.optimize_module!(job, mod) + opt_level = 2 + tm = GPUCompiler.llvm_machine(job.config.target) + LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin + LLVM.register!(pb, GPULowerCPUFeaturesPass()) + LLVM.register!(pb, GPULowerPTLSPass()) + LLVM.register!(pb, GPULowerGCFramePass()) + LLVM.register!(pb, AddKernelStatePass()) + LLVM.register!(pb, LowerKernelStatePass()) + LLVM.register!(pb, CleanupKernelStatePass()) + + LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm + GPUCompiler.buildNewPMPipeline!(mpm, job, opt_level) + end + LLVM.run!(pb, mod, tm) + end + GPUCompiler.optimize_module!(job, mod) + LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm) + + + for fname in ("gpu_report_exception", "gpu_signal_exception") + if LLVM.haskey(LLVM.functions(mod), fname) + fn = LLVM.functions(mod)[fname] + insts = LLVM.Instruction[] + for u in LLVM.uses(fn) + push!(insts, LLVM.user(u)) + end + for inst in insts + Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst) + end + Reactant.Enzyme.Compiler.eraseInst(mod, fn) + end + end + + LLVM.strip_debuginfo!(mod) modstr = string(mod) # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version # it is probably safer to reparse a string using the right llvm module api, so we will do that. - mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMToMLIR(mod::CUDA.LLVM.API.LLVMModuleRef, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule) + println(string(modstr)) + mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule) @show mmod # check if we'll need the device runtime @@ -461,8 +529,18 @@ Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwarg cache = compiler_cache(MLIR.IR.context()) source = CUDA.methodinstance(F, tt) - cuda = CUDA.active_state() - config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig + # cuda = CUDA.active_state() + device = nothing # cuda.device + # config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig + cuda_cap=v"5.0" + cuda_ptx=v"6.3" + llvm_cap=v"5.0" + llvm_ptx=v"6.3" + kernel=true + always_inline=false + name=nothing + debuginfo=false + config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline) CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) end res From ff729ce0f9f56712b38964a95e7249c765f76aaa Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 10 Dec 2024 22:40:02 -0600 Subject: [PATCH 20/78] Cleanup --- Project.toml | 5 +- ext/ReactantCUDAExt.jl | 553 ----------------------------------------- src/Reactant.jl | 3 - src/utils.jl | 84 ++++--- test/Project.toml | 1 - test/runtests.jl | 24 +- 6 files changed, 61 insertions(+), 609 deletions(-) delete mode 100644 ext/ReactantCUDAExt.jl diff --git a/Project.toml b/Project.toml index a5d243705..dd4d67325 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,6 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353" [weakdeps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" @@ -32,7 +31,6 @@ path = "lib/ReactantCore" [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" -ReactantCUDAExt = "CUDA" ReactantNNlibExt = "NNlib" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" @@ -43,7 +41,7 @@ Adapt = "4" ArrayInterface = "7.10" CEnum = "0.4, 0.5" Downloads = "1.6" -Enzyme = "0.13.21" +Enzyme = "0.13.22" EnzymeCore = "0.8.8" GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" @@ -60,5 +58,4 @@ julia = "1.10" [extras] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl deleted file mode 100644 index ad13922f3..000000000 --- a/ext/ReactantCUDAExt.jl +++ /dev/null @@ -1,553 +0,0 @@ -module ReactantCUDAExt - -using CUDA -using Reactant: - Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber -using ReactantCore: @trace - -using Adapt - -struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N} - ptr::Core.LLVMPtr{T,A} -end - - -Base.show(io::IO, a::AT) where AT <: CuTracedArray = - CUDA.Printf.@printf(io, "%s cu traced array at %p", join(size(a), '×'), Int(pointer(a))) - -## array interface - -Base.elsize(::Type{<:CuTracedArray{T}}) where {T} = sizeof(T) -Base.size(g::CuTracedArray{T,N,A,Size}) where {T,N,A,Size} = Size -Base.sizeof(x::CuTracedArray) = Base.elsize(x) * length(x) -Base.pointer(x::CuTracedArray{T,<:Any,A}) where {T,A} = Base.unsafe_convert(Core.LLVMPtr{T,A}, x) -@inline function Base.pointer(x::CuTracedArray{T,<:Any,A}, i::Integer) where {T,A} - Base.unsafe_convert(Core.LLVMPtr{T,A}, x) + Base._memory_offset(x, i) -end - - -## conversions - -Base.unsafe_convert(::Type{Core.LLVMPtr{T,A}}, x::CuTracedArray{T,<:Any,A}) where {T,A} = - x.ptr - - -## indexing intrinsics - -CUDA.@device_function @inline function arrayref(A::CuTracedArray{T}, index::Integer) where {T} - @boundscheck checkbounds(A, index) - if Base.isbitsunion(T) - arrayref_union(A, index) - else - arrayref_bits(A, index) - end -end - -@inline function arrayref_bits(A::CuTracedArray{T}, index::Integer) where {T} - unsafe_load(pointer(A), index) -end - -@inline @generated function arrayref_union(A::CuTracedArray{T,<:Any,AS}, index::Integer) where {T,AS} - typs = Base.uniontypes(T) - - # generate code that conditionally loads a value based on the selector value. - # lacking noreturn, we return T to avoid inference thinking this can return Nothing. - ex = :(Base.llvmcall("unreachable", $T, Tuple{})) - for (sel, typ) in Iterators.reverse(enumerate(typs)) - ex = quote - if selector == $(sel-1) - ptr = reinterpret(Core.LLVMPtr{$typ,AS}, data_ptr) - unsafe_load(ptr, 1) - else - $ex - end - end - end - - quote - selector_ptr = typetagdata(A, index) - selector = unsafe_load(selector_ptr) - - data_ptr = pointer(A, index) - - return $ex - end -end - -CUDA.@device_function @inline function arrayset(A::CuTracedArray{T}, x::T, index::Integer) where {T} - @boundscheck checkbounds(A, index) - if Base.isbitsunion(T) - arrayset_union(A, x, index) - else - arrayset_bits(A, x, index) - end - return A -end - -@inline function arrayset_bits(A::CuTracedArray{T}, x::T, index::Integer) where {T} - unsafe_store!(pointer(A), x, index) -end - -@inline @generated function arrayset_union(A::CuTracedArray{T,<:Any,AS}, x::T, index::Integer) where {T,AS} - typs = Base.uniontypes(T) - sel = findfirst(isequal(x), typs) - - quote - selector_ptr = typetagdata(A, index) - unsafe_store!(selector_ptr, $(UInt8(sel-1))) - - data_ptr = pointer(A, index) - - unsafe_store!(reinterpret(Core.LLVMPtr{$x,AS}, data_ptr), x, 1) - return - end -end - -CUDA.@device_function @inline function const_arrayref(A::CuTracedArray{T}, index::Integer) where {T} - @boundscheck checkbounds(A, index) - unsafe_cached_load(pointer(A), index) -end - - -## indexing - -Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear() - -Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = - arrayref(A, i1) -Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = - arrayset(A, convert(T,x)::T, i1) - -# preserve the specific integer type when indexing device arrays, -# to avoid extending 32-bit hardware indices to 64-bit. -Base.to_index(::CuTracedArray, i::Integer) = i - -# Base doesn't like Integer indices, so we need our own ND get and setindex! routines. -# See also: https://github.com/JuliaLang/julia/pull/42289 -Base.@propagate_inbounds Base.getindex(A::CuTracedArray, - I::Union{Integer, CartesianIndex}...) = - A[Base._to_linear_index(A, to_indices(A, I)...)] -Base.@propagate_inbounds Base.setindex!(A::CuTracedArray, x, - I::Union{Integer, CartesianIndex}...) = - A[Base._to_linear_index(A, to_indices(A, I)...)] = x - - -## const indexing - -""" - Const(A::CuTracedArray) - -Mark a CuTracedArray as constant/read-only. The invariant guaranteed is that you will not -modify an CuTracedArray for the duration of the current kernel. - -This API can only be used on devices with compute capability 3.5 or higher. - -!!! warning - Experimental API. Subject to change without deprecation. -""" -struct Const{T,N,AS} <: DenseArray{T,N} - a::CuTracedArray{T,N,AS} -end -Base.Experimental.Const(A::CuTracedArray) = Const(A) - -Base.IndexStyle(::Type{<:Const}) = IndexLinear() -Base.size(C::Const) = size(C.a) -Base.axes(C::Const) = axes(C.a) -Base.@propagate_inbounds Base.getindex(A::Const, i1::Integer) = const_arrayref(A.a, i1) - -# deprecated -Base.@propagate_inbounds ldg(A::CuTracedArray, i1::Integer) = const_arrayref(A, i1) - - -## other - -@inline function Base.iterate(A::CuTracedArray, i=1) - if (i % UInt) - 1 < length(A) - (@inbounds A[i], i + 1) - else - nothing - end -end - -function Base.reinterpret(::Type{T}, a::CuTracedArray{S,N,A}) where {T,S,N,A} - err = GPUArrays._reinterpret_exception(T, a) - err === nothing || throw(err) - - if sizeof(T) == sizeof(S) # fast case - return CuTracedArray{T,N,A}(reinterpret(Core.LLVMPtr{T,A}, a.ptr), size(a), a.maxsize) - end - - isize = size(a) - size1 = div(isize[1]*sizeof(S), sizeof(T)) - osize = tuple(size1, Base.tail(isize)...) - return CuTracedArray{T,N,A}(reinterpret(Core.LLVMPtr{T,A}, a.ptr), osize, a.maxsize) -end - - -## reshape - -function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M,A} - if prod(dims) != length(a) - throw(DimensionMismatch("new dimensions (argument `dims`) must be consistent with array size (`size(a)`)")) - end - if N == M && dims == size(a) - return a - end - _derived_array(a, T, dims) -end - - - -function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} - res = CuTracedArray{T,N,CUDA.AS.Global, size(xs)}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))) - @show res, xs - return res -end - -const _kernel_instances = Dict{Any, Any}() - -struct LLVMFunc{F,tt} - f::Union{F, Nothing} - mod::String - image - entry::String -end - - -const GPUCompiler = CUDA.GPUCompiler -const LLVM = GPUCompiler.LLVM - - -GPULowerCPUFeaturesPass() = LLVM.NewPMModulePass("GPULowerCPUFeatures", GPUCompiler.cpu_features!) -GPULowerPTLSPass() = LLVM.NewPMModulePass("GPULowerPTLS", GPUCompiler.lower_ptls!) -GPULowerGCFramePass() = LLVM.NewPMFunctionPass("GPULowerGCFrame", GPUCompiler.lower_gc_frame!) -function noop_pass(x) - return false -end -function kern_pass(mod) - for fname in ("julia.gpu.state_getter",) - if LLVM.haskey(LLVM.functions(mod), fname) - fn = LLVM.functions(mod)[fname] - insts = LLVM.Instruction[] - for u in LLVM.uses(fn) - u = LLVM.user(u) - LLVM.replace_uses!(u, LLVM.UndefValue(LLVM.value_type(u))) - push!(insts, u) - end - for inst in insts - Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst) - end - Reactant.Enzyme.Compiler.eraseInst(mod, fn) - end - end - - return true -end -AddKernelStatePass() = LLVM.NewPMModulePass("AddKernelStatePass", kern_pass) -LowerKernelStatePass() = LLVM.NewPMFunctionPass("LowerKernelStatePass", noop_pass) -CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_pass) - -# compile to executable machine code -function compile(job) - - # lower to PTX - # TODO: on 1.9, this actually creates a context. cache those. - modstr, image, entry = GPUCompiler.JuliaContext() do ctx - mod, meta = GPUCompiler.compile(:llvm, job; optimize=false, cleanup=false, validate=false) - GPUCompiler.optimize_module!(job, mod) - opt_level = 2 - tm = GPUCompiler.llvm_machine(job.config.target) - LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin - LLVM.register!(pb, GPULowerCPUFeaturesPass()) - LLVM.register!(pb, GPULowerPTLSPass()) - LLVM.register!(pb, GPULowerGCFramePass()) - LLVM.register!(pb, AddKernelStatePass()) - LLVM.register!(pb, LowerKernelStatePass()) - LLVM.register!(pb, CleanupKernelStatePass()) - - LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm - GPUCompiler.buildNewPMPipeline!(mpm, job, opt_level) - end - LLVM.run!(pb, mod, tm) - end - GPUCompiler.optimize_module!(job, mod) - LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm) - - - for fname in ("gpu_report_exception", "gpu_signal_exception") - if LLVM.haskey(LLVM.functions(mod), fname) - fn = LLVM.functions(mod)[fname] - insts = LLVM.Instruction[] - for u in LLVM.uses(fn) - push!(insts, LLVM.user(u)) - end - for inst in insts - Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst) - end - Reactant.Enzyme.Compiler.eraseInst(mod, fn) - end - end - - LLVM.strip_debuginfo!(mod) - modstr = string(mod) - - # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version - # it is probably safer to reparse a string using the right llvm module api, so we will do that. - - println(string(modstr)) - mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule) - @show mmod - - # check if we'll need the device runtime - undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f - CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) - end - intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail", - "__nvvm_reflect" #= TODO: should have been optimized away =#] - needs_cudadevrt = !isempty(setdiff(CUDA.LLVM.name.(undefined_fs), intrinsic_fns)) - - # prepare invocations of CUDA compiler tools - ptxas_opts = String[] - nvlink_opts = String[] - ## debug flags - if Base.JLOptions().debug_level == 1 - push!(ptxas_opts, "--generate-line-info") - elseif Base.JLOptions().debug_level >= 2 - push!(ptxas_opts, "--device-debug") - push!(nvlink_opts, "--debug") - end - ## relocatable device code - if needs_cudadevrt - push!(ptxas_opts, "--compile-only") - end - - ptx = job.config.params.ptx - cap = job.config.params.cap - arch = "sm_$(cap.major)$(cap.minor)" - - # validate use of parameter memory - argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt - !CUDA.isghosttype(dt) && !Core.Compiler.isconstType(dt) - end - param_usage = sum(sizeof, argtypes) - param_limit = 4096 - if cap >= v"7.0" && ptx >= v"8.1" - param_limit = 32764 - end - if param_usage > param_limit - msg = """Kernel invocation uses too much parameter memory. - $(Base.format_bytes(param_usage)) exceeds the $(Base.format_bytes(param_limit)) limit imposed by sm_$(cap.major)$(cap.minor) / PTX v$(ptx.major).$(ptx.minor).""" - - try - details = "\n\nRelevant parameters:" - - source_types = job.source.specTypes.parameters - source_argnames = Base.method_argnames(job.source.def) - while length(source_argnames) < length(source_types) - # this is probably due to a trailing vararg; repeat its name - push!(source_argnames, source_argnames[end]) - end - - for (i, typ) in enumerate(source_types) - if CUDA.isghosttype(typ) || Core.Compiler.isconstType(typ) - continue - end - name = source_argnames[i] - details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))" - end - details *= "\n" - - if cap >= v"7.0" && ptx < v"8.1" && param_usage < 32764 - details *= "\nNote: use a newer CUDA to support more parameters on your device.\n" - end - - msg *= details - catch err - @error "Failed to analyze kernel parameter usage; please file an issue with a reproducer." - end - error(msg) - end - - # compile to machine code - # NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow - ptx_input = tempname(cleanup=false) * ".ptx" - ptxas_output = tempname(cleanup=false) * ".cubin" - write(ptx_input, asm) - - # we could use the driver's embedded JIT compiler, but that has several disadvantages: - # 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to - # upgrade the toolkit to get a newer compiler; - # 2. version checking is simpler, we otherwise need to use NVML to query the driver - # version, which is hard to correlate to PTX JIT improvements; - # 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an - # older driver, we should use the newer compiler to ensure compatibility. - append!(ptxas_opts, [ - "--verbose", - "--gpu-name", arch, - "--output-file", ptxas_output, - ptx_input - ]) - proc, log = CUDA.run_and_collect(`$(CUDA.ptxas()) $ptxas_opts`) - log = strip(log) - if !success(proc) - reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" : - "ptxas exited with code $(proc.exitcode)" - msg = "Failed to compile PTX code ($reason)" - msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))" - if !isempty(log) - msg *= "\n" * log - end - msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)" - if parse(Bool, get(ENV, "BUILDKITE", "false")) - run(`buildkite-agent artifact upload $(ptx_input)`) - end - error(msg) - elseif !isempty(log) - @debug "PTX compiler log:\n" * log - end - rm(ptx_input) - - # link device libraries, if necessary - # - # this requires relocatable device code, which prevents certain optimizations and - # hurts performance. as such, we only do so when absolutely necessary. - # TODO: try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`. - # fails with `Ignoring -lto option because no LTO objects found` - if needs_cudadevrt - nvlink_output = tempname(cleanup=false) * ".cubin" - append!(nvlink_opts, [ - "--verbose", "--extra-warnings", - "--arch", arch, - "--library-path", dirname(libcudadevrt), - "--library", "cudadevrt", - "--output-file", nvlink_output, - ptxas_output - ]) - proc, log = run_and_collect(`$(CUDA.nvlink()) $nvlink_opts`) - log = strip(log) - if !success(proc) - reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" : - "nvlink exited with code $(proc.exitcode)" - msg = "Failed to link PTX code ($reason)" - msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))" - if !isempty(log) - msg *= "\n" * log - end - msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)" - error(msg) - elseif !isempty(log) - @debug "PTX linker info log:\n" * log - end - rm(ptxas_output) - - image = read(nvlink_output) - rm(nvlink_output) - else - image = read(ptxas_output) - rm(ptxas_output) - end - - modstr, image, meta.entry - end - LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, modstr, image, CUDA.LLVM.name(entry)) -end - -# link into an executable kernel -function link(job, compiled) - # load as an executable kernel object - return compiled -end - -function transpose_val(val) - attr = MLIR.IR.DenseArrayAttribute( - Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...] - ) - return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) -end - -function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, - cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} - @show args - @show call_kwargs - - blockdim = CUDA.CuDim3(blocks) - threaddim = CUDA.CuDim3(threads) - - mlir_args = MLIR.IR.Value[] - restys = MLIR.IR.Type[] - aliases = MLIR.IR.Attribute[] - rarrays = TracedRArray[] - for (i, a) in enumerate(args) - @show a - @assert a isa CuTracedArray - ta = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray - push!(rarrays, ta) - arg = ta.mlir_data - arg = transpose_val(arg) - @show arg - push!(restys, MLIR.IR.type(arg)) - push!(mlir_args, arg) - push!(aliases, - MLIR.IR.Attribute(MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), - length(args) == 1 ? 0 : 1, - length(args) == 1 ? C_NULL : Ref{Int64}(i-1), - i-1, - 0, - C_NULL - )) - ) - end - - output_operand_aliases=MLIR.IR.Attribute(aliases) - call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute("configstr")) - # call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod)) - for (i, res) in enumerate(rarrays) - res.mlir_data = transpose_val(MLIR.IR.result(call, i)) - end - #CUDA.cuLaunchKernel(f, - # blockdim.x, blockdim.y, blockdim.z, - # threaddim.x, threaddim.y, threaddim.z, - # shmem, stream, kernelParams, C_NULL) -end - -# cache of compilation caches, per context -const _compiler_caches = Dict{MLIR.IR.Context, Dict{Any, LLVMFunc}}(); -function compiler_cache(ctx::MLIR.IR.Context) - cache = get(_compiler_caches, ctx, nothing) - if cache === nothing - cache = Dict{Any, LLVMFunc}() - _compiler_caches[ctx] = cache - end - return cache -end - -Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} - @show "recufunction", f, tt - res = Base.@lock CUDA.cufunction_lock begin - # compile the function - cache = compiler_cache(MLIR.IR.context()) - source = CUDA.methodinstance(F, tt) - - # cuda = CUDA.active_state() - device = nothing # cuda.device - # config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig - cuda_cap=v"5.0" - cuda_ptx=v"6.3" - llvm_cap=v"5.0" - llvm_ptx=v"6.3" - kernel=true - always_inline=false - name=nothing - debuginfo=false - config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline) - CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) - end - res -end - -function __init__() - -end - -end # module ReactantCUDAExt diff --git a/src/Reactant.jl b/src/Reactant.jl index 1623503df..06fd59aff 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -124,7 +124,4 @@ function set_default_backend(backend::String) return set_default_backend(XLA.backends[backend]) end -# include("../ext/ReactantCUDAExt.jl") - end # module - diff --git a/src/utils.jl b/src/utils.jl index db1de4505..c165ac402 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -60,6 +60,8 @@ end function rewrite_inst(inst, ir) if Meta.isexpr(inst, :call) + # Even if type unstable we do not want (or need) to replace intrinsic + # calls or builtins with our version. ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) if !(ft <: Core.IntrinsicFunction) && !(ft <: Core.Builtin) rep = Expr(:call, call_with_reactant, inst.args...) @@ -74,6 +76,8 @@ end const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") +# From Julia's Base.Meta with fix from https://github.com/JuliaLang/julia/pull/56787 +# and additionally adds support for an argument rewriting into a slot function arg_partially_inline!(code::Vector{Any}, slot_replacements::Vector{Any}, @nospecialize(type_signature)#=::Type{<:Tuple}=#, static_param_values::Vector{Any}, @@ -218,12 +222,35 @@ function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any} return x end + +""" + Reactant.REDUB_ARGUMENTS_NAME + +The variable name bound to `call_with_reactant`'s tuple of arguments in its +`@generated` method definition. + +This binding can be used to manually reference/destructure `call_with_reactants` arguments + +This is required because user arguments could have a name which clashes with whatever name we choose for +our argument. Thus we gensym to create it. + +This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0deda3a1037ec519eebe216952bfe/src/overdub.jl#L154 +""" +const OVERDUB_ARGUMENTS_NAME = gensym("overdub_arguments") + +# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter +# In particular this entails two pieces: +# 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance +# 2) Post type inference (using of course the reactant interpreter), all type unstable call functions are +# replaced with calls to `call_with_reactant`. This allows us to circumvent long standing issues in Julia +# using a custom interpreter in type unstable code. +# `redub_arguments` is `(typeof(original_function), map(typeof, original_args_tuple)...)` function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @nospecialize(redub_arguments)) @nospecialize args = redub_arguments - stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, :redub_arguments), Core.svec()) + stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, OVERDUB_ARGUMENTS_NAME), Core.svec()) # look up the method match builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $redub_arguments"))) @@ -248,6 +275,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, matches = lookup_result.matches + # No method could be found (including in our method table), bail with an error if length(matches) != 1 return stub(world, source, method_error) end @@ -269,18 +297,15 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @assert Core.Compiler.is_inferred(frame) method = match.method - @show mi - @show method + # The original julia code (on 1.11+) has the potential constprop, for now + # we assume this outermost function does not constprop, for ease. #if Core.Compiler.result_is_constabi(interp, frame.result) # rt = frame.result.result::Core.Compiler.Const # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) #else opt = Core.Compiler.OptimizationState(frame, interp) - @show Core.Compiler.retrieve_code_info(mi, world) - @show opt.src - caller = frame.result @static if VERSION < v"1.11-" ir = Core.Compiler.run_passes(opt.src, opt, caller) @@ -289,11 +314,13 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end - @show ir + # Rewrite type unstable calls to recurse into call_with_reactant to ensure + # they continue to use our interpreter. Reset the derived return type + # to Any if our interpreter would change the return type of any result. + # Also rewrite invoke (type stable call) to be :call, since otherwise apparently + # screws up type inference after this (TODO this should be fixed). any_changed = false - for (i, inst) in enumerate(ir.stmts) - - + for (i, inst) in enumerate(ir.stmts) @static if VERSION < v"1.11" changed, next = rewrite_inst(inst[:inst], ir) Core.Compiler.setindex!(ir.stmts[i], next, :inst) @@ -307,10 +334,12 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, end end Core.Compiler.finish(interp, opt, ir, caller) - @show "post", ir src = Core.Compiler.ir_to_codeinf!(opt) - @show any_changed, src + # Julia hits various internal errors trying to re-perform type inference + # on type infered code (that we undo inference of), if there is no type unstable + # code to be rewritten, just use the default methodinstance (still using our methodtable), + # to improve compatibility as these bugs are fixed upstream. if !any_changed src = Core.Compiler.retrieve_code_info(mi, world) @show "post non change", src @@ -322,15 +351,16 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, signature = sig is_invoke = args[1] === typeof(Core.invoke) - # propagate edge metadata + # propagate edge metadata, this method is invalidated if the original function we are calling + # is invalidated code_info.edges = Core.MethodInstance[mi] code_info.min_world = lookup_result.valid_worlds.min_world code_info.max_world = lookup_result.valid_worlds.max_world - code_info.slotnames = Any[:call_with_reactant, :redub_arguments, code_info.slotnames...] + # Rewrite the arguments to this function, to prepend the two new arguments, the function :call_with_reactant, + # and the REDUB_ARGUMENTS_NAME tuple of input arguments + code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames...] code_info.slotflags = UInt8[0x00, 0x00, code_info.slotflags...] - #code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] #code_info.slotnames...] - #code_info.slotflags = UInt8[0x00, 0x00] # code_info.slotflags...] n_prepended_slots = 2 overdub_args_slot = Core.SlotNumber(n_prepended_slots) @@ -339,6 +369,9 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, # the end of the pass, we'll reset `code_info` fields accordingly. overdubbed_code = Any[] overdubbed_codelocs = Int32[] + + # Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention + # required by the base method. # destructure the generated argument slots into the overdubbed method's argument slots. n_actual_args = fieldcount(signature) @@ -384,24 +417,12 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, push!(fn_args, Core.SSAValue(length(overdubbed_code))) end - @show code_info.code - #=== finish initialization of `overdubbed_code`/`overdubbed_codelocs` ===# - # substitute static parameters, offset slot numbers by number of added slots, and # offset statement indices by the number of additional statements arg_partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...], n_prepended_slots, n_prepended_slots, length(overdubbed_code), :propagate) - #callexpr = Expr(:call, Core.OpaqueClosure(ir), fn_args...) - #push!(overdubbed_code, callexpr) - #push!(overdubbed_codelocs, code_info.codelocs[1]) - - #push!(new_ci.code, Core.Compiler.ReturnNode(Core.SSAValue(length(overdubbed_code)))) - #push!(overdubbed_codelocs, code_info.codelocs[1]) - - # original_code_start_index = length(overdubbed_code) + 1 - append!(overdubbed_code, code_info.code) append!(overdubbed_codelocs, code_info.codelocs) @@ -416,12 +437,10 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - @show code_info - return code_info end -@eval function call_with_reactant(redub_arguments...) +@eval function call_with_reactant($OVERDUB_ARGUMENTS_NAME...) $(Expr(:meta, :generated_only)) $(Expr(:meta, :generated, call_with_reactant_generator)) end @@ -517,9 +536,6 @@ function make_mlir_fn( end end - interp = ReactantInterpreter() - - # TODO replace with `Base.invoke_within` if julia#52964 lands # TODO fix it for kwargs if f === Reactant.apply call_with_reactant(f, traced_args[1], (traced_args[2:end]...,)) diff --git a/test/Project.toml b/test/Project.toml index 9956337ea..4b50a487f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,6 @@ [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/test/runtests.jl b/test/runtests.jl index 87e1a3702..fddc963ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,8 +41,6 @@ end const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) -include("cuda.jl") -@static if false @testset "Reactant.jl Tests" begin if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" @safetestset "Layout" include("layout.jl") @@ -62,19 +60,17 @@ include("cuda.jl") if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" @safetestset "Linear Algebra" include("integration/linear_algebra.jl") - @safetestset "CUDA" include("cuda.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") end - # if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" - # @testset "Neural Networks" begin - # @safetestset "NNlib Primitives" include("nn/nnlib.jl") - # @safetestset "Flux.jl Integration" include("nn/flux.jl") - # if Sys.islinux() - # @safetestset "LuxLib Primitives" include("nn/luxlib.jl") - # @safetestset "Lux Integration" include("nn/lux.jl") - # end - # end - # end -end + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" + @testset "Neural Networks" begin + @safetestset "NNlib Primitives" include("nn/nnlib.jl") + @safetestset "Flux.jl Integration" include("nn/flux.jl") + if Sys.islinux() + @safetestset "LuxLib Primitives" include("nn/luxlib.jl") + @safetestset "Lux Integration" include("nn/lux.jl") + end + end + end end From 3bd56081e8918c67d7725e8388d042bc1ef85b9f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 10 Dec 2024 22:44:29 -0600 Subject: [PATCH 21/78] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Interpreter.jl | 1 - src/utils.jl | 250 +++++++++++++++++++++++++++++++-------------- 2 files changed, 173 insertions(+), 78 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 9708a76f9..675f9036d 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -21,7 +21,6 @@ import Core.Compiler: mapany, MethodResultPure - Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE) function var"@reactant_override"(__source__::LineNumberNode, __module__::Module, def) diff --git a/src/utils.jl b/src/utils.jl index c165ac402..afc4bbbce 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -41,15 +41,12 @@ function call_with_reactant end # generate a LineInfoNode for the current source code location macro LineInfoNode(method) - Core.LineInfoNode(__module__, method, __source__.file, Int32(__source__.line), Int32(0)) + return Core.LineInfoNode( + __module__, method, __source__.file, Int32(__source__.line), Int32(0) + ) end - - -function maybe_argextype( - @nospecialize(x), - src, -) +function maybe_argextype(@nospecialize(x), src) return try Core.Compiler.argextype(x, src) catch err @@ -59,43 +56,61 @@ function maybe_argextype( end function rewrite_inst(inst, ir) - if Meta.isexpr(inst, :call) - # Even if type unstable we do not want (or need) to replace intrinsic - # calls or builtins with our version. - ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) - if !(ft <: Core.IntrinsicFunction) && !(ft <: Core.Builtin) - rep = Expr(:call, call_with_reactant, inst.args...) - return true, rep - end - end - if Meta.isexpr(inst, :invoke) - return false, Expr(:call, inst.args[2:end]...) - end - return false, inst + if Meta.isexpr(inst, :call) + # Even if type unstable we do not want (or need) to replace intrinsic + # calls or builtins with our version. + ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) + if !(ft <: Core.IntrinsicFunction) && !(ft <: Core.Builtin) + rep = Expr(:call, call_with_reactant, inst.args...) + return true, rep + end + end + if Meta.isexpr(inst, :invoke) + return false, Expr(:call, inst.args[2:end]...) + end + return false, inst end const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") # From Julia's Base.Meta with fix from https://github.com/JuliaLang/julia/pull/56787 # and additionally adds support for an argument rewriting into a slot -function arg_partially_inline!(code::Vector{Any}, slot_replacements::Vector{Any}, - @nospecialize(type_signature)#=::Type{<:Tuple}=#, - static_param_values::Vector{Any}, - slot_offset::Int, arg_offset::Int, statement_offset::Int, - boundscheck::Symbol) - for i = 1:length(code) +function arg_partially_inline!( + code::Vector{Any}, + slot_replacements::Vector{Any}, + @nospecialize(type_signature), #=::Type{<:Tuple}=# + static_param_values::Vector{Any}, + slot_offset::Int, + arg_offset::Int, + statement_offset::Int, + boundscheck::Symbol, +) + for i in 1:length(code) isassigned(code, i) || continue - code[i] = _arg_partially_inline!(code[i], slot_replacements, type_signature, - static_param_values, slot_offset, arg_offset, - statement_offset, boundscheck) + code[i] = _arg_partially_inline!( + code[i], + slot_replacements, + type_signature, + static_param_values, + slot_offset, + arg_offset, + statement_offset, + boundscheck, + ) end return code end -function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any}, - @nospecialize(type_signature), static_param_values::Vector{Any}, - slot_offset::Int, arg_offset::Int, statement_offset::Int, - boundscheck::Symbol) +function _arg_partially_inline!( + @nospecialize(x), + slot_replacements::Vector{Any}, + @nospecialize(type_signature), + static_param_values::Vector{Any}, + slot_offset::Int, + arg_offset::Int, + statement_offset::Int, + boundscheck::Symbol, +) if isa(x, Core.SSAValue) return Core.SSAValue(x.id + statement_offset) end @@ -110,33 +125,66 @@ function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any} return Core.SlotNumber(id + slot_offset) end if isa(x, Core.Argument) - return Core.SlotNumber(x.n + arg_offset) + return Core.SlotNumber(x.n + arg_offset) end if isa(x, Core.NewvarNode) - return Core.NewvarNode(_arg_partially_inline!(x.slot, slot_replacements, type_signature, - static_param_values, slot_offset, arg_offset, - statement_offset, boundscheck)) + return Core.NewvarNode( + _arg_partially_inline!( + x.slot, + slot_replacements, + type_signature, + static_param_values, + slot_offset, + arg_offset, + statement_offset, + boundscheck, + ), + ) end if isa(x, Core.PhiNode) - arg_partially_inline!(x.values, slot_replacements, type_signature, static_param_values, - slot_offset, arg_offset, statement_offset, boundscheck) + arg_partially_inline!( + x.values, + slot_replacements, + type_signature, + static_param_values, + slot_offset, + arg_offset, + statement_offset, + boundscheck, + ) x.edges .+= slot_offset return x end if isa(x, Core.ReturnNode) - if !isdefined(x, :val) - return Core.ReturnNode(:nothing) - else - return Core.ReturnNode( - _arg_partially_inline!(x.val, slot_replacements, type_signature, static_param_values, - slot_offset, arg_offset, statement_offset, boundscheck), - ) - end + if !isdefined(x, :val) + return Core.ReturnNode(:nothing) + else + return Core.ReturnNode( + _arg_partially_inline!( + x.val, + slot_replacements, + type_signature, + static_param_values, + slot_offset, + arg_offset, + statement_offset, + boundscheck, + ), + ) + end end if isa(x, Core.GotoIfNot) return Core.GotoIfNot( - _arg_partially_inline!(x.cond, slot_replacements, type_signature, static_param_values, - slot_offset, arg_offset, statement_offset, boundscheck), + _arg_partially_inline!( + x.cond, + slot_replacements, + type_signature, + static_param_values, + slot_offset, + arg_offset, + statement_offset, + boundscheck, + ), x.dest + statement_offset, ) end @@ -153,28 +201,59 @@ function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any} elseif head === :cfunction @assert !isa(type_signature, UnionAll) || !isempty(spvals) if !isa(x.args[2], QuoteNode) # very common no-op - x.args[2] = Core.Compiler._partially_inline!(x.args[2], slot_replacements, type_signature, - static_param_values, slot_offset, arg_offset, - statement_offset, boundscheck) + x.args[2] = Core.Compiler._partially_inline!( + x.args[2], + slot_replacements, + type_signature, + static_param_values, + slot_offset, + arg_offset, + statement_offset, + boundscheck, + ) end - x.args[3] = Core.Compiler._instantiate_type_in_env(x.args[3], type_signature, static_param_values) - x.args[4] = Core.svec(Any[Core.Compiler._instantiate_type_in_env(argt, type_signature, static_param_values) for argt in x.args[4]]...) + x.args[3] = Core.Compiler._instantiate_type_in_env( + x.args[3], type_signature, static_param_values + ) + x.args[4] = Core.svec( + Any[ + Core.Compiler._instantiate_type_in_env( + argt, type_signature, static_param_values + ) for argt in x.args[4] + ]..., + ) elseif head === :foreigncall @assert !isa(type_signature, UnionAll) || !isempty(static_param_values) - for i = 1:length(x.args) + for i in 1:length(x.args) if i == 2 - x.args[2] = Core.Compiler._instantiate_type_in_env(x.args[2], type_signature, static_param_values) + x.args[2] = Core.Compiler._instantiate_type_in_env( + x.args[2], type_signature, static_param_values + ) elseif i == 3 - x.args[3] = Core.svec(Any[Core.Compiler._instantiate_type_in_env(argt, type_signature, static_param_values) for argt in x.args[3]]...) + x.args[3] = Core.svec( + Any[ + Core.Compiler._instantiate_type_in_env( + argt, type_signature, static_param_values + ) for argt in x.args[3] + ]..., + ) elseif i == 4 @assert isa(x.args[4], Int) elseif i == 5 - @assert isa((x.args[5]::QuoteNode).value, Union{Symbol, Tuple{Symbol, UInt8}}) + @assert isa( + (x.args[5]::QuoteNode).value, Union{Symbol,Tuple{Symbol,UInt8}} + ) else - x.args[i] = _arg_partially_inline!(x.args[i], slot_replacements, - type_signature, static_param_values, - slot_offset, statement_offset, arg_offset, - boundscheck) + x.args[i] = _arg_partially_inline!( + x.args[i], + slot_replacements, + type_signature, + static_param_values, + slot_offset, + statement_offset, + arg_offset, + boundscheck, + ) end end elseif head === :boundscheck @@ -186,9 +265,16 @@ function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any} return true end elseif head === :gotoifnot - x.args[1] = _arg_partially_inline!(x.args[1], slot_replacements, type_signature, - static_param_values, slot_offset, arg_offset, - statement_offset, boundscheck) + x.args[1] = _arg_partially_inline!( + x.args[1], + slot_replacements, + type_signature, + static_param_values, + slot_offset, + arg_offset, + statement_offset, + boundscheck, + ) x.args[2] += statement_offset elseif head === :isdefined arg = x.args[1] @@ -197,7 +283,7 @@ function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any} id = arg.id if 1 <= id <= length(slot_replacements) replacement = slot_replacements[id] - if isa(replacement, Union{Core.SlotNumber, GlobalRef, Symbol}) + if isa(replacement, Union{Core.SlotNumber,GlobalRef,Symbol}) return Expr(:isdefined, replacement) else @assert !isa(replacement, Expr) @@ -211,18 +297,25 @@ function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any} end return x else - @assert isa(arg, Union{GlobalRef, Symbol}) + @assert isa(arg, Union{GlobalRef,Symbol}) return x end elseif !Core.Compiler.is_meta_expr_head(head) - arg_partially_inline!(x.args, slot_replacements, type_signature, static_param_values, - slot_offset, arg_offset, statement_offset, boundscheck) + arg_partially_inline!( + x.args, + slot_replacements, + type_signature, + static_param_values, + slot_offset, + arg_offset, + statement_offset, + boundscheck, + ) end end return x end - """ Reactant.REDUB_ARGUMENTS_NAME @@ -245,20 +338,24 @@ const OVERDUB_ARGUMENTS_NAME = gensym("overdub_arguments") # replaced with calls to `call_with_reactant`. This allows us to circumvent long standing issues in Julia # using a custom interpreter in type unstable code. # `redub_arguments` is `(typeof(original_function), map(typeof, original_args_tuple)...)` -function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @nospecialize(redub_arguments)) +function call_with_reactant_generator( + world::UInt, source::LineNumberNode, self, @nospecialize(redub_arguments) +) @nospecialize - args = redub_arguments - stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, OVERDUB_ARGUMENTS_NAME), Core.svec()) + stub = Core.GeneratedFunctionStub( + identity, Core.svec(:call_with_reactant, OVERDUB_ARGUMENTS_NAME), Core.svec() + ) # look up the method match - builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $redub_arguments"))) - + builtin_error = :(throw( + AssertionError("Unsupported call_with_reactant of builtin $redub_arguments") + )) + if args[1] <: Core.Builtin return stub(world, source, builtin_error) end - method_error = :(throw(MethodError(args[1], args[2:end], $world))) interp = ReactantInterpreter(; world) @@ -281,7 +378,6 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, end match = matches[1]::Core.MethodMatch - # look up the method and code instance mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any), match.method, match.spec_types, match.sparams) From 5e33afbdcdfc1fd8b133dae5e3e15a7a74a8c086 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 10 Dec 2024 22:46:21 -0600 Subject: [PATCH 22/78] Delete test/cuda.jl --- test/cuda.jl | 25 ------------------------- 1 file changed, 25 deletions(-) delete mode 100644 test/cuda.jl diff --git a/test/cuda.jl b/test/cuda.jl deleted file mode 100644 index ae1b473f6..000000000 --- a/test/cuda.jl +++ /dev/null @@ -1,25 +0,0 @@ -using Reactant -using Test -using CUDA - -function square_kernel!(x) - i = threadIdx().x - x[i] *= x[i] - sync_threads() - return nothing -end - -# basic squaring on GPU -function square!(x) - @cuda blocks = 1 threads = length(x) square_kernel!(x) - return nothing -end - -@testset "Square Kernel" begin - oA = collect(1:1:64) - A = Reactant.to_rarray(oA) - @show @code_hlo optimize=false square!(A) - @show @code_hlo square!(A) - func = @compile square!(A) - @test all(Array(A) .≈ (oA .* oA)) -end From 17c2f72ae9a809696dfb1305877db351b21ddb1a Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 10 Dec 2024 22:57:40 -0600 Subject: [PATCH 23/78] fixup --- src/utils.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index afc4bbbce..5e46c143f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -212,12 +212,12 @@ function _arg_partially_inline!( boundscheck, ) end - x.args[3] = Core.Compiler._instantiate_type_in_env( + x.args[3] = Base.Meta._instantiate_type_in_env( x.args[3], type_signature, static_param_values ) x.args[4] = Core.svec( Any[ - Core.Compiler._instantiate_type_in_env( + Base.Meta._instantiate_type_in_env( argt, type_signature, static_param_values ) for argt in x.args[4] ]..., @@ -226,13 +226,13 @@ function _arg_partially_inline!( @assert !isa(type_signature, UnionAll) || !isempty(static_param_values) for i in 1:length(x.args) if i == 2 - x.args[2] = Core.Compiler._instantiate_type_in_env( + x.args[2] = Base.Meta._instantiate_type_in_env( x.args[2], type_signature, static_param_values ) elseif i == 3 x.args[3] = Core.svec( Any[ - Core.Compiler._instantiate_type_in_env( + Base.Meta._instantiate_type_in_env( argt, type_signature, static_param_values ) for argt in x.args[3] ]..., @@ -438,7 +438,6 @@ function call_with_reactant_generator( # to improve compatibility as these bugs are fixed upstream. if !any_changed src = Core.Compiler.retrieve_code_info(mi, world) - @show "post non change", src end # prepare a new code info From 4807a7984f84161fa49b2a5f0d8474a2d136347f Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 10 Dec 2024 23:01:45 -0600 Subject: [PATCH 24/78] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/utils.jl | 141 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 83 insertions(+), 58 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 5e46c143f..c423c9ca7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -379,11 +379,17 @@ function call_with_reactant_generator( match = matches[1]::Core.MethodMatch # look up the method and code instance - mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, - (Any, Any, Any), match.method, match.spec_types, match.sparams) - + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, + ) + result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) - frame = Core.Compiler.InferenceState(result, #=cache_mode=#:local, interp) + frame = Core.Compiler.InferenceState(result, :local, interp) #=cache_mode=# @assert frame !== nothing Core.Compiler.typeinf(interp, frame) @static if VERSION >= v"1.11" @@ -400,45 +406,45 @@ function call_with_reactant_generator( # rt = frame.result.result::Core.Compiler.Const # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) #else - opt = Core.Compiler.OptimizationState(frame, interp) - - caller = frame.result - @static if VERSION < v"1.11-" - ir = Core.Compiler.run_passes(opt.src, opt, caller) - else - ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) - Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) - end - - # Rewrite type unstable calls to recurse into call_with_reactant to ensure - # they continue to use our interpreter. Reset the derived return type - # to Any if our interpreter would change the return type of any result. - # Also rewrite invoke (type stable call) to be :call, since otherwise apparently - # screws up type inference after this (TODO this should be fixed). - any_changed = false - for (i, inst) in enumerate(ir.stmts) - @static if VERSION < v"1.11" - changed, next = rewrite_inst(inst[:inst], ir) - Core.Compiler.setindex!(ir.stmts[i], next, :inst) - else - changed, next = rewrite_inst(inst[:stmt], ir) - Core.Compiler.setindex!(ir.stmts[i], next, :stmt) - end - if changed - any_changed = true - Core.Compiler.setindex!(ir.stmts[i], Any, :type) - end - end - Core.Compiler.finish(interp, opt, ir, caller) - src = Core.Compiler.ir_to_codeinf!(opt) - - # Julia hits various internal errors trying to re-perform type inference - # on type infered code (that we undo inference of), if there is no type unstable - # code to be rewritten, just use the default methodinstance (still using our methodtable), - # to improve compatibility as these bugs are fixed upstream. - if !any_changed - src = Core.Compiler.retrieve_code_info(mi, world) - end + opt = Core.Compiler.OptimizationState(frame, interp) + + caller = frame.result + @static if VERSION < v"1.11-" + ir = Core.Compiler.run_passes(opt.src, opt, caller) + else + ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) + Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) + end + + # Rewrite type unstable calls to recurse into call_with_reactant to ensure + # they continue to use our interpreter. Reset the derived return type + # to Any if our interpreter would change the return type of any result. + # Also rewrite invoke (type stable call) to be :call, since otherwise apparently + # screws up type inference after this (TODO this should be fixed). + any_changed = false + for (i, inst) in enumerate(ir.stmts) + @static if VERSION < v"1.11" + changed, next = rewrite_inst(inst[:inst], ir) + Core.Compiler.setindex!(ir.stmts[i], next, :inst) + else + changed, next = rewrite_inst(inst[:stmt], ir) + Core.Compiler.setindex!(ir.stmts[i], next, :stmt) + end + if changed + any_changed = true + Core.Compiler.setindex!(ir.stmts[i], Any, :type) + end + end + Core.Compiler.finish(interp, opt, ir, caller) + src = Core.Compiler.ir_to_codeinf!(opt) + + # Julia hits various internal errors trying to re-perform type inference + # on type infered code (that we undo inference of), if there is no type unstable + # code to be rewritten, just use the default methodinstance (still using our methodtable), + # to improve compatibility as these bugs are fixed upstream. + if !any_changed + src = Core.Compiler.retrieve_code_info(mi, world) + end # prepare a new code info code_info = copy(src) @@ -454,7 +460,9 @@ function call_with_reactant_generator( # Rewrite the arguments to this function, to prepend the two new arguments, the function :call_with_reactant, # and the REDUB_ARGUMENTS_NAME tuple of input arguments - code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames...] + code_info.slotnames = Any[ + :call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames... + ] code_info.slotflags = UInt8[0x00, 0x00, code_info.slotflags...] n_prepended_slots = 2 overdub_args_slot = Core.SlotNumber(n_prepended_slots) @@ -464,7 +472,6 @@ function call_with_reactant_generator( # the end of the pass, we'll reset `code_info` fields accordingly. overdubbed_code = Any[] overdubbed_codelocs = Int32[] - # Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention # required by the base method. @@ -481,14 +488,16 @@ function call_with_reactant_generator( offset += 1 end slot = i + n_prepended_slots - actual_argument = Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset) + actual_argument = Expr( + :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset + ) push!(overdubbed_code, :($(Core.SlotNumber(slot)) = $actual_argument)) push!(overdubbed_codelocs, code_info.codelocs[1]) code_info.slotflags[slot] |= 0x02 # ensure this slotflag has the "assigned" bit set offset += 1 - - #push!(overdubbed_code, actual_argument) - push!(fn_args, Core.SSAValue(length(overdubbed_code))) + + #push!(overdubbed_code, actual_argument) + push!(fn_args, Core.SSAValue(length(overdubbed_code))) end # If `method` is a varargs method, we have to restructure the original method call's @@ -497,26 +506,42 @@ function call_with_reactant_generator( if !isempty(overdubbed_code) # remove the final slot reassignment leftover from the previous destructuring pop!(overdubbed_code) - pop!(overdubbed_codelocs) - pop!(fn_args) + pop!(overdubbed_codelocs) + pop!(fn_args) end trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) for i in n_method_args:n_actual_args - push!(overdubbed_code, Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - 1)) + push!( + overdubbed_code, + Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - 1), + ) push!(overdubbed_codelocs, code_info.codelocs[1]) push!(trailing_arguments.args, Core.SSAValue(length(overdubbed_code))) offset += 1 end - push!(overdubbed_code, Expr(:(=), Core.SlotNumber(n_method_args + n_prepended_slots), trailing_arguments)) - push!(overdubbed_codelocs, code_info.codelocs[1]) - push!(fn_args, Core.SSAValue(length(overdubbed_code))) + push!( + overdubbed_code, + Expr( + :(=), Core.SlotNumber(n_method_args + n_prepended_slots), trailing_arguments + ), + ) + push!(overdubbed_codelocs, code_info.codelocs[1]) + push!(fn_args, Core.SSAValue(length(overdubbed_code))) end # substitute static parameters, offset slot numbers by number of added slots, and # offset statement indices by the number of additional statements - arg_partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...], - n_prepended_slots, n_prepended_slots, length(overdubbed_code), :propagate) + arg_partially_inline!( + code_info.code, + fn_args, + method.sig, + Any[static_params...], + n_prepended_slots, + n_prepended_slots, + length(overdubbed_code), + :propagate, + ) append!(overdubbed_code, code_info.code) append!(overdubbed_codelocs, code_info.codelocs) @@ -537,7 +562,7 @@ end @eval function call_with_reactant($OVERDUB_ARGUMENTS_NAME...) $(Expr(:meta, :generated_only)) - $(Expr(:meta, :generated, call_with_reactant_generator)) + return $(Expr(:meta, :generated, call_with_reactant_generator)) end function make_mlir_fn( From 51286bd5b7edcf1c38a2ec3b5dd493383f13b56d Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 10 Dec 2024 23:24:17 -0600 Subject: [PATCH 25/78] fix apply --- src/utils.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index c423c9ca7..5beac5b13 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -657,11 +657,7 @@ function make_mlir_fn( end # TODO fix it for kwargs - if f === Reactant.apply - call_with_reactant(f, traced_args[1], (traced_args[2:end]...,)) - else - call_with_reactant(f, traced_args...) - end + call_with_reactant(f, traced_args...) end seen_results = OrderedIdDict() From bd40b69efb97baadd4d01349fa3e2a672b9cac67 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 10 Dec 2024 23:27:20 -0600 Subject: [PATCH 26/78] indep of change --- src/utils.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 5beac5b13..a70d15cc4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -442,9 +442,11 @@ function call_with_reactant_generator( # on type infered code (that we undo inference of), if there is no type unstable # code to be rewritten, just use the default methodinstance (still using our methodtable), # to improve compatibility as these bugs are fixed upstream. - if !any_changed - src = Core.Compiler.retrieve_code_info(mi, world) - end + # Just kidding we can't do this, since otherwise the inferred code won't guarantee to run + # within our interpreter, so we must use our generated IR here. + # if !any_changed + # src = Core.Compiler.retrieve_code_info(mi, world) + # end # prepare a new code info code_info = copy(src) From 5b3329cef748c3e2a95f2eed463664ca5b51b559 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 10 Dec 2024 23:28:27 -0600 Subject: [PATCH 27/78] minor fix in name --- src/utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index a70d15cc4..7ade139fb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -329,7 +329,7 @@ our argument. Thus we gensym to create it. This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0deda3a1037ec519eebe216952bfe/src/overdub.jl#L154 """ -const OVERDUB_ARGUMENTS_NAME = gensym("overdub_arguments") +const REDUB_ARGUMENTS_NAME = gensym("overdub_arguments") # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: @@ -345,7 +345,7 @@ function call_with_reactant_generator( args = redub_arguments stub = Core.GeneratedFunctionStub( - identity, Core.svec(:call_with_reactant, OVERDUB_ARGUMENTS_NAME), Core.svec() + identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() ) # look up the method match @@ -562,7 +562,7 @@ function call_with_reactant_generator( return code_info end -@eval function call_with_reactant($OVERDUB_ARGUMENTS_NAME...) +@eval function call_with_reactant($REDUB_ARGUMENTS_NAME...) $(Expr(:meta, :generated_only)) return $(Expr(:meta, :generated, call_with_reactant_generator)) end From 4af4a00b18fd2b53a6ae295805bf821818381734 Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 10 Dec 2024 23:30:54 -0600 Subject: [PATCH 28/78] Update utils.jl --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 7ade139fb..c81b0e5d9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -329,7 +329,7 @@ our argument. Thus we gensym to create it. This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0deda3a1037ec519eebe216952bfe/src/overdub.jl#L154 """ -const REDUB_ARGUMENTS_NAME = gensym("overdub_arguments") +const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: From 8379f05a9d03e80bccfc581579ff6e7b28f073e7 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 11 Dec 2024 12:25:04 -0600 Subject: [PATCH 29/78] Interp take 2 --- src/utils.jl | 95 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 81 insertions(+), 14 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index c81b0e5d9..2cdd00629 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -66,13 +66,12 @@ function rewrite_inst(inst, ir) end end if Meta.isexpr(inst, :invoke) - return false, Expr(:call, inst.args[2:end]...) + @show inst + # return false, Expr(:call, inst.args[2:end]...) end return false, inst end -const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") - # From Julia's Base.Meta with fix from https://github.com/JuliaLang/julia/pull/56787 # and additionally adds support for an argument rewriting into a slot function arg_partially_inline!( @@ -331,6 +330,54 @@ This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0d """ const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") +function compute_stateful_oc_signature(ir::Core.Compiler.IRCode, nargs::Int, isva::Bool) + argtypes = Vector{Any}(undef, nargs) + for i = 1:nargs + argtypes[i] = Core.Compiler.widenconst(ir.argtypes[i]) + end + if isva + lastarg = pop!(argtypes) + if lastarg <: Tuple + append!(argtypes, lastarg.parameters) + else + push!(argtypes, Vararg{Any}) + end + end + return Tuple{argtypes...} +end + +function StatefulOpaqueClosure(ir::Core.Compiler.IRCode, @nospecialize env...; + isva::Bool = false, + slotnames::Union{Nothing,Vector{Symbol}}=nothing, + kwargs...) + # if the user didn't specify a definition MethodInstance or filename Symbol to use for the debuginfo, set a filename now + @static if VERSION < v"1.12-" + else + ir.debuginfo.def === nothing && (ir.debuginfo.def = :var"generated IR for OpaqueClosure") + end + nargtypes = length(ir.argtypes) + nargs = nargtypes + sig = compute_stateful_oc_signature(ir, nargs, isva) + rt = Base.Experimental.compute_ir_rettype(ir) + src = ccall(:jl_new_code_info_uninit, Ref{Core.Compiler.CodeInfo}, ()) + if slotnames === nothing + src.slotnames = fill(:none, nargtypes) + else + length(slotnames) == nargtypes || error("mismatched `argtypes` and `slotnames`") + src.slotnames = slotnames + end + src.slotflags = fill(zero(UInt8), nargtypes) + src.slottypes = copy(ir.argtypes) + @static if VERSION < v"1.12-" + else + src.isva = isva + src.nargs = UInt(nargtypes) + end + src = Core.Compiler.ir_to_codeinf!(src, ir) + src.rettype = rt + return Base.Experimental.generate_opaque_closure(sig, Union{}, rt, src, nargs, isva, env...; kwargs...) +end + # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -534,19 +581,38 @@ function call_with_reactant_generator( # substitute static parameters, offset slot numbers by number of added slots, and # offset statement indices by the number of additional statements - arg_partially_inline!( - code_info.code, - fn_args, - method.sig, - Any[static_params...], - n_prepended_slots, - n_prepended_slots, - length(overdubbed_code), - :propagate, + # arg_partially_inline!( + # code_info.code, + # fn_args, + # method.sig, + # Any[static_params...], + # n_prepended_slots, + # n_prepended_slots, + # length(overdubbed_code), + # :propagate, + #) + + @show ir + + push!( + overdubbed_code, + Expr( + :(call), + StatefulOpaqueClosure(ir; isva=method.isva), + fn_args... + ), + ) + + push!(overdubbed_codelocs, code_info.codelocs[1]) + + push!( + overdubbed_code, + Core.ReturnNode(Core.SSAValue(length(overdubbed_code))) ) + push!(overdubbed_codelocs, code_info.codelocs[1]) - append!(overdubbed_code, code_info.code) - append!(overdubbed_codelocs, code_info.codelocs) + # append!(overdubbed_code, code_info.code) + # append!(overdubbed_codelocs, code_info.codelocs) #=== set `code_info`/`reflection` fields accordingly ===# @@ -559,6 +625,7 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code + @show code_info return code_info end From 246ec4e9d97983f5a86a448aaf605c1cd6a82844 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 11 Dec 2024 16:34:48 -0600 Subject: [PATCH 30/78] continuing adentures --- src/utils.jl | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 2cdd00629..43f8a543a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -67,6 +67,7 @@ function rewrite_inst(inst, ir) end if Meta.isexpr(inst, :invoke) @show inst + flush(stdout) # return false, Expr(:call, inst.args[2:end]...) end return false, inst @@ -378,6 +379,7 @@ function StatefulOpaqueClosure(ir::Core.Compiler.IRCode, @nospecialize env...; return Base.Experimental.generate_opaque_closure(sig, Union{}, rt, src, nargs, isva, env...; kwargs...) end + # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -593,12 +595,24 @@ function call_with_reactant_generator( #) @show ir + ccall(:jl_, Any, (Any,), "ir=") + ccall(:jl_, Any, (Any,), ir) + flush(stdout) + + rt = Base.Experimental.compute_ir_rettype(ir) + lno = LineNumberNode(1, :none) + oc = Base.Experimental.generate_opaque_closure(sig::Type, rt::Type, rt::Type, src::Core.Compiler.CodeInfo, Int(method.nargs), method.isva::Bool) + + @show oc + ccall(:jl_, Any, (Any,), "oc=") + ccall(:jl_, Any, (Any,), oc) + flush(stdout) push!( overdubbed_code, Expr( :(call), - StatefulOpaqueClosure(ir; isva=method.isva), + oc, fn_args... ), ) @@ -626,6 +640,10 @@ function call_with_reactant_generator( code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code @show code_info + ccall(:jl_, Any, (Any,), "code_info=") + ccall(:jl_, Any, (Any,), code_info) + flush(stdout) + return code_info end From 9a669ef00e13b922a6c6dc6b48beacefc1e198f0 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 11 Dec 2024 16:41:46 -0600 Subject: [PATCH 31/78] delcode --- src/Interpreter.jl | 89 ++------------ src/utils.jl | 295 +-------------------------------------------- 2 files changed, 10 insertions(+), 374 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 675f9036d..acdde888b 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -29,51 +29,6 @@ function var"@reactant_override"(__source__::LineNumberNode, __module__::Module, ) end -function set_reactant_abi( - interp, - @nospecialize(f), - arginfo::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int=get_max_methods(interp, f, sv), -) - (; fargs, argtypes) = arginfo - - if ( - (f === Enzyme.autodiff) || - (f === Enzyme.autodiff_deferred) || - (f === Enzyme.gradient) || - (f === Enzyme.jacobian) - ) && (length(argtypes) >= 2) - if widenconst(argtypes[2]) <: Enzyme.Mode - newmode = Enzyme.set_abi(widenconst(argtypes[2]), ReactantABI) - if newmode != widenconst(argtypes[2]) - newmodev = newmode() - arginfo2 = ArgInfo( - if fargs isa Nothing - nothing - else - [fargs[1], :($(newmodev)), fargs[3:end]...] - end, - [argtypes[1], Core.Const(newmodev), argtypes[3:end]...], - ) - return abstract_call_known(interp, f, arginfo2, si, sv, max_methods) - end - end - end - - return Base.@invoke abstract_call_known( - interp::AbstractInterpreter, - f::Any, - arginfo::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int, - ) -end - -function set_reactant_abi end - @static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE struct ReactantCacheToken end @@ -85,7 +40,7 @@ function set_reactant_abi end true, #=forward_rules=# true, #=reverse_rules=# false, #=broadcast_rewrite=# - set_reactant_abi, + nothing, ) end else @@ -101,7 +56,7 @@ else true, #=forward_rules=# true, #=forward_rules=# false, #=broadcast_rewrite=# - set_reactant_abi, + nothing, ) end end @@ -566,58 +521,28 @@ function overload_autodiff( end end -@inline function Enzyme.autodiff_deferred( - rmode::Enzyme.ReverseMode{ - ReturnPrimal,RuntimeActivity,ReactantABI,Holomorphic,ErrIfFuncWritten - }, +@reactant_override @inline function Enzyme.autodiff_deferred( + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}, ) where { FA<:Annotation, A<:Annotation, - ReturnPrimal, - RuntimeActivity, - Holomorphic, Nargs, - ErrIfFuncWritten, } return overload_autodiff(rmode, f, rt, args...) end -@inline function Enzyme.autodiff_deferred( - rmode::ForwardMode{ReturnPrimal,ReactantABI,ErrIfFuncWritten,RuntimeActivity}, - f::FA, - rt::Type{A}, - args::Vararg{Annotation,Nargs}, -) where {FA<:Annotation,A<:Annotation,ReturnPrimal,Nargs,ErrIfFuncWritten,RuntimeActivity} - return overload_autodiff(rmode, f, rt, args...) -end - -@inline function Enzyme.autodiff( - rmode::Enzyme.ReverseMode{ - ReturnPrimal,RuntimeActivity,ReactantABI,Holomorphic,ErrIfFuncWritten - }, +@reactant_override @inline function Enzyme.autodiff( + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}, ) where { FA<:Annotation, A<:Annotation, - ReturnPrimal, - RuntimeActivity, - Holomorphic, Nargs, - ErrIfFuncWritten, } return overload_autodiff(rmode, f, rt, args...) -end - -@inline function Enzyme.autodiff( - rmode::ForwardMode{ReturnPrimal,ReactantABI,ErrIfFuncWritten,RuntimeActivity}, - f::FA, - rt::Type{A}, - args::Vararg{Annotation,Nargs}, -) where {FA<:Annotation,A<:Annotation,ReturnPrimal,Nargs,ErrIfFuncWritten,RuntimeActivity} - return overload_autodiff(rmode, f, rt, args...) -end +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 43f8a543a..8cebb048c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -73,249 +73,6 @@ function rewrite_inst(inst, ir) return false, inst end -# From Julia's Base.Meta with fix from https://github.com/JuliaLang/julia/pull/56787 -# and additionally adds support for an argument rewriting into a slot -function arg_partially_inline!( - code::Vector{Any}, - slot_replacements::Vector{Any}, - @nospecialize(type_signature), #=::Type{<:Tuple}=# - static_param_values::Vector{Any}, - slot_offset::Int, - arg_offset::Int, - statement_offset::Int, - boundscheck::Symbol, -) - for i in 1:length(code) - isassigned(code, i) || continue - code[i] = _arg_partially_inline!( - code[i], - slot_replacements, - type_signature, - static_param_values, - slot_offset, - arg_offset, - statement_offset, - boundscheck, - ) - end - return code -end - -function _arg_partially_inline!( - @nospecialize(x), - slot_replacements::Vector{Any}, - @nospecialize(type_signature), - static_param_values::Vector{Any}, - slot_offset::Int, - arg_offset::Int, - statement_offset::Int, - boundscheck::Symbol, -) - if isa(x, Core.SSAValue) - return Core.SSAValue(x.id + statement_offset) - end - if isa(x, Core.GotoNode) - return Core.GotoNode(x.label + statement_offset) - end - if isa(x, Core.SlotNumber) - id = x.id - if 1 <= id <= length(slot_replacements) - return slot_replacements[id] - end - return Core.SlotNumber(id + slot_offset) - end - if isa(x, Core.Argument) - return Core.SlotNumber(x.n + arg_offset) - end - if isa(x, Core.NewvarNode) - return Core.NewvarNode( - _arg_partially_inline!( - x.slot, - slot_replacements, - type_signature, - static_param_values, - slot_offset, - arg_offset, - statement_offset, - boundscheck, - ), - ) - end - if isa(x, Core.PhiNode) - arg_partially_inline!( - x.values, - slot_replacements, - type_signature, - static_param_values, - slot_offset, - arg_offset, - statement_offset, - boundscheck, - ) - x.edges .+= slot_offset - return x - end - if isa(x, Core.ReturnNode) - if !isdefined(x, :val) - return Core.ReturnNode(:nothing) - else - return Core.ReturnNode( - _arg_partially_inline!( - x.val, - slot_replacements, - type_signature, - static_param_values, - slot_offset, - arg_offset, - statement_offset, - boundscheck, - ), - ) - end - end - if isa(x, Core.GotoIfNot) - return Core.GotoIfNot( - _arg_partially_inline!( - x.cond, - slot_replacements, - type_signature, - static_param_values, - slot_offset, - arg_offset, - statement_offset, - boundscheck, - ), - x.dest + statement_offset, - ) - end - if isdefined(Core, :EnterNode) && isa(x, Core.EnterNode) - return Core.EnterNode(x, x.catch_dest + statement_offset) - end - if isa(x, Expr) - head = x.head - if head === :static_parameter - if isassigned(static_param_values, x.args[1]) - return QuoteNode(static_param_values[x.args[1]]) - end - return x - elseif head === :cfunction - @assert !isa(type_signature, UnionAll) || !isempty(spvals) - if !isa(x.args[2], QuoteNode) # very common no-op - x.args[2] = Core.Compiler._partially_inline!( - x.args[2], - slot_replacements, - type_signature, - static_param_values, - slot_offset, - arg_offset, - statement_offset, - boundscheck, - ) - end - x.args[3] = Base.Meta._instantiate_type_in_env( - x.args[3], type_signature, static_param_values - ) - x.args[4] = Core.svec( - Any[ - Base.Meta._instantiate_type_in_env( - argt, type_signature, static_param_values - ) for argt in x.args[4] - ]..., - ) - elseif head === :foreigncall - @assert !isa(type_signature, UnionAll) || !isempty(static_param_values) - for i in 1:length(x.args) - if i == 2 - x.args[2] = Base.Meta._instantiate_type_in_env( - x.args[2], type_signature, static_param_values - ) - elseif i == 3 - x.args[3] = Core.svec( - Any[ - Base.Meta._instantiate_type_in_env( - argt, type_signature, static_param_values - ) for argt in x.args[3] - ]..., - ) - elseif i == 4 - @assert isa(x.args[4], Int) - elseif i == 5 - @assert isa( - (x.args[5]::QuoteNode).value, Union{Symbol,Tuple{Symbol,UInt8}} - ) - else - x.args[i] = _arg_partially_inline!( - x.args[i], - slot_replacements, - type_signature, - static_param_values, - slot_offset, - statement_offset, - arg_offset, - boundscheck, - ) - end - end - elseif head === :boundscheck - if boundscheck === :propagate - return x - elseif boundscheck === :off - return false - else - return true - end - elseif head === :gotoifnot - x.args[1] = _arg_partially_inline!( - x.args[1], - slot_replacements, - type_signature, - static_param_values, - slot_offset, - arg_offset, - statement_offset, - boundscheck, - ) - x.args[2] += statement_offset - elseif head === :isdefined - arg = x.args[1] - # inlining a QuoteNode or literal into `Expr(:isdefined, x)` is invalid, replace with true - if isa(arg, Core.SlotNumber) - id = arg.id - if 1 <= id <= length(slot_replacements) - replacement = slot_replacements[id] - if isa(replacement, Union{Core.SlotNumber,GlobalRef,Symbol}) - return Expr(:isdefined, replacement) - else - @assert !isa(replacement, Expr) - return true - end - end - return Expr(:isdefined, Core.SlotNumber(id + slot_offset)) - elseif isexpr(arg, :static_parameter) - if isassigned(static_param_values, arg.args[1]) - return true - end - return x - else - @assert isa(arg, Union{GlobalRef,Symbol}) - return x - end - elseif !Core.Compiler.is_meta_expr_head(head) - arg_partially_inline!( - x.args, - slot_replacements, - type_signature, - static_param_values, - slot_offset, - arg_offset, - statement_offset, - boundscheck, - ) - end - end - return x -end - """ Reactant.REDUB_ARGUMENTS_NAME @@ -331,55 +88,6 @@ This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0d """ const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") -function compute_stateful_oc_signature(ir::Core.Compiler.IRCode, nargs::Int, isva::Bool) - argtypes = Vector{Any}(undef, nargs) - for i = 1:nargs - argtypes[i] = Core.Compiler.widenconst(ir.argtypes[i]) - end - if isva - lastarg = pop!(argtypes) - if lastarg <: Tuple - append!(argtypes, lastarg.parameters) - else - push!(argtypes, Vararg{Any}) - end - end - return Tuple{argtypes...} -end - -function StatefulOpaqueClosure(ir::Core.Compiler.IRCode, @nospecialize env...; - isva::Bool = false, - slotnames::Union{Nothing,Vector{Symbol}}=nothing, - kwargs...) - # if the user didn't specify a definition MethodInstance or filename Symbol to use for the debuginfo, set a filename now - @static if VERSION < v"1.12-" - else - ir.debuginfo.def === nothing && (ir.debuginfo.def = :var"generated IR for OpaqueClosure") - end - nargtypes = length(ir.argtypes) - nargs = nargtypes - sig = compute_stateful_oc_signature(ir, nargs, isva) - rt = Base.Experimental.compute_ir_rettype(ir) - src = ccall(:jl_new_code_info_uninit, Ref{Core.Compiler.CodeInfo}, ()) - if slotnames === nothing - src.slotnames = fill(:none, nargtypes) - else - length(slotnames) == nargtypes || error("mismatched `argtypes` and `slotnames`") - src.slotnames = slotnames - end - src.slotflags = fill(zero(UInt8), nargtypes) - src.slottypes = copy(ir.argtypes) - @static if VERSION < v"1.12-" - else - src.isva = isva - src.nargs = UInt(nargtypes) - end - src = Core.Compiler.ir_to_codeinf!(src, ir) - src.rettype = rt - return Base.Experimental.generate_opaque_closure(sig, Union{}, rt, src, nargs, isva, env...; kwargs...) -end - - # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -393,6 +101,9 @@ function call_with_reactant_generator( @nospecialize args = redub_arguments + @show args + flush(stdout) + stub = Core.GeneratedFunctionStub( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() ) From 623ff38002b5e772ca6d684e1dfea028c465d195 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 11 Dec 2024 22:10:05 -0600 Subject: [PATCH 32/78] fix --- src/Compiler.jl | 5 +++++ src/Interpreter.jl | 45 +++++++++++++++++++++++++++++++++++++++++++-- src/utils.jl | 24 ++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 586f33b05..4e00b9b74 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -779,6 +779,11 @@ function compile(f, args; client=nothing, optimize=true, sync=false) return register_thunk(fname, body) end +# Compiling within a compile should return simply the original function +Reactant.@reactant_override function Reactant.Compiler.compile(f, args; client=nothing, optimize=true, sync=false) + return f +end + # inspired by RuntimeGeneratedFunction.jl const __thunk_body_cache = Dict{Symbol,Expr}() diff --git a/src/Interpreter.jl b/src/Interpreter.jl index acdde888b..6a31eea07 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -29,6 +29,47 @@ function var"@reactant_override"(__source__::LineNumberNode, __module__::Module, ) end +function set_reactant_abi( + interp, + @nospecialize(f), + arginfo::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int=get_max_methods(interp, f, sv), +) + (; fargs, argtypes) = arginfo + + @show "pre", f + + # Improve inference by considering call_with_reactant as having the same results as + # the original call + if f === Reactant.call_with_reactant + arginfo2 = ArgInfo( + fargs isa Nothing ? nothing : + fargs[2:end], + argtypes[2:end], + ) + return abstract_call( + interp, + arginfo2::ArgInfo, + si, + sv, + max_methods, + ) + end + + @show "post", f + + return Base.@invoke abstract_call_known( + interp::AbstractInterpreter, + f::Any, + arginfo::ArgInfo, + si::StmtInfo, + sv::AbsIntState, + max_methods::Int, + ) +end + @static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE struct ReactantCacheToken end @@ -40,7 +81,7 @@ end true, #=forward_rules=# true, #=reverse_rules=# false, #=broadcast_rewrite=# - nothing, + set_reactant_abi, ) end else @@ -56,7 +97,7 @@ else true, #=forward_rules=# true, #=forward_rules=# false, #=broadcast_rewrite=# - nothing, + set_reactant_abi, ) end end diff --git a/src/utils.jl b/src/utils.jl index 8cebb048c..7d7f6659d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -102,6 +102,9 @@ function call_with_reactant_generator( args = redub_arguments @show args + ccall(:jl_, Any, (Any,), "world="*string(world)) + ccall(:jl_, Any, (Any,), "args=") + ccall(:jl_, Any, (Any,), args) flush(stdout) stub = Core.GeneratedFunctionStub( @@ -114,6 +117,7 @@ function call_with_reactant_generator( )) if args[1] <: Core.Builtin + ccall(:jl_, Any, (Any,), "builtin-ret") return stub(world, source, builtin_error) end method_error = :(throw(MethodError(args[1], args[2:end], $world))) @@ -127,6 +131,7 @@ function call_with_reactant_generator( end if lookup_result === nothing || lookup_result === missing + ccall(:jl_, Any, (Any,), "no rlookup_result"*string(lookup_result)) return stub(world, source, method_error) end @@ -134,6 +139,8 @@ function call_with_reactant_generator( # No method could be found (including in our method table), bail with an error if length(matches) != 1 + ccall(:jl_, Any, (Any,), "no matches "*string(lookup_result)) + ccall(:jl_, Any, (Any,), "no matches2 "*string(matches)) return stub(world, source, method_error) end @@ -148,14 +155,19 @@ function call_with_reactant_generator( match.sparams, ) + ccall(:jl_, Any, (Any,), "mi=") + ccall(:jl_, Any, (Any,), mi) + result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) frame = Core.Compiler.InferenceState(result, :local, interp) #=cache_mode=# + ccall(:jl_, Any, (Any,), "frame="*string(frame)) @assert frame !== nothing Core.Compiler.typeinf(interp, frame) @static if VERSION >= v"1.11" # `typeinf` doesn't update the cfg. We need to do it manually. frame.cfg = Core.Compiler.compute_basic_blocks(frame.src.code) end + ccall(:jl_, Any, (Any,), "frameinf="*string(Core.Compiler.is_inferred(frame))) @assert Core.Compiler.is_inferred(frame) method = match.method @@ -176,6 +188,10 @@ function call_with_reactant_generator( Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end + + ccall(:jl_, Any, (Any,), "first ir=") + ccall(:jl_, Any, (Any,), ir) + # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type # to Any if our interpreter would change the return type of any result. @@ -196,8 +212,16 @@ function call_with_reactant_generator( end end Core.Compiler.finish(interp, opt, ir, caller) + + ccall(:jl_, Any, (Any,), "post ir=") + ccall(:jl_, Any, (Any,), ir) + src = Core.Compiler.ir_to_codeinf!(opt) + + ccall(:jl_, Any, (Any,), "post src=") + ccall(:jl_, Any, (Any,), src) + # Julia hits various internal errors trying to re-perform type inference # on type infered code (that we undo inference of), if there is no type unstable # code to be rewritten, just use the default methodinstance (still using our methodtable), From df3e27c4dff2c6e4fbf0fd2ff04743ff9a974970 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 11 Dec 2024 23:54:20 -0600 Subject: [PATCH 33/78] tmp --- src/utils.jl | 60 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 7d7f6659d..f9a2e750a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -66,8 +66,6 @@ function rewrite_inst(inst, ir) end end if Meta.isexpr(inst, :invoke) - @show inst - flush(stdout) # return false, Expr(:call, inst.args[2:end]...) end return false, inst @@ -88,6 +86,10 @@ This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0d """ const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") +@generated function make_oc(sig, rt, isva, method) + Expr(:new_opaque_closure, sig, rt, rt, isva, method) +end + # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -101,7 +103,6 @@ function call_with_reactant_generator( @nospecialize args = redub_arguments - @show args ccall(:jl_, Any, (Any,), "world="*string(world)) ccall(:jl_, Any, (Any,), "args=") ccall(:jl_, Any, (Any,), args) @@ -329,16 +330,60 @@ function call_with_reactant_generator( # :propagate, #) - @show ir ccall(:jl_, Any, (Any,), "ir=") ccall(:jl_, Any, (Any,), ir) flush(stdout) rt = Base.Experimental.compute_ir_rettype(ir) - lno = LineNumberNode(1, :none) - oc = Base.Experimental.generate_opaque_closure(sig::Type, rt::Type, rt::Type, src::Core.Compiler.CodeInfo, Int(method.nargs), method.isva::Bool) - @show oc + meth = ccall(:jl_new_method_uninit, Ref{Method}, (Any,), Main) + meth.sig = sig + meth.isva = method.isva + meth.is_for_opaque_closure = true + meth.name = :opaque_closure + meth.nargs = method.nargs + meth.file = Symbol("") + meth.line = 0 # source + ccall(:jl_method_set_source, Cvoid, (Ref{Core.Method}, Ref{Core.Compiler.CodeInfo}), meth, src) + + nmi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any), meth, sig, Core.svec()) + + inst2 = Core.Compiler.CodeInstance( + nmi::Core.MethodInstance, rt, #=@nospecialize(inferred_const)=#C_NULL, + #=@nospecialize(inferred)=#src, #=const_flags=#Int32(0), lookup_result.valid_worlds.min_world, lookup_result.valid_worlds.max_world, + #=ipo_effects::UInt32=#UInt32(0), #=effects::UInt32=#UInt32(0), #=@nospecialize(argescapes#=::Union{Nothing,Vector{ArgEscapeInfo}}=#)=#nothing, + #=src.relocatability=#UInt8(0)) + + ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), nmi, inst2) + + # oc = make_oc(sig, rt, method.isva, meth) + + # oc = Core.OpaqueClosure(sig, rt, rt, method, C_NULL, 0, true) # + + oc = @static if sizeof(Cint) == sizeof(Int64) + Base.llvmcall((""" + declare {} addrspace(10)* @jl_new_opaque_closure_jlcall({} addrspace(10)*, {} addrspace(10)**, i64); + + declare {} addrspace(10)* @julia.call(...) + + define {} addrspace(10)* @f({} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) { + %res = call {} addrspace(10)* (...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i64)* @jl_new_opaque_closure_jlcall, {} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) + ret {} addrspace(10)* %res + } + """, "f"), Any, Tuple{Any, Any, Any, Any}, sig, rt, rt, meth) + else + Base.llvmcall((""" + declare {} addrspace(10)* @jl_new_opaque_closure_jlcall({} addrspace(10)*, {} addrspace(10)**, i32); + + declare {} addrspace(10)* @julia.call(...) + + define {} addrspace(10)* @f({} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) { + %res = call {} addrspace(10)* (...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @jl_new_opaque_closure_jlcall, {} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) + ret {} addrspace(10)* %res + } + """, "f"), Any, Tuple{Any, Any, Any, Any}, sig, rt, rt, meth) + end + ccall(:jl_, Any, (Any,), "oc=") ccall(:jl_, Any, (Any,), oc) flush(stdout) @@ -374,7 +419,6 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - @show code_info ccall(:jl_, Any, (Any,), "code_info=") ccall(:jl_, Any, (Any,), code_info) flush(stdout) From bda891208fdcec4fd563ddbc4446c89e9d573c9e Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 00:20:07 -0600 Subject: [PATCH 34/78] make --- src/utils.jl | 37 ++++++++++++------------------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index f9a2e750a..2e5fc3534 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -86,8 +86,17 @@ This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0d """ const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") -@generated function make_oc(sig, rt, isva, method) - Expr(:new_opaque_closure, sig, rt, rt, isva, method) +function make_oc(@nospecialize(sig::Type), @nospecialize(rt::Type), @nospecialize(rt2::Type), method::Core.Method) + Base.llvmcall((""" + declare {} addrspace(10)* @jl_new_opaque_closure_jlcall({} addrspace(10)*, {} addrspace(10)**, i32); + + declare {} addrspace(10)* @julia.call(...) + + define {} addrspace(10)* @f({} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) { + %res = call {} addrspace(10)* (...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @jl_new_opaque_closure_jlcall, {} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) + ret {} addrspace(10)* %res + } + """, "f"), Any, Tuple{Any, Any, Any, Any}, sig, rt, rt2, method) end # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter @@ -360,29 +369,7 @@ function call_with_reactant_generator( # oc = Core.OpaqueClosure(sig, rt, rt, method, C_NULL, 0, true) # - oc = @static if sizeof(Cint) == sizeof(Int64) - Base.llvmcall((""" - declare {} addrspace(10)* @jl_new_opaque_closure_jlcall({} addrspace(10)*, {} addrspace(10)**, i64); - - declare {} addrspace(10)* @julia.call(...) - - define {} addrspace(10)* @f({} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) { - %res = call {} addrspace(10)* (...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i64)* @jl_new_opaque_closure_jlcall, {} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) - ret {} addrspace(10)* %res - } - """, "f"), Any, Tuple{Any, Any, Any, Any}, sig, rt, rt, meth) - else - Base.llvmcall((""" - declare {} addrspace(10)* @jl_new_opaque_closure_jlcall({} addrspace(10)*, {} addrspace(10)**, i32); - - declare {} addrspace(10)* @julia.call(...) - - define {} addrspace(10)* @f({} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) { - %res = call {} addrspace(10)* (...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @jl_new_opaque_closure_jlcall, {} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) - ret {} addrspace(10)* %res - } - """, "f"), Any, Tuple{Any, Any, Any, Any}, sig, rt, rt, meth) - end + oc = make_oc(sig, rt, rt, meth) ccall(:jl_, Any, (Any,), "oc=") ccall(:jl_, Any, (Any,), oc) From fd92864b7259c695f1ad943e17b16223e7b9f7a6 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 00:54:11 -0600 Subject: [PATCH 35/78] fix --- src/utils.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 2e5fc3534..49882144c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -369,7 +369,11 @@ function call_with_reactant_generator( # oc = Core.OpaqueClosure(sig, rt, rt, method, C_NULL, 0, true) # - oc = make_oc(sig, rt, rt, meth) + ocargs = Any[Tuple{typeof(Reactant.apply), sig.parameters...}, rt, rt, meth] + @show meth.nargs + @show length(ocargs[1].parameters) + flush(stdout) + oc = ccall(:jl_new_opaque_closure_jlcall, Any, (Ptr{Cvoid}, Ptr{Any}, Int32), C_NULL, ocargs, Int32(4)) ccall(:jl_, Any, (Any,), "oc=") ccall(:jl_, Any, (Any,), oc) From e74173b9401307ac5e2278b4ed69471a230141be Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 10:02:41 -0600 Subject: [PATCH 36/78] cleanup --- src/Interpreter.jl | 4 -- src/utils.jl | 117 ++++++++++----------------------------------- 2 files changed, 24 insertions(+), 97 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 6a31eea07..7af10535c 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -39,8 +39,6 @@ function set_reactant_abi( ) (; fargs, argtypes) = arginfo - @show "pre", f - # Improve inference by considering call_with_reactant as having the same results as # the original call if f === Reactant.call_with_reactant @@ -58,8 +56,6 @@ function set_reactant_abi( ) end - @show "post", f - return Base.@invoke abstract_call_known( interp::AbstractInterpreter, f::Any, diff --git a/src/utils.jl b/src/utils.jl index 49882144c..7db162cd5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -68,7 +68,7 @@ function rewrite_inst(inst, ir) if Meta.isexpr(inst, :invoke) # return false, Expr(:call, inst.args[2:end]...) end - return false, inst + return falsse, inst end """ @@ -86,19 +86,6 @@ This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0d """ const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") -function make_oc(@nospecialize(sig::Type), @nospecialize(rt::Type), @nospecialize(rt2::Type), method::Core.Method) - Base.llvmcall((""" - declare {} addrspace(10)* @jl_new_opaque_closure_jlcall({} addrspace(10)*, {} addrspace(10)**, i32); - - declare {} addrspace(10)* @julia.call(...) - - define {} addrspace(10)* @f({} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) { - %res = call {} addrspace(10)* (...) @julia.call({} addrspace(10)* ({} addrspace(10)*, {} addrspace(10)**, i32)* @jl_new_opaque_closure_jlcall, {} addrspace(10)* %a0, {} addrspace(10)* %a1, {} addrspace(10)* %a2, {} addrspace(10)* %a3) - ret {} addrspace(10)* %res - } - """, "f"), Any, Tuple{Any, Any, Any, Any}, sig, rt, rt2, method) -end - # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -141,7 +128,6 @@ function call_with_reactant_generator( end if lookup_result === nothing || lookup_result === missing - ccall(:jl_, Any, (Any,), "no rlookup_result"*string(lookup_result)) return stub(world, source, method_error) end @@ -149,8 +135,6 @@ function call_with_reactant_generator( # No method could be found (including in our method table), bail with an error if length(matches) != 1 - ccall(:jl_, Any, (Any,), "no matches "*string(lookup_result)) - ccall(:jl_, Any, (Any,), "no matches2 "*string(matches)) return stub(world, source, method_error) end @@ -165,19 +149,14 @@ function call_with_reactant_generator( match.sparams, ) - ccall(:jl_, Any, (Any,), "mi=") - ccall(:jl_, Any, (Any,), mi) - result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) frame = Core.Compiler.InferenceState(result, :local, interp) #=cache_mode=# - ccall(:jl_, Any, (Any,), "frame="*string(frame)) @assert frame !== nothing Core.Compiler.typeinf(interp, frame) @static if VERSION >= v"1.11" # `typeinf` doesn't update the cfg. We need to do it manually. frame.cfg = Core.Compiler.compute_basic_blocks(frame.src.code) end - ccall(:jl_, Any, (Any,), "frameinf="*string(Core.Compiler.is_inferred(frame))) @assert Core.Compiler.is_inferred(frame) method = match.method @@ -198,10 +177,6 @@ function call_with_reactant_generator( Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end - - ccall(:jl_, Any, (Any,), "first ir=") - ccall(:jl_, Any, (Any,), ir) - # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type # to Any if our interpreter would change the return type of any result. @@ -223,15 +198,8 @@ function call_with_reactant_generator( end Core.Compiler.finish(interp, opt, ir, caller) - ccall(:jl_, Any, (Any,), "post ir=") - ccall(:jl_, Any, (Any,), ir) - src = Core.Compiler.ir_to_codeinf!(opt) - - ccall(:jl_, Any, (Any,), "post src=") - ccall(:jl_, Any, (Any,), src) - # Julia hits various internal errors trying to re-perform type inference # on type infered code (that we undo inference of), if there is no type unstable # code to be rewritten, just use the default methodinstance (still using our methodtable), @@ -325,66 +293,36 @@ function call_with_reactant_generator( push!(fn_args, Core.SSAValue(length(overdubbed_code))) end - # substitute static parameters, offset slot numbers by number of added slots, and - # offset statement indices by the number of additional statements - - # arg_partially_inline!( - # code_info.code, - # fn_args, - # method.sig, - # Any[static_params...], - # n_prepended_slots, - # n_prepended_slots, - # length(overdubbed_code), - # :propagate, - #) - - ccall(:jl_, Any, (Any,), "ir=") - ccall(:jl_, Any, (Any,), ir) - flush(stdout) - rt = Base.Experimental.compute_ir_rettype(ir) - - meth = ccall(:jl_new_method_uninit, Ref{Method}, (Any,), Main) - meth.sig = sig - meth.isva = method.isva - meth.is_for_opaque_closure = true - meth.name = :opaque_closure - meth.nargs = method.nargs - meth.file = Symbol("") - meth.line = 0 # source - ccall(:jl_method_set_source, Cvoid, (Ref{Core.Method}, Ref{Core.Compiler.CodeInfo}), meth, src) - - nmi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, (Any, Any, Any), meth, sig, Core.svec()) - - inst2 = Core.Compiler.CodeInstance( - nmi::Core.MethodInstance, rt, #=@nospecialize(inferred_const)=#C_NULL, - #=@nospecialize(inferred)=#src, #=const_flags=#Int32(0), lookup_result.valid_worlds.min_world, lookup_result.valid_worlds.max_world, - #=ipo_effects::UInt32=#UInt32(0), #=effects::UInt32=#UInt32(0), #=@nospecialize(argescapes#=::Union{Nothing,Vector{ArgEscapeInfo}}=#)=#nothing, - #=src.relocatability=#UInt8(0)) - - ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), nmi, inst2) - - # oc = make_oc(sig, rt, method.isva, meth) - - # oc = Core.OpaqueClosure(sig, rt, rt, method, C_NULL, 0, true) # - - ocargs = Any[Tuple{typeof(Reactant.apply), sig.parameters...}, rt, rt, meth] - @show meth.nargs - @show length(ocargs[1].parameters) - flush(stdout) - oc = ccall(:jl_new_opaque_closure_jlcall, Any, (Ptr{Cvoid}, Ptr{Any}, Int32), C_NULL, ocargs, Int32(4)) - - ccall(:jl_, Any, (Any,), "oc=") - ccall(:jl_, Any, (Any,), oc) - flush(stdout) + + + oc = if Base.issingletontype(args[1]) + ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), + Tuple{sig.parameters[2:end]...}, rt, rt, @__MODULE__, src, 0, nothing, method.nargs-1, method.isva, args[1].instance, true) + else + push!(overdubbed_code, + quote + args[1] + end + ) + push!(overdubbed_codelocs, code_info.codelocs[1]) + farg = Core.SSAValue(length(overdubbed_code)) + push!(overdubbed_code, + quote + ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), + $(Tuple{sig.parameters[2:end]...}), $rt, $rt, $(@__MODULE__), $src, 0, nothing, $(method.nargs-1), $(method.isva), $farg, true) + end + ) + push!(overdubbed_codelocs, code_info.codelocs[1]) + Core.SSAValue(length(overdubbed_code)) + end push!( overdubbed_code, Expr( :(call), oc, - fn_args... + fn_args[2:end]... ), ) @@ -396,9 +334,6 @@ function call_with_reactant_generator( ) push!(overdubbed_codelocs, code_info.codelocs[1]) - # append!(overdubbed_code, code_info.code) - # append!(overdubbed_codelocs, code_info.codelocs) - #=== set `code_info`/`reflection` fields accordingly ===# if code_info.method_for_inference_limit_heuristics === nothing @@ -410,10 +345,6 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - ccall(:jl_, Any, (Any,), "code_info=") - ccall(:jl_, Any, (Any,), code_info) - flush(stdout) - return code_info end From c71942c0570174201428316cad28980441003e44 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 10:05:16 -0600 Subject: [PATCH 37/78] continuing --- src/utils.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 7db162cd5..c263394e7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -68,7 +68,7 @@ function rewrite_inst(inst, ir) if Meta.isexpr(inst, :invoke) # return false, Expr(:call, inst.args[2:end]...) end - return falsse, inst + return false, inst end """ @@ -99,10 +99,8 @@ function call_with_reactant_generator( @nospecialize args = redub_arguments - ccall(:jl_, Any, (Any,), "world="*string(world)) ccall(:jl_, Any, (Any,), "args=") ccall(:jl_, Any, (Any,), args) - flush(stdout) stub = Core.GeneratedFunctionStub( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() From 1fa3c930c0003a6e505acac586642432e07e2181 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 21:42:02 -0600 Subject: [PATCH 38/78] more working --- src/utils.jl | 245 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 164 insertions(+), 81 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index c263394e7..07d5dee6d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -86,6 +86,64 @@ This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0d """ const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") +function throw_method_error(argtys) + throw(MethodError(argtys[1], argtys[2:end])) +end + + + +@inline function lookup_world(@nospecialize(sig::Type), world::UInt, mt::Union{Nothing,Core.MethodTable}, min_world::Ref{UInt}, max_world::Ref{UInt}) + ccall(:jl_, Any, (Any,), "pre mt "*string(world)*" mnw="*string(min_world)*" mxw"*string(max_world)) + res = ccall(:jl_gf_invoke_lookup_worlds, Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, mt, world, min_world, max_world) + ccall(:jl_, Any, (Any,), "post mt "*string(world)*" mnw="*string(min_world)* " mxw"*string(max_world)) + return res +end + +@inline function lookup_world(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.InternalMethodTable, min_world::Ref{UInt}, max_world::Ref{UInt}) + @show "pre imt", world, min_world, max_world + res = lookup_world(sig, mt.world, nothing, min_world, max_world) + @show "imt", res, world, min_world, max_world + return res +end + +@inline function lookup_world(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.OverlayMethodTable, min_world::Ref{UInt}, max_world::Ref{UInt}) + res = lookup_world(sig, mt.world, mt.mt, min_world, max_world) + if res !== nothing + return res + else + return lookup_world(sig, mt.world, nothing, min_world, max_world) + end +end + + +# HACK: in all versions of Julia, `jl_new_opaque_closure_from_code_info` doesn't take a world argument +# but instead always generates code for the current world. note that this doesn't +# actually change the world age, but just spoofs the counter `jl_create_native` reads. +# XXX: Base.get_world_counter is supposed to be monotonically increasing and is runtime global. +macro in_world(world, ex) + quote + actual_world = Base.get_world_counter() + world_counter = cglobal(:jl_world_counter, Csize_t) + unsafe_store!(world_counter, $(esc(world))) + try + $(esc(ex)) + finally + unsafe_store!(world_counter, actual_world) + end + end +end + +#define jl_current_task (container_of(jl_get_pgcstack(), jl_task_t, gcstack)) + + +function make_oc(sig, rt, src, nargs, isva, f)::Core.OpaqueClosure + ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), + sig, rt, rt, @__MODULE__, src, 0, nothing, nargs, isva, f, true)::Core.OpaqueClosure +end + + # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -99,8 +157,7 @@ function call_with_reactant_generator( @nospecialize args = redub_arguments - ccall(:jl_, Any, (Any,), "args=") - ccall(:jl_, Any, (Any,), args) + ccall(:jl_, Any, (Any,), string(world)*" args="*string(args)) stub = Core.GeneratedFunctionStub( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() @@ -112,31 +169,96 @@ function call_with_reactant_generator( )) if args[1] <: Core.Builtin - ccall(:jl_, Any, (Any,), "builtin-ret") return stub(world, source, builtin_error) end - method_error = :(throw(MethodError(args[1], args[2:end], $world))) + method_error = :(throw(MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world))) interp = ReactantInterpreter(; world) sig = Tuple{args...} - lookup_result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp)) - @static if VERSION < v"1.11-" - lookup_result = lookup_result.matches - end - if lookup_result === nothing || lookup_result === missing - return stub(world, source, method_error) - end + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + + lookup_result = lookup_world(sig, world, Core.Compiler.method_table(interp), min_world, max_world) + + ccall(:jl_, Any, (Any,), string(lookup_result)*" sig="*string(sig)*" mw="*string(min_world)*" "*string(max_world)*" "*string(Base.get_world_counter())) - matches = lookup_result.matches + overdubbed_code = Any[] + overdubbed_codelocs = Int32[] # No method could be found (including in our method table), bail with an error - if length(matches) != 1 + if lookup_result == nothing return stub(world, source, method_error) + tmp_min_world = Ref{UInt}(typemin(UInt)) + tmp_max_world = Ref{UInt}(typemax(UInt)) + match = ccall(:jl_gf_invoke_lookup_worlds, Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + Tuple{typeof(throw_method_error), sig}, #=mt=# nothing, world, tmp_min_world, tmp_max_world) + @assert match !== nothing + + # look up the method and code instance + mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, + (Any, Any, Any), match.method, match.spec_types, match.sparams) + + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo + + src = copy(ci) + src.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] + + src.edges = Any[ccall(:jl_method_table_for, Any, (Any,), sig)::Core.MethodTable, sig] + src.min_world = min_world[] + src.max_world = max_world[] + + push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), 1))) + push!(overdubbed_codelocs, 0) + + expr_fn = Core.SSAValue(length(overdubbed_code)) + + + push!(overdubbed_code, :($(Base.lastindex)($(Core.Argument(2))))) + push!(overdubbed_codelocs, 0) + + expr_lastindex = Core.SSAValue(length(overdubbed_code)) + + + push!(overdubbed_code, :(2:$expr_lastindex)) + push!(overdubbed_codelocs, 0) + + expr_slice = Core.SSAValue(length(overdubbed_code)) + + push!(overdubbed_code, :($(Base.getindex)($(Core.Argument(2)), $expr_slice))) + push!(overdubbed_codelocs, 0) + + expr_args = Core.SSAValue(length(overdubbed_code)) + + push!(overdubbed_code, :($(Base.MethodError)($expr_fn, $expr_args, $world))) + push!(overdubbed_codelocs, 0) + + expr_method = Core.SSAValue(length(overdubbed_code)) + + push!(overdubbed_code, :($(Base.throw)($expr_method))) + push!(overdubbed_codelocs, 0) + + push!( + overdubbed_code, + Core.ReturnNode(Core.SSAValue(length(overdubbed_code))) + ) + push!(overdubbed_codelocs, 0) + + src.code = overdubbed_code + src.codelocs = overdubbed_codelocs + src.ssavaluetypes = length(overdubbed_code) + src.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code + + @show src + @show src.edges + @show typeof(src) + + return src end - match = matches[1]::Core.MethodMatch + match = lookup_result::Core.MethodMatch # look up the method and code instance mi = ccall( :jl_specializations_get_linfo, @@ -159,6 +281,9 @@ function call_with_reactant_generator( method = match.method + ccall(:jl_, Any, (Any,), ("method=")*string(method)) + ccall(:jl_, Any, (Any,), ("va=")*string(method.isva)) + # The original julia code (on 1.11+) has the potential constprop, for now # we assume this outermost function does not constprop, for ease. #if Core.Compiler.result_is_constabi(interp, frame.result) @@ -175,6 +300,8 @@ function call_with_reactant_generator( Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end + ccall(:jl_, Any, (Any,), ("ir=")*string(ir)) + # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type # to Any if our interpreter would change the return type of any result. @@ -198,33 +325,22 @@ function call_with_reactant_generator( src = Core.Compiler.ir_to_codeinf!(opt) - # Julia hits various internal errors trying to re-perform type inference - # on type infered code (that we undo inference of), if there is no type unstable - # code to be rewritten, just use the default methodinstance (still using our methodtable), - # to improve compatibility as these bugs are fixed upstream. - # Just kidding we can't do this, since otherwise the inferred code won't guarantee to run - # within our interpreter, so we must use our generated IR here. - # if !any_changed - # src = Core.Compiler.retrieve_code_info(mi, world) - # end + ccall(:jl_, Any, (Any,), ("src=")*string(src)) # prepare a new code info code_info = copy(src) static_params = match.sparams signature = sig - is_invoke = args[1] === typeof(Core.invoke) # propagate edge metadata, this method is invalidated if the original function we are calling # is invalidated code_info.edges = Core.MethodInstance[mi] - code_info.min_world = lookup_result.valid_worlds.min_world - code_info.max_world = lookup_result.valid_worlds.max_world + code_info.min_world = min_world[] + code_info.max_world = max_world[] # Rewrite the arguments to this function, to prepend the two new arguments, the function :call_with_reactant, # and the REDUB_ARGUMENTS_NAME tuple of input arguments - code_info.slotnames = Any[ - :call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames... - ] + code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames...] code_info.slotflags = UInt8[0x00, 0x00, code_info.slotflags...] n_prepended_slots = 2 overdub_args_slot = Core.SlotNumber(n_prepended_slots) @@ -240,15 +356,10 @@ function call_with_reactant_generator( # destructure the generated argument slots into the overdubbed method's argument slots. n_actual_args = fieldcount(signature) n_method_args = Int(method.nargs) + offset = 1 fn_args = Any[] - for i in 1:n_method_args - if is_invoke && (i == 1 || i == 2) - # With an invoke call, we have: 1 is invoke, 2 is f, 3 is Tuple{}, 4... is args. - # In the first loop iteration, we should skip invoke and process f. - # In the second loop iteration, we should skip the Tuple type and process args[1]. - offset += 1 - end + for i in 1:length(redub_arguments) slot = i + n_prepended_slots actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset @@ -258,59 +369,30 @@ function call_with_reactant_generator( code_info.slotflags[slot] |= 0x02 # ensure this slotflag has the "assigned" bit set offset += 1 - #push!(overdubbed_code, actual_argument) - push!(fn_args, Core.SSAValue(length(overdubbed_code))) - end - - # If `method` is a varargs method, we have to restructure the original method call's - # trailing arguments into a tuple and assign that tuple to the expected argument slot. - if method.isva - if !isempty(overdubbed_code) - # remove the final slot reassignment leftover from the previous destructuring - pop!(overdubbed_code) - pop!(overdubbed_codelocs) - pop!(fn_args) - end - trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) - for i in n_method_args:n_actual_args - push!( - overdubbed_code, - Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - 1), - ) - push!(overdubbed_codelocs, code_info.codelocs[1]) - push!(trailing_arguments.args, Core.SSAValue(length(overdubbed_code))) - offset += 1 - end - push!( - overdubbed_code, - Expr( - :(=), Core.SlotNumber(n_method_args + n_prepended_slots), trailing_arguments - ), - ) - push!(overdubbed_codelocs, code_info.codelocs[1]) push!(fn_args, Core.SSAValue(length(overdubbed_code))) end rt = Base.Experimental.compute_ir_rettype(ir) - + # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right + # inner code during compilation without special handling (i.e. call_in_world_total). + # Opaque closures also require takign the function argument. We can work around the latter + # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure oc = if Base.issingletontype(args[1]) - ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - Tuple{sig.parameters[2:end]...}, rt, rt, @__MODULE__, src, 0, nothing, method.nargs-1, method.isva, args[1].instance, true) + Core._call_in_world_total(world, make_oc, Tuple{sig.parameters[2:end]...}, rt, src, method.nargs - 1, method.isva, args[1].instance)::Core.OpaqueClosure else + farg = fn_args[1] push!(overdubbed_code, - quote - args[1] - end - ) - push!(overdubbed_codelocs, code_info.codelocs[1]) - farg = Core.SSAValue(length(overdubbed_code)) - push!(overdubbed_code, - quote - ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - $(Tuple{sig.parameters[2:end]...}), $rt, $rt, $(@__MODULE__), $src, 0, nothing, $(method.nargs-1), $(method.isva), $farg, true) - end - ) + Expr(:call, + make_oc, + Tuple{sig.parameters[2:end]...}, + rt, + src, + method.nargs-1, + method.isva, + farg + ) + ) push!(overdubbed_codelocs, code_info.codelocs[1]) Core.SSAValue(length(overdubbed_code)) end @@ -343,6 +425,7 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code + ccall(:jl_, Any, (Any,), "code_info="*string(code_info)) return code_info end From 07fb85699433bfbcd0e9d2d490e1ed3a0b649d4b Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 21:47:08 -0600 Subject: [PATCH 39/78] further simplify --- src/utils.jl | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 07d5dee6d..f93ada05c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -340,8 +340,8 @@ function call_with_reactant_generator( # Rewrite the arguments to this function, to prepend the two new arguments, the function :call_with_reactant, # and the REDUB_ARGUMENTS_NAME tuple of input arguments - code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames...] - code_info.slotflags = UInt8[0x00, 0x00, code_info.slotflags...] + code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] + code_info.slotflags = UInt8[0x00, 0x00] n_prepended_slots = 2 overdub_args_slot = Core.SlotNumber(n_prepended_slots) @@ -354,21 +354,16 @@ function call_with_reactant_generator( # required by the base method. # destructure the generated argument slots into the overdubbed method's argument slots. - n_actual_args = fieldcount(signature) - n_method_args = Int(method.nargs) offset = 1 fn_args = Any[] for i in 1:length(redub_arguments) - slot = i + n_prepended_slots actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset ) - push!(overdubbed_code, :($(Core.SlotNumber(slot)) = $actual_argument)) + push!(overdubbed_code, actual_argument) push!(overdubbed_codelocs, code_info.codelocs[1]) - code_info.slotflags[slot] |= 0x02 # ensure this slotflag has the "assigned" bit set offset += 1 - push!(fn_args, Core.SSAValue(length(overdubbed_code))) end From 72533ff289edceea1e5cbae1327b98699123a4a3 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 23:00:42 -0600 Subject: [PATCH 40/78] fx --- src/utils.jl | 50 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index f93ada05c..61cc28d4f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -357,7 +357,10 @@ function call_with_reactant_generator( offset = 1 fn_args = Any[] - for i in 1:length(redub_arguments) + n_method_args = method.nargs + n_actual_args = length(redub_arguments) + + for i in 1:n_actual_args actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset ) @@ -367,24 +370,61 @@ function call_with_reactant_generator( push!(fn_args, Core.SSAValue(length(overdubbed_code))) end + + # If `method` is a varargs method, we have to restructure the original method call's + # trailing arguments into a tuple and assign that tuple to the expected argument slot. + if false && method.isva + if !isempty(overdubbed_code) + # remove the final slot reassignment leftover from the previous destructuring + pop!(overdubbed_code) + pop!(overdubbed_codelocs) + pop!(fn_args) + end + trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) + for i in n_method_args:n_actual_args + push!( + overdubbed_code, + Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - 1), + ) + push!(overdubbed_codelocs, code_info.codelocs[1]) + push!(trailing_arguments.args, Core.SSAValue(length(overdubbed_code))) + offset += 1 + end + push!( + overdubbed_code, trailing_arguments + ) + push!(overdubbed_codelocs, code_info.codelocs[1]) + push!(fn_args, Core.SSAValue(length(overdubbed_code))) + end + rt = Base.Experimental.compute_ir_rettype(ir) + ocva = method.isva + + @show method + @show method.isva, method.sig, mi.specTypes + @show method.nargs + ocnargs = method.nargs - 1 + octup = Tuple{mi.specTypes.parameters[2:end]...} + @show octup + @show mi + # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). # Opaque closures also require takign the function argument. We can work around the latter # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure oc = if Base.issingletontype(args[1]) - Core._call_in_world_total(world, make_oc, Tuple{sig.parameters[2:end]...}, rt, src, method.nargs - 1, method.isva, args[1].instance)::Core.OpaqueClosure + Core._call_in_world_total(world, make_oc, octup, rt, src, ocnargs, ocva, args[1].instance)::Core.OpaqueClosure else farg = fn_args[1] push!(overdubbed_code, Expr(:call, make_oc, - Tuple{sig.parameters[2:end]...}, + octup, rt, src, - method.nargs-1, - method.isva, + ocnargs, + ocva, farg ) ) From 503f1ff9e6596230303e5b0a633e89e719f25907 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 23:51:09 -0600 Subject: [PATCH 41/78] more improvements --- src/utils.jl | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 61cc28d4f..7abecd183 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -360,7 +360,13 @@ function call_with_reactant_generator( n_method_args = method.nargs n_actual_args = length(redub_arguments) - for i in 1:n_actual_args + tys = [] + + iter_args = n_actual_args + if method.isva + iter_args = min(n_actual_args, n_method_args-1) + end + for i in 1:iter_args actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset ) @@ -368,47 +374,61 @@ function call_with_reactant_generator( push!(overdubbed_codelocs, code_info.codelocs[1]) offset += 1 push!(fn_args, Core.SSAValue(length(overdubbed_code))) + push!(tys, redub_arguments[i]) end # If `method` is a varargs method, we have to restructure the original method call's # trailing arguments into a tuple and assign that tuple to the expected argument slot. - if false && method.isva - if !isempty(overdubbed_code) - # remove the final slot reassignment leftover from the previous destructuring - pop!(overdubbed_code) - pop!(overdubbed_codelocs) - pop!(fn_args) - end + if method.isva + @show "post pop", tys trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) for i in n_method_args:n_actual_args push!( overdubbed_code, - Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - 1), + Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset), ) push!(overdubbed_codelocs, code_info.codelocs[1]) push!(trailing_arguments.args, Core.SSAValue(length(overdubbed_code))) offset += 1 end + push!( overdubbed_code, trailing_arguments ) push!(overdubbed_codelocs, code_info.codelocs[1]) push!(fn_args, Core.SSAValue(length(overdubbed_code))) + push!(tys, Tuple{redub_arguments[n_method_args:n_actual_args]...}) + + @show "post redo", tys + @show Tuple{redub_arguments[n_method_args:n_actual_args]...} end + @show n_method_args, n_actual_args rt = Base.Experimental.compute_ir_rettype(ir) - ocva = method.isva + # ocva = method.isva + + ocva = false # method.isva @show method @show method.isva, method.sig, mi.specTypes @show method.nargs ocnargs = method.nargs - 1 - octup = Tuple{mi.specTypes.parameters[2:end]...} + # octup = Tuple{mi.specTypes.parameters[2:end]...} + # octup = Tuple{method.sig.parameters[2:end]...} + octup = Tuple{tys[2:end]...} + ocva = false @show octup @show mi + if false && method.isva && tys[end] == Tuple{} + octup = Tuple{tys[2:end-1]...} + ocnargs -= 1 + end + + @show "final", ocva, ocnargs, octup + # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). # Opaque closures also require takign the function argument. We can work around the latter From d302cd9a0be9f7f57aa94c8bbe58ad389c91da4c Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 23:51:51 -0600 Subject: [PATCH 42/78] minus show --- src/utils.jl | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 7abecd183..3bd6f144a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -102,9 +102,7 @@ end end @inline function lookup_world(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.InternalMethodTable, min_world::Ref{UInt}, max_world::Ref{UInt}) - @show "pre imt", world, min_world, max_world res = lookup_world(sig, mt.world, nothing, min_world, max_world) - @show "imt", res, world, min_world, max_world return res end @@ -251,10 +249,6 @@ function call_with_reactant_generator( src.ssavaluetypes = length(overdubbed_code) src.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - @show src - @show src.edges - @show typeof(src) - return src end @@ -381,7 +375,6 @@ function call_with_reactant_generator( # If `method` is a varargs method, we have to restructure the original method call's # trailing arguments into a tuple and assign that tuple to the expected argument slot. if method.isva - @show "post pop", tys trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) for i in n_method_args:n_actual_args push!( @@ -399,36 +392,24 @@ function call_with_reactant_generator( push!(overdubbed_codelocs, code_info.codelocs[1]) push!(fn_args, Core.SSAValue(length(overdubbed_code))) push!(tys, Tuple{redub_arguments[n_method_args:n_actual_args]...}) - - @show "post redo", tys - @show Tuple{redub_arguments[n_method_args:n_actual_args]...} end - @show n_method_args, n_actual_args rt = Base.Experimental.compute_ir_rettype(ir) # ocva = method.isva ocva = false # method.isva - @show method - @show method.isva, method.sig, mi.specTypes - @show method.nargs ocnargs = method.nargs - 1 # octup = Tuple{mi.specTypes.parameters[2:end]...} # octup = Tuple{method.sig.parameters[2:end]...} octup = Tuple{tys[2:end]...} ocva = false - @show octup - @show mi - if false && method.isva && tys[end] == Tuple{} octup = Tuple{tys[2:end-1]...} ocnargs -= 1 end - @show "final", ocva, ocnargs, octup - # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). # Opaque closures also require takign the function argument. We can work around the latter From 59a648a6f41f9e3c46476b158439b2e2e94b9212 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 23:53:08 -0600 Subject: [PATCH 43/78] less prints --- src/utils.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 3bd6f144a..b27f8f44f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -155,8 +155,6 @@ function call_with_reactant_generator( @nospecialize args = redub_arguments - ccall(:jl_, Any, (Any,), string(world)*" args="*string(args)) - stub = Core.GeneratedFunctionStub( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() ) @@ -179,8 +177,6 @@ function call_with_reactant_generator( max_world = Ref{UInt}(typemax(UInt)) lookup_result = lookup_world(sig, world, Core.Compiler.method_table(interp), min_world, max_world) - - ccall(:jl_, Any, (Any,), string(lookup_result)*" sig="*string(sig)*" mw="*string(min_world)*" "*string(max_world)*" "*string(Base.get_world_counter())) overdubbed_code = Any[] overdubbed_codelocs = Int32[] @@ -275,9 +271,6 @@ function call_with_reactant_generator( method = match.method - ccall(:jl_, Any, (Any,), ("method=")*string(method)) - ccall(:jl_, Any, (Any,), ("va=")*string(method.isva)) - # The original julia code (on 1.11+) has the potential constprop, for now # we assume this outermost function does not constprop, for ease. #if Core.Compiler.result_is_constabi(interp, frame.result) @@ -294,8 +287,6 @@ function call_with_reactant_generator( Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end - ccall(:jl_, Any, (Any,), ("ir=")*string(ir)) - # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type # to Any if our interpreter would change the return type of any result. @@ -319,8 +310,6 @@ function call_with_reactant_generator( src = Core.Compiler.ir_to_codeinf!(opt) - ccall(:jl_, Any, (Any,), ("src=")*string(src)) - # prepare a new code info code_info = copy(src) static_params = match.sparams @@ -461,7 +450,6 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - ccall(:jl_, Any, (Any,), "code_info="*string(code_info)) return code_info end From d55e4e89c1ff947580c9a08d9d0eca196b1efd02 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Dec 2024 23:54:56 -0600 Subject: [PATCH 44/78] even fewer --- src/utils.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index b27f8f44f..a9f4a7ddf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -93,11 +93,9 @@ end @inline function lookup_world(@nospecialize(sig::Type), world::UInt, mt::Union{Nothing,Core.MethodTable}, min_world::Ref{UInt}, max_world::Ref{UInt}) - ccall(:jl_, Any, (Any,), "pre mt "*string(world)*" mnw="*string(min_world)*" mxw"*string(max_world)) res = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), sig, mt, world, min_world, max_world) - ccall(:jl_, Any, (Any,), "post mt "*string(world)*" mnw="*string(min_world)* " mxw"*string(max_world)) return res end From db50a371db66c091f0350bea9a69a4ca6ab4ca7f Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 13 Dec 2024 00:59:28 -0600 Subject: [PATCH 45/78] confusion --- src/TracedRArray.jl | 1 + src/TracedRNumber.jl | 19 +++++++++++-------- src/utils.jl | 17 +++++++++++++++++ 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 0e9bf6f77..6fd0ad316 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -807,6 +807,7 @@ end for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) @eval function Base.clamp!(x::TracedRArray{T}, min::$(minT), max::$(maxT)) where {T} + @show x, min, max y = clamp.(x, min, max) x.mlir_data = y.mlir_data return x diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index ebe733ce6..8479c9436 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -290,14 +290,17 @@ Base.log1p(x::TracedRNumber{T}) where {T} = log(x + one(T)) for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) @eval function Base.clamp(x::TracedRNumber{T}, min::$(minT), max::$(maxT)) where {T} - min = promote_to(TracedRNumber{T}, min) - max = promote_to(TracedRNumber{T}, max) - return TracedRNumber{T}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.clamp(min.mlir_data, x.mlir_data, max.mlir_data), 1 - ), - ) + ccall(:jl_, Any, (Any,), min) + return x + #min = promote_to(TracedRNumber{T}, min) + #max = promote_to(TracedRNumber{T}, max) + #@show min, max + #return TracedRNumber{T}( + # (), + # MLIR.IR.result( + # MLIR.Dialects.stablehlo.clamp(min.mlir_data::MLIR.IR.Value, x.mlir_data::MLIR.IR.Value, max.mlir_data::MLIR.IR.Value), 1 + # ), + #) end end diff --git a/src/utils.jl b/src/utils.jl index a9f4a7ddf..1660ee58b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -93,9 +93,11 @@ end @inline function lookup_world(@nospecialize(sig::Type), world::UInt, mt::Union{Nothing,Core.MethodTable}, min_world::Ref{UInt}, max_world::Ref{UInt}) + ccall(:jl_, Any, (Any,), "pre mt "*string(world)*" mnw="*string(min_world)*" mxw"*string(max_world)) res = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), sig, mt, world, min_world, max_world) + ccall(:jl_, Any, (Any,), "post mt "*string(world)*" mnw="*string(min_world)* " mxw"*string(max_world)) return res end @@ -153,6 +155,8 @@ function call_with_reactant_generator( @nospecialize args = redub_arguments + ccall(:jl_, Any, (Any,), string(world)*" args="*string(args)) + stub = Core.GeneratedFunctionStub( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() ) @@ -175,6 +179,8 @@ function call_with_reactant_generator( max_world = Ref{UInt}(typemax(UInt)) lookup_result = lookup_world(sig, world, Core.Compiler.method_table(interp), min_world, max_world) + + ccall(:jl_, Any, (Any,), string(lookup_result)*" sig="*string(sig)*" mw="*string(min_world)*" "*string(max_world)*" "*string(Base.get_world_counter())) overdubbed_code = Any[] overdubbed_codelocs = Int32[] @@ -269,6 +275,9 @@ function call_with_reactant_generator( method = match.method + ccall(:jl_, Any, (Any,), ("method=")*string(method)) + ccall(:jl_, Any, (Any,), ("va=")*string(method.isva)) + # The original julia code (on 1.11+) has the potential constprop, for now # we assume this outermost function does not constprop, for ease. #if Core.Compiler.result_is_constabi(interp, frame.result) @@ -285,6 +294,8 @@ function call_with_reactant_generator( Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end + ccall(:jl_, Any, (Any,), ("ir=")*string(ir)) + # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type # to Any if our interpreter would change the return type of any result. @@ -308,6 +319,8 @@ function call_with_reactant_generator( src = Core.Compiler.ir_to_codeinf!(opt) + ccall(:jl_, Any, (Any,), ("src=")*string(src)) + # prepare a new code info code_info = copy(src) static_params = match.sparams @@ -420,6 +433,8 @@ function call_with_reactant_generator( Core.SSAValue(length(overdubbed_code)) end + @show Base.issingletontype(args[1]), oc + push!( overdubbed_code, Expr( @@ -448,6 +463,7 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code + ccall(:jl_, Any, (Any,), "code_info="*string(code_info)) return code_info end @@ -548,6 +564,7 @@ function make_mlir_fn( end # TODO fix it for kwargs + @show f, traced_args call_with_reactant(f, traced_args...) end From 284094a7cc14739637fde69f11c63a13cd569731 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 13 Dec 2024 01:21:13 -0600 Subject: [PATCH 46/78] tmp --- src/utils.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 1660ee58b..10c6de5ca 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -410,6 +410,10 @@ function call_with_reactant_generator( ocnargs -= 1 end + @show args[1] + @show method.isva + @show octup, rt, ocnargs, ocva + # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). # Opaque closures also require takign the function argument. We can work around the latter From 82584f8058aed523117d6f33f9c7789885bc703c Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 13 Dec 2024 01:24:00 -0600 Subject: [PATCH 47/78] force clean --- src/TracedRArray.jl | 1 - src/TracedRNumber.jl | 19 ++++++++----------- src/utils.jl | 29 ++++++----------------------- 3 files changed, 14 insertions(+), 35 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 6fd0ad316..0e9bf6f77 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -807,7 +807,6 @@ end for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) @eval function Base.clamp!(x::TracedRArray{T}, min::$(minT), max::$(maxT)) where {T} - @show x, min, max y = clamp.(x, min, max) x.mlir_data = y.mlir_data return x diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 8479c9436..012d7961b 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -290,17 +290,14 @@ Base.log1p(x::TracedRNumber{T}) where {T} = log(x + one(T)) for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) @eval function Base.clamp(x::TracedRNumber{T}, min::$(minT), max::$(maxT)) where {T} - ccall(:jl_, Any, (Any,), min) - return x - #min = promote_to(TracedRNumber{T}, min) - #max = promote_to(TracedRNumber{T}, max) - #@show min, max - #return TracedRNumber{T}( - # (), - # MLIR.IR.result( - # MLIR.Dialects.stablehlo.clamp(min.mlir_data::MLIR.IR.Value, x.mlir_data::MLIR.IR.Value, max.mlir_data::MLIR.IR.Value), 1 - # ), - #) + min = promote_to(TracedRNumber{T}, min) + max = promote_to(TracedRNumber{T}, max) + return TracedRNumber{T}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.clamp(min.mlir_data::MLIR.IR.Value, x.mlir_data::MLIR.IR.Value, max.mlir_data::MLIR.IR.Value), 1 + ), + ) end end diff --git a/src/utils.jl b/src/utils.jl index 10c6de5ca..e21e1a00b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -155,8 +155,6 @@ function call_with_reactant_generator( @nospecialize args = redub_arguments - ccall(:jl_, Any, (Any,), string(world)*" args="*string(args)) - stub = Core.GeneratedFunctionStub( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() ) @@ -275,9 +273,6 @@ function call_with_reactant_generator( method = match.method - ccall(:jl_, Any, (Any,), ("method=")*string(method)) - ccall(:jl_, Any, (Any,), ("va=")*string(method.isva)) - # The original julia code (on 1.11+) has the potential constprop, for now # we assume this outermost function does not constprop, for ease. #if Core.Compiler.result_is_constabi(interp, frame.result) @@ -294,8 +289,6 @@ function call_with_reactant_generator( Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end - ccall(:jl_, Any, (Any,), ("ir=")*string(ir)) - # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type # to Any if our interpreter would change the return type of any result. @@ -319,8 +312,6 @@ function call_with_reactant_generator( src = Core.Compiler.ir_to_codeinf!(opt) - ccall(:jl_, Any, (Any,), ("src=")*string(src)) - # prepare a new code info code_info = copy(src) static_params = match.sparams @@ -405,14 +396,6 @@ function call_with_reactant_generator( # octup = Tuple{method.sig.parameters[2:end]...} octup = Tuple{tys[2:end]...} ocva = false - if false && method.isva && tys[end] == Tuple{} - octup = Tuple{tys[2:end-1]...} - ocnargs -= 1 - end - - @show args[1] - @show method.isva - @show octup, rt, ocnargs, ocva # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right # inner code during compilation without special handling (i.e. call_in_world_total). @@ -437,8 +420,6 @@ function call_with_reactant_generator( Core.SSAValue(length(overdubbed_code)) end - @show Base.issingletontype(args[1]), oc - push!( overdubbed_code, Expr( @@ -467,7 +448,6 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - ccall(:jl_, Any, (Any,), "code_info="*string(code_info)) return code_info end @@ -567,9 +547,12 @@ function make_mlir_fn( end end - # TODO fix it for kwargs - @show f, traced_args - call_with_reactant(f, traced_args...) + # TODO fix it for kwargs + if concretein + call_with_reactant(f, traced_args...) + else + f(traced_args...) + end end seen_results = OrderedIdDict() From 00776dab0818eadca0df6b3277ce63044d5b59fd Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 13 Dec 2024 11:08:22 -0600 Subject: [PATCH 48/78] force oc --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index e21e1a00b..16ffa2a34 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -401,7 +401,7 @@ function call_with_reactant_generator( # inner code during compilation without special handling (i.e. call_in_world_total). # Opaque closures also require takign the function argument. We can work around the latter # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure - oc = if Base.issingletontype(args[1]) + oc = if false && Base.issingletontype(args[1]) Core._call_in_world_total(world, make_oc, octup, rt, src, ocnargs, ocva, args[1].instance)::Core.OpaqueClosure else farg = fn_args[1] From ad784b30eeb1ca8fa1c2855e9f08fa849b0be677 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 13 Dec 2024 11:10:52 -0600 Subject: [PATCH 49/78] clean --- src/utils.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 16ffa2a34..f26fe13cf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -93,11 +93,9 @@ end @inline function lookup_world(@nospecialize(sig::Type), world::UInt, mt::Union{Nothing,Core.MethodTable}, min_world::Ref{UInt}, max_world::Ref{UInt}) - ccall(:jl_, Any, (Any,), "pre mt "*string(world)*" mnw="*string(min_world)*" mxw"*string(max_world)) res = ccall(:jl_gf_invoke_lookup_worlds, Any, (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), sig, mt, world, min_world, max_world) - ccall(:jl_, Any, (Any,), "post mt "*string(world)*" mnw="*string(min_world)* " mxw"*string(max_world)) return res end @@ -178,8 +176,6 @@ function call_with_reactant_generator( lookup_result = lookup_world(sig, world, Core.Compiler.method_table(interp), min_world, max_world) - ccall(:jl_, Any, (Any,), string(lookup_result)*" sig="*string(sig)*" mw="*string(min_world)*" "*string(max_world)*" "*string(Base.get_world_counter())) - overdubbed_code = Any[] overdubbed_codelocs = Int32[] From e90096bb1d1d2b437a94cb1fb86997bb484ffbe7 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 13 Dec 2024 22:24:39 -0500 Subject: [PATCH 50/78] Rewrite --- src/Compiler.jl | 2 +- src/ConcreteRArray.jl | 8 + src/ControlFlow.jl | 6 +- src/Interpreter.jl | 145 ++++-------- src/Ops.jl | 72 ++++-- src/Reactant.jl | 52 ++++- src/TracedRArray.jl | 345 ++++------------------------ src/TracedRNumber.jl | 74 +++--- src/TracedUtils.jl | 511 ++++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 374 +++++++------------------------ 10 files changed, 830 insertions(+), 759 deletions(-) create mode 100644 src/TracedUtils.jl diff --git a/src/Compiler.jl b/src/Compiler.jl index 4e00b9b74..44e4f59d9 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -292,7 +292,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true) func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = MLIR.IR.mmodule!(mod) do MLIR.IR.block!(MLIR.IR.body(mod)) do - return Reactant.make_mlir_fn(f, args, (), "main", true) + return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) end end diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index dac67bf69..ba495aef3 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -345,3 +345,11 @@ end buffer_on_cpu(::Any) = true buffer_on_cpu(x::ConcreteRArray) = XLA.BufferOnCPU(x.data.buffer) + +function Ops.constant(x::ConcreteRArray; kwargs...) + return Ops.constant(Base.convert(Array, x); kwargs...) +end + +function Ops.constant(x::ConcreteRNumber{T}; kwargs...) where {T} + return Ops.constant(Base.convert(T, x); kwargs...) +end diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 3b30c4cb6..3035e90c3 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -36,16 +36,16 @@ function ReactantCore.traced_if( returned `$(typeof(tr))`, false branch returned `$(typeof(fr))`.") elseif tr isa MissingTracedValue push!(result_types, MLIR.IR.type(fr.mlir_data)) - push!(linear_results, new_traced_value(false_linear_results[i])) + push!(linear_results, TracedUtils.new_traced_value(false_linear_results[i])) push!(true_block_insertions, (i => linear_results[end])) else push!(result_types, MLIR.IR.type(tr.mlir_data)) - push!(linear_results, new_traced_value(true_linear_results[i])) + push!(linear_results, TracedUtils.new_traced_value(true_linear_results[i])) push!(false_block_insertions, (i => linear_results[end])) end else push!(result_types, MLIR.IR.type(tr.mlir_data)) - push!(linear_results, new_traced_value(tr)) + push!(linear_results, TracedUtils.new_traced_value(tr)) end end diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 724e4c78a..50f4f52d1 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -152,38 +152,30 @@ const enzyme_constnoneed = 5 enzyme_outnoneed end -function push_val!(ad_inputs, x, path) - for p in path - x = traced_getfield(x, p) - end - x = x.mlir_data - return push!(ad_inputs, x) -end - function push_acts!(ad_inputs, x::Const, path, reverse) - return push_val!(ad_inputs, x.val, path) + return TracedUtils.push_val!(ad_inputs, x.val, path) end function push_acts!(ad_inputs, x::Active, path, reverse) - return push_val!(ad_inputs, x.val, path) + return TracedUtils.push_val!(ad_inputs, x.val, path) end function push_acts!(ad_inputs, x::Duplicated, path, reverse) - push_val!(ad_inputs, x.val, path) + TracedUtils.push_val!(ad_inputs, x.val, path) if !reverse - push_val!(ad_inputs, x.dval, path) + TracedUtils.push_val!(ad_inputs, x.dval, path) end end function push_acts!(ad_inputs, x::DuplicatedNoNeed, path, reverse) - push_val!(ad_inputs, x.val, path) + TracedUtils.push_val!(ad_inputs, x.val, path) if !reverse - push_val!(ad_inputs, x.dval, path) + TracedUtils.push_val!(ad_inputs, x.dval, path) end end function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse) - push_val!(ad_inputs, x.val, path) + TracedUtils.push_val!(ad_inputs, x.val, path) if !reverse ET = eltype(x.val) predims = size(x.val) @@ -193,12 +185,12 @@ function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse) ), ) tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...)) - push_val!(ad_inputs, tval, path) + TracedUtils.push_val!(ad_inputs, tval, path) end end function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse) - push_val!(ad_inputs, x.val, path) + TracedUtils.push_val!(ad_inputs, x.val, path) if !reverse ET = eltype(x.val) predims = size(x.val) @@ -208,7 +200,7 @@ function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse) ), ) tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...)) - push_val!(ad_inputs, tval, path) + TracedUtils.push_val!(ad_inputs, tval, path) end end @@ -234,57 +226,6 @@ function set_act!(inp, path, reverse, tostore; emptypath=false) end end -function set!(x, path, tostore; emptypath=false) - for p in path - x = traced_getfield(x, p) - end - - x.mlir_data = tostore - - if emptypath - x.paths = () - end -end - -function get_argidx(x) - for path in x.paths - if length(path) == 0 - continue - end - if path[1] == :args - return path[2]::Int, path - end - end - throw(AssertionError("No path found for $x")) -end -function get_residx(x) - for path in x.paths - if length(path) == 0 - continue - end - if path[1] == :result - return path - end - end - throw(AssertionError("No path found $x")) -end - -function has_residx(x) - for path in x.paths - if length(path) == 0 - continue - end - if path[1] == :result - return true - end - end - return false -end - -function get_attribute_by_name(operation, name) - return MLIR.IR.Attribute(MLIR.API.mlirOperationGetAttributeByName(operation, name)) -end - function overload_autodiff( ::CMode, f::FA, ::Type{A}, args::Vararg{Enzyme.Annotation,Nargs} ) where {CMode<:Enzyme.Mode,FA<:Enzyme.Annotation,A<:Enzyme.Annotation,Nargs} @@ -304,7 +245,7 @@ function overload_autodiff( primf = f.val primargs = ((v.val for v in args)...,) - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( + fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = Reactant.TracedUtils.make_mlir_fn( primf, primargs, (), string(f) * "_autodiff", false ) @@ -312,7 +253,7 @@ function overload_autodiff( ad_inputs = MLIR.IR.Value[] for a in linear_args - idx, path = get_argidx(a) + idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap push!(activity, act_from_type(f, reverse)) push_acts!(ad_inputs, f, path[3:end], reverse) @@ -331,19 +272,19 @@ function overload_autodiff( @inline needs_primal(::Type{<:Enzyme.ForwardMode{ReturnPrimal}}) where {ReturnPrimal} = ReturnPrimal for a in linear_results - if has_residx(a) + if TracedUtils.has_residx(a) if needs_primal(CMode) - push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data))) + push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) end if CMode <: Enzyme.ForwardMode && !(A <: Enzyme.Const) if width == 1 - push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data))) + push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) else - push!(outtys, batch_ty(width, transpose_ty(MLIR.IR.type(a.mlir_data)))) + push!(outtys, TracedUtils.batch_ty(width, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)))) end end else - push!(outtys, transpose_ty(MLIR.IR.type(a.mlir_data))) + push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) end end for (i, act) in enumerate(activity) @@ -351,14 +292,14 @@ function overload_autodiff( if width == 1 push!(outtys, in_tys[i]) else - push!(outtys, batch_ty(width, in_tys[i])) + push!(outtys, TracedUtils.batch_ty(width, in_tys[i])) end end end ret_activity = Int32[] for a in linear_results - if has_residx(a) + if TracedUtils.has_residx(a) act = act_from_type(A, reverse, needs_primal(CMode)) push!(ret_activity, act) if act == enzyme_out || act == enzyme_outnoneed @@ -367,14 +308,14 @@ function overload_autodiff( push!(ad_inputs, cst) end else - idx, path = get_argidx(a) + idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap act = act_from_type(f, reverse, true) push!(ret_activity, act) if act != enzyme_out && act != enzyme_outnoneed continue end - push_val!(ad_inputs, f.dval, path[3:end]) + TracedUtils.push_val!(ad_inputs, f.dval, path[3:end]) else if fnwrap idx -= 1 @@ -384,7 +325,7 @@ function overload_autodiff( if act != enzyme_out && act != enzyme_outnoneed continue end - push_val!(ad_inputs, args[idx].dval, path[3:end]) + TraceUtils.push_val!(ad_inputs, args[idx].dval, path[3:end]) end end end @@ -395,10 +336,10 @@ function overload_autodiff( )::MLIR.API.MlirAttribute return MLIR.IR.Attribute(val) end - fname = get_attribute_by_name(func2, "sym_name") + fname = TracedUtils.get_attribute_by_name(func2, "sym_name") fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)( - [transpose_val(v) for v in ad_inputs]; + [TracedUtils.transpose_val(v) for v in ad_inputs]; outputs=outtys, fn=fname, activity=MLIR.IR.Attribute([act_attr(a) for a in activity]), @@ -421,20 +362,20 @@ function overload_autodiff( end for a in linear_results - if has_residx(a) + if TracedUtils.has_residx(a) if needs_primal(CMode) - path = get_residx(a) - tval = transpose_val(MLIR.IR.result(res, residx)) - set!(result, path[2:end], tval) + path = TracedUtils.get_residx(a) + tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) + TracedUtils.set!(result, path[2:end], tval) residx += 1 end if CMode <: Enzyme.ForwardMode && !(A <: Enzyme.Const) - path = get_residx(a) + path = TracedUtils.get_residx(a) if width == 1 - tval = transpose_val(MLIR.IR.result(res, residx)) - set!(dresult, path[2:end], tval) + tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) + TracedUtils.set!(dresult, path[2:end], tval) else - tval = transpose_val(MLIR.IR.result(res, residx)) + tval = TracedUtils.transpose_val(MLIR.IR.result(res, residx)) for i in 1:width sz = size(a) starts = Int64[i] @@ -444,21 +385,21 @@ function overload_autodiff( push!(limits, v) end sval = Ops.slice(sval, starts, limits) - set!(dresult[i], path[2:end], sval) + TracedUtils.set!(dresult[i], path[2:end], sval) end end residx += 1 end else - idx, path = get_argidx(a) + idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap - set!(f.val, path[3:end], transpose_val(MLIR.IR.result(res, residx))) + TracedUtils.set!(f.val, path[3:end], TracedUtils.transpose_val(MLIR.IR.result(res, residx))) residx += 1 else if fnwrap idx -= 1 end - set!(args[idx].val, path[3:end], transpose_val(MLIR.IR.result(res, residx))) + TracedUtils.set!(args[idx].val, path[3:end], TracedUtils.transpose_val(MLIR.IR.result(res, residx))) residx += 1 end end @@ -466,7 +407,7 @@ function overload_autodiff( restup = Any[(a isa Active) ? copy(a) : nothing for a in args] for a in linear_args - idx, path = get_argidx(a) + idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap if act_from_type(f, reverse) != enzyme_out continue @@ -476,7 +417,7 @@ function overload_autodiff( residx += 1 continue end - set_act!(f, path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx))) + set_act!(f, path[3:end], reverse, TracedUtils.transpose_val(MLIR.IR.result(res, residx))) else if fnwrap idx -= 1 @@ -489,14 +430,14 @@ function overload_autodiff( args[idx], path[3:end], false, - transpose_val(MLIR.IR.result(res, residx)); + TracedUtils.transpose_val(MLIR.IR.result(res, residx)); emptypaths=true, ) #=reverse=# residx += 1 continue end set_act!( - args[idx], path[3:end], reverse, transpose_val(MLIR.IR.result(res, residx)) + args[idx], path[3:end], reverse, TracedUtils.transpose_val(MLIR.IR.result(res, residx)) ) end residx += 1 @@ -528,7 +469,7 @@ function overload_autodiff( end end -@reactant_override @inline function Enzyme.autodiff_deferred( +@reactant_override @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, @@ -541,7 +482,7 @@ end return overload_autodiff(rmode, f, rt, args...) end -@reactant_override @inline function Enzyme.autodiff( +@reactant_override @noinline function Enzyme.autodiff( rmode::Enzyme.Mode, f::FA, rt::Type{A}, @@ -552,4 +493,4 @@ end Nargs, } return overload_autodiff(rmode, f, rt, args...) -end \ No newline at end of file +end diff --git a/src/Ops.jl b/src/Ops.jl index 013e0dbc8..4b335e182 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -6,12 +6,61 @@ using ..MLIR: MLIR using ..MLIR.Dialects: stablehlo, chlo, enzyme using ..Reactant: Reactant, - ConcreteRArray, - ConcreteRNumber, TracedRArray, TracedRNumber, - mlir_type, - mlir_stacktrace + RArray, + RNumber, + MissingTracedValue + +function mlir_type(x::RArray{T,N}) where {T,N} + return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) +end + +mlir_type(::RNumber{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T)) + +mlir_type(::MissingTracedValue) = MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) + +function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N} + @assert length(shape) == N + return MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) +end + +function mlir_type(::Type{<:RNumber{T}}) where {T} + return MLIR.IR.TensorType((), MLIR.IR.Type(T)) +end + +function mlir_type(::Type{<:MissingTracedValue}) + return MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) +end + +const DEBUG_MODE::Ref{Bool} = Ref(false) + +function with_debug(f) + old = DEBUG_MODE[] + DEBUG_MODE[] = true + try + return f() + finally + DEBUG_MODE[] = old + end +end + +function mlir_stacktrace(name, file, line)::MLIR.IR.Location + # calling `stacktrace` can add a lot of time overhead, so let's avoid adding debug info if not used + if DEBUG_MODE[] + return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) + end + + # retrieve current stacktrace, remove this function's frame and translate to MLIR Location + st = stacktrace() + deleteat!(st, 1) + return mapfoldl(MLIR.IR.Location, st) do stackframe + name = string(stackframe.func) + file = stackframe.file + line = stackframe.line + return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) + end +end struct Token mlir_data::MLIR.IR.Value @@ -27,10 +76,6 @@ function constant( return TracedRArray{T,N}((), res, size(x)) end -function constant(x::ConcreteRArray; kwargs...) - return stablehlo.constant(Base.convert(Array, x); kwargs...) -end - function constant( x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} @@ -38,17 +83,6 @@ function constant( return TracedRNumber{T}((), res.mlir_data) end -function constant( - x::ConcreteRNumber{T}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) -) where {T} - output = mlir_type(TracedRArray{T,0}, ()) - value = MLIR.IR.DenseElementsAttribute( - fill(MLIR.IR.Attribute(Base.convert(T, x)), output) - ) - res = MLIR.IR.result(stablehlo.constant(; output, value, location)) - return TracedRNumber{T,N}((), res) -end - # unary elementwise ops for (dialect, op) in [ (:stablehlo, :abs), diff --git a/src/Reactant.jl b/src/Reactant.jl index 0fc900b24..5d670558e 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -97,11 +97,59 @@ include("Interpreter.jl") include("utils.jl") -include("ConcreteRArray.jl") +mutable struct TracedRArray{T,N} <: RArray{T,N} + paths::Tuple + mlir_data::Union{Nothing,MLIR.IR.Value} + shape::NTuple{N,Int} + + function TracedRArray{T,N}( + paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape + ) where {T,N} + shape = Tuple(shape) + if !isnothing(mlir_data) + @assert size(MLIR.IR.type(mlir_data)) == shape + end + return new{T,N}(paths, mlir_data, shape) + end +end + +const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} +const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} +const AnyTracedRVector{T} = AnyTracedRArray{T,1} +const AnyTracedRMatrix{T} = Union{ + AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}} +} +const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} + +function TracedRArray(data::MLIR.IR.Value) + data_type = MLIR.IR.type(data) + return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( + (), data, size(data_type) + ) +end + +mutable struct TracedRNumber{T} <: RNumber{T} + paths::Tuple + mlir_data::Union{Nothing,MLIR.IR.Value} + + function TracedRNumber{T}( + paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} + ) where {T} + if !isnothing(mlir_data) + @assert size(MLIR.IR.type(mlir_data)) == () + end + return new{T}(paths, mlir_data) + end +end + +include("Ops.jl") +include("TracedUtils.jl") + include("TracedRNumber.jl") include("TracedRArray.jl") -include("Ops.jl") +include("ConcreteRArray.jl") + include("linear_algebra.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 6bdbadc8f..988d895da 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1,41 +1,21 @@ +module TracedRArrayOverrides + using Base.Broadcast using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate -mutable struct TracedRArray{T,N} <: RArray{T,N} - paths::Tuple - mlir_data::Union{Nothing,MLIR.IR.Value} - shape::NTuple{N,Int} - - function TracedRArray{T,N}( - paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value}, shape - ) where {T,N} - shape = Tuple(shape) - if !isnothing(mlir_data) - @assert size(MLIR.IR.type(mlir_data)) == shape - end - return new{T,N}(paths, mlir_data, shape) - end -end - -const WrappedTracedRArray{T,N} = WrappedArray{T,N,TracedRArray,TracedRArray{T,N}} -const AnyTracedRArray{T,N} = Union{TracedRArray{T,N},WrappedTracedRArray{T,N}} -const AnyTracedRVector{T} = AnyTracedRArray{T,1} -const AnyTracedRMatrix{T} = Union{ - AnyTracedRArray{T,2},LinearAlgebra.Diagonal{T,TracedRArray{T,1}} -} -const AnyTracedRVecOrMat{T} = Union{AnyTracedRVector{T},AnyTracedRMatrix{T}} - -function TracedRArray(data::MLIR.IR.Value) - data_type = MLIR.IR.type(data) - return TracedRArray{eltype(MLIR.IR.julia_type(data_type)),ndims(data_type)}( - (), data, size(data_type) - ) -end +import ..TracedRArray +import ..TracedRNumber +import ..ReactantPrimitive +import ..WrappedTracedRArray +import ..AnyTracedRArray +using ..TracedUtils +import ..Ops +import ..MLIR +import ReactantCore +import ..TracedUtils: materialize_traced_array ReactantCore.is_traced(::TracedRArray) = true -new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) - function Base.convert(::Type{TracedRArray{T,N}}, x::AbstractArray) where {T,N} @assert ndims(x) == N if x isa TracedRArray @@ -49,90 +29,13 @@ end TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x) -materialize_traced_array(x::TracedRArray) = x -materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] -function materialize_traced_array( - x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray} -) where {T,N} - return Ops.reshape(materialize_traced_array(parent(x)), size(x)...) -end -function materialize_traced_array( - x::LinearAlgebra.Transpose{T,TracedRArray{T,N}} -) where {T,N} - px = parent(x) - A = ndims(px) == 1 ? reshape(px, :, 1) : px - return permutedims(A, (2, 1)) -end -function materialize_traced_array(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}) where {T,N} - return conj(materialize_traced_array(transpose(parent(x)))) -end -function materialize_traced_array( - x::PermutedDimsArray{T,N,perm,iperm,<:TracedRArray{T,N}} -) where {T,N,perm,iperm} - return permutedims(parent(x), perm) -end -function materialize_traced_array(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}) where {T} - return LinearAlgebra.diagm(parent(x)) -end - -get_mlir_data(x::TracedRArray) = x.mlir_data -get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) - -function set_mlir_data!(x::TracedRArray, data) - x.mlir_data = data - return x -end -function set_mlir_data!(x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray}, data) where {T,N} - res_mlir_data = Ops.reshape(TracedRArray(data), size(parent(x))...).mlir_data - set_mlir_data!(parent(x), res_mlir_data) - return x -end -function set_mlir_data!(x::LinearAlgebra.Transpose{T,TracedRArray{T,N}}, data) where {T,N} - tdata = TracedRArray(data) - px = parent(x) - px.mlir_data = ( - if ndims(px) == 1 - Ops.reshape(tdata, length(tdata)) - else - Ops.transpose(tdata, [2, 1]) - end - ).mlir_data - return x -end -function set_mlir_data!(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}, data) where {T,N} - tdata = TracedRArray(data) - px = parent(x) - transposed_data = - ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) - px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data - return x -end -function set_mlir_data!( - x::PermutedDimsArray{T,N,perm,iperm,TracedRArray{T,N}}, data -) where {T,N,perm,iperm} - parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data - return x -end -function set_mlir_data!(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}, data) where {T} - parent(x).mlir_data = LinearAlgebra.diag(TracedRArray(data)).mlir_data - return x -end -function set_mlir_data!(x::AnyTracedRArray, data) - setindex!(x, TracedRArray(data), axes(x)...) - return x -end - -get_ancestor_indices(::TracedRArray, indices...) = indices -function get_ancestor_indices(x::WrappedTracedRArray, indices...) - return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) -end function Base.getindex( a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N} ) where {T,N} GPUArraysCore.assertscalar("getindex(::TracedRArray, ::Vararg{Int, N})") - start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index] + start_indices = [TracedUtils.promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index] slice_sizes = [Int64(1) for _ in index] res1 = MLIR.IR.result( @@ -169,7 +72,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} end start_indices = map(indices) do i - return promote_to(TracedRNumber{Int}, first(i) - 1).mlir_data + return TracedUtils.promote_to(TracedRNumber{Int}, first(i) - 1).mlir_data end slice_sizes = [Int64(length(i)) for i in indices] res = MLIR.IR.result( @@ -199,15 +102,15 @@ function Base.setindex!( indices = map(enumerate(indices)) do (idx, i) i isa Int ? (i:i) : (i isa Colon ? (1:size(a, idx)) : i) end - v = broadcast_to_size(v, length.(indices)) - v = promote_to(TracedRArray{T,N}, v) + v = TracedUtils.broadcast_to_size(v, length.(indices)) + v = TracedUtils.promote_to(TracedRArray{T,N}, v) indices = [ - (promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for + (TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for i in indices ] res = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_update_slice( - a.mlir_data, get_mlir_data(v), indices + a.mlir_data, TracedUtils.get_mlir_data(v), indices ), 1, ) @@ -250,6 +153,7 @@ Base.conj(A::AnyTracedRArray) = A Base.conj(A::AnyTracedRArray{<:Complex}) = Ops.conj(materialize_traced_array(A)) Base.conj!(A::AnyTracedRArray) = A + function Base.conj!(A::AnyTracedRArray{<:Complex}) set_mlir_data!(A, Ops.conj(materialize_traced_array(A)).mlir_data) return A @@ -261,101 +165,8 @@ Base.real(A::AnyTracedRArray{<:Complex}) = Ops.real(materialize_traced_array(A)) Base.imag(A::AnyTracedRArray) = zero(A) Base.imag(A::AnyTracedRArray{<:Complex}) = Ops.imag(materialize_traced_array(A)) -promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs) - -promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs) - -elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x -function elem_apply( - ::Type{T}, x::TracedRArray{T2} -) where {T<:ReactantPrimitive,T2<:ReactantPrimitive} - # Special Path to prevent going down a despecialized path - return elem_apply(TypeCast{T}(), x) -end - -function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} - if all(iszero ∘ ndims, args) - scalar_args = map(args) do arg - return promote_to(TracedRNumber{eltype(arg)}, arg) - end - return f(scalar_args...) - end - - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( - f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true - ) - - invmap = IdDict() - for (k, v) in seen_args - invmap[v] = k - end - - keys_seen = [k for k in keys(seen_args) if k isa TracedType] - input_shapes = size.(keys_seen) - # by the time we reach here all args must have same size - @assert allequal(input_shapes) "input shapes are $(input_shapes)" - OutShape = isempty(seen_args) ? nothing : first(input_shapes) - @assert !isnothing(OutShape) - - in_tys2 = [mlir_type(invmap[arg]) for arg in linear_args] - - out_tys2 = [ - MLIR.IR.TensorType(OutShape, MLIR.IR.Type(eltype(arg))) for arg in linear_results - ] - - fname = get_attribute_by_name(func2, "sym_name") - fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) - - batch_inputs = MLIR.IR.Value[] - - for a in linear_args - idx, path = get_argidx(a) - if idx == 1 && fnwrap - push_val!(batch_inputs, f, path[3:end]) - else - if fnwrap - idx -= 1 - end - push_val!(batch_inputs, args[idx], path[3:end]) - end - end - - res = MLIR.Dialects.enzyme.batch( - batch_inputs; - outputs=out_tys2, - fn=fname, - batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in OutShape]), - ) - - residx = 1 - - for a in linear_results - if has_residx(a) - path = get_residx(a) - set!(result, path[2:end], MLIR.IR.result(res, residx)) - residx += 1 - else - idx, path = get_argidx(a) - if idx == 1 && fnwrap - set!(f, path[3:end], MLIR.IR.result(res, residx)) - residx += 1 - else - if fnwrap - idx -= 1 - end - set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) - residx += 1 - end - end - end - - seen_results = OrderedIdDict() - traced2_result = make_tracer(seen_results, result, (), TracedSetPath; tobatch=OutShape) - - func2.operation = MLIR.API.MlirOperation(C_NULL) - - return traced2_result -end +TracedUtils.promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs) +TracedUtils.promote_to(::TracedRArray{T,N}, rhs) where {T,N} = TracedUtils.promote_to(TracedRArray{T,N}, rhs) for (jlop, hloop, hlocomp, merge) in ((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any)) @@ -367,21 +178,6 @@ for (jlop, hloop, hlocomp, merge) in end end -function Enzyme.Compiler.active_reg_inner( - ::Type{TracedRArray{T,N}}, - seen::ST, - world::Union{Nothing,UInt}, - ::Val{justActive}=Val(false), - ::Val{UnionSret}=Val(false), -)::Enzyme.Compiler.ActivityState where {ST,T,N,justActive,UnionSret} - if Enzyme.Compiler.active_reg_inner(T, seen, world, Val(justActive), Val(UnionSret)) == - Enzyme.Compiler.AnyState - return Enzyme.Compiler.AnyState - else - return Enzyme.Compiler.DupState - end -end - function Base.mapreduce( @nospecialize(f), @nospecialize(op), @@ -409,7 +205,7 @@ function Base.mapreduce( init = init::T end - init = [broadcast_to_size(init, ()).mlir_data] + init = [TracedUtils.broadcast_to_size(init, ()).mlir_data] inp = [broadcast(f, A).mlir_data] @@ -431,7 +227,7 @@ function Base.mapreduce( ) res = MLIR.IR.block!(fnbody) do - tmp = broadcast_to_size(op(args...), ()).mlir_data + tmp = TracedUtils.broadcast_to_size(op(args...), ()).mlir_data MLIR.Dialects.stablehlo.return_(MLIR.IR.Value[tmp]) return tmp end @@ -471,19 +267,19 @@ function Base.mapreducedim!( @nospecialize(R::TracedRArray), A::Base.AbstractArrayOrBroadcasted, ) - tmp = broadcast_to_size(Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...)) + tmp = TracedUtils.broadcast_to_size(Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...)) R.mlir_data = broadcast(op, R, tmp).mlir_data return R end function Base.fill!(A::TracedRArray{T,N}, x) where {T,N} - bcast = broadcast_to_size(T(x), size(A)) + bcast = TracedUtils.broadcast_to_size(T(x), size(A)) A.mlir_data = bcast.mlir_data return A end function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2} - bcast = broadcast_to_size(promote_to(TracedRNumber{T}, x), size(A)) + bcast = TracedUtils.broadcast_to_size(TracedUtils.promote_to(TracedRNumber{T}, x), size(A)) A.mlir_data = bcast.mlir_data return A end @@ -517,7 +313,7 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle{0}}) return dest[CartesianIndex()] # 0D broadcast needs to unwrap results end -Base.eltype(::Broadcast.Extruded{T}) where {T} = eltype(T) +# Base.eltype(::Broadcast.Extruded{T}) where {T} = eltype(T) # we need to override the outer copy method to make sure we never fall back to scalar # iteration (see, e.g., CUDA.jl#145) @@ -560,77 +356,16 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T, return dest end -broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize) - -function broadcast_to_size(arg::Base.RefValue, rsize) - # XXX: don't we want to expand here to rsize? - return arg -end - -broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize))) - -function broadcast_to_size(arg::TracedRNumber, rsize) - length(rsize) == 0 && return arg - return broadcast_to_size_internal( - TracedRArray{eltype(arg),0}((), arg.mlir_data, ()), rsize - ) -end - -function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T} - arg = materialize_traced_array(arg) - return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) -end - -function broadcast_to_size(arg::AnyTracedRArray, rsize) - arg = materialize_traced_array(arg) - size(arg) == Tuple(rsize) && return arg - return broadcast_to_size_internal(arg, rsize) -end - -function broadcast_to_size(arg::Broadcast.Extruded, rsize) - rsize2 = (keep ? rsizev : 1 for (keep, rsizev) in zip(arg.keeps, rsize)) - x = broadcast_to_size(arg.x, rsize2) - size(x) == rsize && return x - return broadcast_to_size_internal(x, rsize) -end - -function broadcast_to_size_internal(x::TracedRArray, rsize) - dims = collect(Int64, 0:(length(size(x)) - 1)) - - if length(size(MLIR.IR.type(x.mlir_data))) != length(dims) - @show x - @show arg - @show rsize - @show rsize2 - @show dims - end - @assert length(size(MLIR.IR.type(x.mlir_data))) == length(dims) - mlirty = MLIR.IR.type(x.mlir_data) - - return TracedRArray{eltype(x),Int(length(rsize))}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.broadcast_in_dim( - x.mlir_data; - result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), - broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), - ), - 1, - ), - collect(rsize), - ) -end - function _copyto!(dest::AnyTracedRArray, bc::Broadcasted) axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc)) isempty(dest) && return dest bc = Broadcast.preprocess(dest, bc) - args = (broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) + args = (TracedUtils.broadcast_to_size(Base.materialize(a), size(bc)) for a in bc.args) - res = elem_apply(bc.f, args...) - set_mlir_data!(dest, res.mlir_data) + res = TracedUtils.elem_apply(bc.f, args...) + TracedUtils.set_mlir_data!(dest, res.mlir_data) return dest end @@ -642,6 +377,7 @@ dispatch_val(::Val{D}) where {D} = D ) where {T} return Base._cat_t(Val(1), T, X...) end + @inline function Base._typed_hcat( ::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray} ) where {T} @@ -684,6 +420,12 @@ function Base._typed_hvncat( return only(As) end +function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N} + dims = dispatch_val(dims) + dims ≤ N && return x + return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, dims)) +end + function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} dims = dispatch_val(dims) @assert dims isa Integer "Support for non-integer dimensions is not implemented yet." @@ -696,7 +438,7 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} RT = Base.promote_eltype(T, X...) # convert to the target eltype - X = map(Base.Fix1(promote_to, TracedRArray{RT,length(shape)}), X) + X = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{RT,length(shape)}), X) return TracedRArray{RT,length(shape)}( (), @@ -713,12 +455,6 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} ) end -function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N} - dims = dispatch_val(dims) - dims ≤ N && return x - return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, dims)) -end - for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) @eval function Base.clamp!(x::AnyTracedRArray, min::$(minT), max::$(maxT)) y = Ops.clamp(min, materialize_traced_array(x), max) @@ -731,6 +467,7 @@ Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x) Base.any(f::Function, x::AnyTracedRArray) = mapreduce(f, |, x) # outer repeat +# Overridden because we don't need to further recur into the definitions here function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,M} P = max(N, M) # potentially padded @@ -744,7 +481,7 @@ function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N, broadcast_target_size = interleaved_size broadcast_target_size[2:2:(2M)] .= counts - x_broadcasted = broadcast_to_size(x_interleaved, broadcast_target_size) + x_broadcasted = TracedUtils.broadcast_to_size(x_interleaved, broadcast_target_size) # (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP) final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1)) @@ -753,3 +490,5 @@ function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N, return x_final end + +end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index dc7a7ec2a..d9505549b 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -1,36 +1,27 @@ -mutable struct TracedRNumber{T} <: RNumber{T} - paths::Tuple - mlir_data::Union{Nothing,MLIR.IR.Value} +module TracedRNumberOverrides - function TracedRNumber{T}( - paths::Tuple, mlir_data::Union{Nothing,MLIR.IR.Value} - ) where {T} - if !isnothing(mlir_data) - @assert size(MLIR.IR.type(mlir_data)) == () - end - return new{T}(paths, mlir_data) - end -end - -get_mlir_data(x::TracedRNumber) = x.mlir_data -set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) +import ..TracedRNumber +import ..TracedRArray +import ..ReactantPrimitive +using ..TracedUtils +import ..Ops +import ..MLIR +using ReactantCore ReactantCore.is_traced(::TracedRNumber) = true -new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) - Base.eltype(::Type{TracedRNumber{T}}) where {T} = T Base.getindex(a::TracedRNumber{T}) where {T} = a -Base.zero(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, zero(T)) -Base.one(::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{T}, one(T)) +Base.zero(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, zero(T)) +Base.one(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, one(T)) Base.collect(x::TracedRNumber{T}) where {T} = TracedRArray{T,0}((), x.mlir_data, ()) -Base.eps(::Type{TracedRNumber{T}}) where {T} = promote_to(TracedRNumber{T}, eps(T)) +Base.eps(::Type{TracedRNumber{T}}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, eps(T)) function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T} - return promote_to(TracedRNumber{T}, T(x)) + return TracedUtils.promote_to(TracedRNumber{T}, T(x)) end function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}} @@ -57,27 +48,28 @@ function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S} end function Base.convert(::Type{TracedRNumber{T}}, x::Number) where {T} - return promote_to(TracedRNumber{T}, x) + return TracedUtils.promote_to(TracedRNumber{T}, x) end TracedRNumber{T}(x::TracedRNumber{T}) where {T} = x + function TracedRNumber{T}(x::Number) where {T} - return promote_to(TracedRNumber{T}, x) + return TracedUtils.promote_to(TracedRNumber{T}, x) end -function promote_to(::Type{TracedRNumber{T}}, rhs) where {T} +function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T} if rhs isa TracedRNumber rhs isa TracedRNumber{T} && return rhs return Ops.convert(TracedRNumber{T}, rhs) end if rhs isa TracedRArray{<:Any,0} - return promote_to(TracedRNumber{T}, TracedRNumber{eltype(rhs)}((), rhs.mlir_data)) + return TracedUtils.promote_to(TracedRNumber{T}, TracedRNumber{eltype(rhs)}((), rhs.mlir_data)) end - rhs isa Number && return promote_to(TracedRNumber{T}, Ops.constant(fill(T(rhs)))) - return promote_to(TracedRNumber{T}, Ops.constant(collect(rhs))) + rhs isa Number && return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(fill(T(rhs)))) + return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(collect(rhs))) end -promote_to(::TracedRNumber{T}, rhs) where {T} = promote_to(TracedRNumber{T}, rhs) +TracedUtils.promote_to(::TracedRNumber{T}, rhs) where {T} = TracedUtils.promote_to(TracedRNumber{T}, rhs) for (jlop, hloop) in ( (:(Base.min), :minimum), @@ -98,7 +90,7 @@ end function Base.div( @nospecialize(lhs::TracedRNumber{T}), rhs, ::typeof(RoundDown) ) where {T<:Integer} - return Ops.divide(lhs, promote_to(TracedRNumber{T}, rhs)) + return Ops.divide(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs)) end for (jlop, hloop, hlocomp) in ( @@ -117,29 +109,29 @@ for (jlop, hloop, hlocomp) in ( end function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T} - return $(jlop)(lhs, promote_to(lhs, rhs)) + return $(jlop)(lhs, TracedUtils.promote_to(lhs, rhs)) end function $(jlop)( @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number) ) where {T} - return $(jlop)(lhs, promote_to(lhs, rhs)) + return $(jlop)(lhs, TracedUtils.promote_to(lhs, rhs)) end function $(jlop)(@nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T})) where {T} - return $(jlop)(promote_to(rhs, lhs), rhs) + return $(jlop)(TracedUtils.promote_to(rhs, lhs), rhs) end function $(jlop)( @nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T}) ) where {T} - return $(jlop)(promote_to(rhs, lhs), rhs) + return $(jlop)(TracedUtils.promote_to(rhs, lhs), rhs) end function $(jlop)( @nospecialize(lhs::TracedRNumber{T1}), @nospecialize(rhs::TracedRNumber{T2}) ) where {T1,T2} commonTy = TracedRNumber{Base.promote_type(T1, T2)} - lhs = promote_to(commonTy, lhs) - rhs = promote_to(commonTy, rhs) + lhs = TracedUtils.promote_to(commonTy, lhs) + rhs = TracedUtils.promote_to(commonTy, rhs) return $(jlop)(lhs, rhs) end end @@ -154,7 +146,7 @@ function Base.ifelse( element-type to the common type. This is semantically different from the \ behavior of `ifelse` in Base. Use with caution" maxlog = 1 T = promote_type(T1, T2) - return ifelse(pred, promote_to(TracedRNumber{T}, x), promote_to(TracedRNumber{T}, y)) + return ifelse(pred, TracedUtils.promote_to(TracedRNumber{T}, x), TracedUtils.promote_to(TracedRNumber{T}, y)) end function Base.ifelse( @@ -170,12 +162,12 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) @eval begin function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) return Ops.and( - promote_to(TracedRNumber{$(T)}, x), promote_to(TracedRNumber{$(T)}, y) + TracedUtils.promote_to(TracedRNumber{$(T)}, x), TracedUtils.promote_to(TracedRNumber{$(T)}, y) ) end function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) return Ops.or( - promote_to(TracedRNumber{$(T)}, x), promote_to(TracedRNumber{$(T)}, y) + TracedUtils.promote_to(TracedRNumber{$(T)}, x), TracedUtils.promote_to(TracedRNumber{$(T)}, y) ) end Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x) @@ -222,13 +214,13 @@ end struct TypeCast{T<:ReactantPrimitive} <: Function end -(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x) +(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = TracedUtils.promote_to(TracedRNumber{T}, x) function Base.fill(x::TracedRNumber, dims::NTuple{N,Integer}) where {N} return Reactant.broadcast_to_size(x, dims) end -Base.float(x::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{float(T)}, x) +Base.float(x::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{float(T)}, x) # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) @@ -260,3 +252,5 @@ function Base.typed_hvncat( xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs) return Base.typed_hvncat(T, dims, row_first, xs...) end + +end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl new file mode 100644 index 000000000..e18a819bd --- /dev/null +++ b/src/TracedUtils.jl @@ -0,0 +1,511 @@ +# Functions within this module and Ops do not get forcibly re-compiled to be within our interpreter. +# This means that replacements, for example, for autodiff/random/kernels/etc do not get applied here when +# within compilation. However, it means these functions are a _lot_ faster to compile. +module TracedUtils + +import LinearAlgebra +import Adapt +using ..Reactant: RArray, RNumber, TracedRArray, TracedRNumber, WrappedTracedRArray, AnyTracedRArray, MissingTracedValue, OrderedIdDict +import ..Reactant +import ..Reactant.MLIR +import ..ReactantPrimitive +import ..Ops + +materialize_traced_array(x::TracedRArray) = x +materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...] +function materialize_traced_array( + x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray} +) where {T,N} + return Ops.reshape(materialize_traced_array(parent(x)), size(x)...) +end +function materialize_traced_array( + x::LinearAlgebra.Transpose{T,TracedRArray{T,N}} +) where {T,N} + px = parent(x) + A = ndims(px) == 1 ? reshape(px, :, 1) : px + return permutedims(A, (2, 1)) +end +function materialize_traced_array(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}) where {T,N} + return conj(materialize_traced_array(transpose(parent(x)))) +end +function materialize_traced_array( + x::PermutedDimsArray{T,N,perm,iperm,<:TracedRArray{T,N}} +) where {T,N,perm,iperm} + return permutedims(parent(x), perm) +end +function materialize_traced_array(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}) where {T} + return LinearAlgebra.diagm(parent(x)) +end + +get_mlir_data(x::TracedRNumber) = x.mlir_data +set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) + +get_mlir_data(x::TracedRArray) = x.mlir_data +get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) + +function set_mlir_data!(x::TracedRArray, data) + x.mlir_data = data + return x +end +function set_mlir_data!(x::Adapt.WrappedReshapedArray{T,N,<:TracedRArray}, data) where {T,N} + res_mlir_data = Ops.reshape(TracedRArray(data), size(parent(x))...).mlir_data + set_mlir_data!(parent(x), res_mlir_data) + return x +end +function set_mlir_data!(x::LinearAlgebra.Transpose{T,TracedRArray{T,N}}, data) where {T,N} + tdata = TracedRArray(data) + px = parent(x) + px.mlir_data = ( + if ndims(px) == 1 + Ops.reshape(tdata, length(tdata)) + else + Ops.transpose(tdata, [2, 1]) + end + ).mlir_data + return x +end +function set_mlir_data!(x::LinearAlgebra.Adjoint{T,TracedRArray{T,N}}, data) where {T,N} + tdata = TracedRArray(data) + px = parent(x) + transposed_data = + ndims(px) == 1 ? Ops.reshape(tdata, length(tdata)) : Ops.transpose(tdata, [2, 1]) + px.mlir_data = (T <: Real ? transposed_data : Ops.conj(transposed_data)).mlir_data + return x +end +function set_mlir_data!( + x::PermutedDimsArray{T,N,perm,iperm,TracedRArray{T,N}}, data +) where {T,N,perm,iperm} + parent(x).mlir_data = permutedims(TracedRArray(data), iperm).mlir_data + return x +end +function set_mlir_data!(x::LinearAlgebra.Diagonal{T,TracedRArray{T,1}}, data) where {T} + parent(x).mlir_data = LinearAlgebra.diag(TracedRArray(data)).mlir_data + return x +end +function set_mlir_data!(x::AnyTracedRArray, data) + setindex!(x, TracedRArray(data), axes(x)...) + return x +end + +get_ancestor_indices(::TracedRArray, indices...) = indices +function get_ancestor_indices(x::WrappedTracedRArray, indices...) + return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) +end + + +function batch_ty(width, mlirty) + return MLIR.IR.TensorType([width, size(mlirty)...], eltype(mlirty)) +end + +function transpose_ty(mlirty) + return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty)) +end +function transpose_val(val) + attr = MLIR.IR.DenseArrayAttribute( + Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...] + ) + return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) +end + +function make_mlir_fn( + f, + args, + kwargs, + name="main", + concretein=true; + toscalar=false, + return_dialect=:func, + no_args_in_result::Bool=false, + construct_function_without_args::Bool=false, + do_transpose=true, +) + if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction + return ( + true, + make_mlir_fn( + apply, + (f, args...), + kwargs, + name, + concretein; + toscalar, + return_dialect, + no_args_in_result, + construct_function_without_args, + do_transpose, + )[2:end]..., + ) + end + + N = length(args) + seen_args = OrderedIdDict() + traced_args = ntuple(N) do i + return Reactant.make_tracer( + seen_args, + args[i], + (:args, i), + concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath; + toscalar, + track_numbers=construct_function_without_args ? (Number,) : (), + ) + end + + linear_args = Reactant.TracedType[] + for (k, v) in seen_args + v isa Reactant.TracedType || continue + push!(linear_args, v) + end + + in_tys = if toscalar + [MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args] + elseif do_transpose + [transpose_ty(Ops.mlir_type(arg)) for arg in linear_args] + else + [Ops.mlir_type(arg) for arg in linear_args] + end + + sym_visibility = nothing + if !concretein + sym_visibility = MLIR.IR.Attribute("private") + end + + mod = MLIR.IR.mmodule() + func = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=name * "_tmp", + function_type=MLIR.IR.FunctionType(in_tys, []), + body=MLIR.IR.Region(), + ) + end + + if construct_function_without_args + fnbody = MLIR.IR.Block() + else + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) + end + push!(MLIR.IR.region(func, 1), fnbody) + + @assert MLIR.IR._has_block() + + result = MLIR.IR.block!(fnbody) do + for (i, arg) in enumerate(linear_args) + if construct_function_without_args + arg.mlir_data = args[i].mlir_data + else + raw_arg = MLIR.IR.argument(fnbody, i) + row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg + arg.mlir_data = row_maj_arg + end + end + + # TODO fix it for kwargs + #if concretein + Reactant.call_with_reactant(f, traced_args...) + #else + # f(traced_args...) + #end + end + + seen_results = OrderedIdDict() + + traced_result = Reactant.make_tracer( + seen_results, + result, + (:result,), + concretein ? Reactant.TracedTrack : Reactant.TracedSetPath; + track_numbers=construct_function_without_args ? (Number,) : (), + ) + + # marks buffers to be donated + for i in 1:N + Reactant.make_tracer( + seen_results, traced_args[i], concretein ? (:resargs, i) : (), Reactant.TracedTrack + ) + end + + linear_results = Reactant.TracedType[] + + for (k, v) in seen_results + v isa Reactant.TracedType || continue + (no_args_in_result && length(v.paths) > 0 && v.paths[1][1] == :args) && continue + push!(linear_results, v) + end + + out_tys = [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results] + + ret = MLIR.IR.block!(fnbody) do + vals = MLIR.IR.Value[] + for res in linear_results + col_maj = if res isa MissingTracedValue + broadcast_to_size(false, ()).mlir_data + elseif construct_function_without_args || !do_transpose + res.mlir_data + elseif do_transpose + transpose_val(res.mlir_data) + end + push!(vals, col_maj) + end + !no_args_in_result && @assert length(vals) == length(linear_results) + + dialect = getfield(MLIR.Dialects, return_dialect) + return dialect.return_(vals) + end + + name2 = name + + tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)) + for i in 0:10000 + name2 = if i == 0 + name + else + name * string(i) + end + if MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, name2)) + break + end + end + + func2 = MLIR.IR.block!(MLIR.IR.body(mod)) do + return MLIR.Dialects.func.func_(; + sym_name=name2, + function_type=MLIR.IR.FunctionType(in_tys, out_tys), + body=MLIR.IR.Region(), + sym_visibility, + ) + end + MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func2, 1), MLIR.IR.region(func, 1)) + + MLIR.API.mlirOperationDestroy(func.operation) + func.operation = MLIR.API.MlirOperation(C_NULL) + return ( + false, + func2, + traced_result, + result, + seen_args, + ret, + linear_args, + in_tys, + linear_results, + ) +end + +elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x +function elem_apply( + ::Type{T}, x::TracedRArray{T2} +) where {T<:ReactantPrimitive,T2<:ReactantPrimitive} + # Special Path to prevent going down a despecialized path + return elem_apply(TypeCast{T}(), x) +end + +function promote_to end + +function get_attribute_by_name(operation, name) + return MLIR.IR.Attribute(MLIR.API.mlirOperationGetAttributeByName(operation, name)) +end + +function push_val!(ad_inputs, x, path) + for p in path + x = traced_getfield(x, p) + end + x = x.mlir_data + return push!(ad_inputs, x) +end + +function get_argidx(x) + for path in x.paths + if length(path) == 0 + continue + end + if path[1] == :args + return path[2]::Int, path + end + end + throw(AssertionError("No path found for $x")) +end + +function set!(x, path, tostore; emptypath=false) + for p in path + x = traced_getfield(x, p) + end + + x.mlir_data = tostore + + if emptypath + x.paths = () + end +end + +function get_residx(x) + for path in x.paths + if length(path) == 0 + continue + end + if path[1] == :result + return path + end + end + throw(AssertionError("No path found $x")) +end + +function has_residx(x) + for path in x.paths + if length(path) == 0 + continue + end + if path[1] == :result + return true + end + end + return false +end + +function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} + if all(iszero ∘ ndims, args) + scalar_args = map(args) do arg + return promote_to(TracedRNumber{eltype(arg)}, arg) + end + return f(scalar_args...) + end + + fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( + f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true + ) + + invmap = IdDict() + for (k, v) in seen_args + invmap[v] = k + end + + keys_seen = [k for k in keys(seen_args) if k isa Reactant.TracedType] + input_shapes = size.(keys_seen) + # by the time we reach here all args must have same size + @assert allequal(input_shapes) "input shapes are $(input_shapes)" + OutShape = isempty(seen_args) ? nothing : first(input_shapes) + @assert !isnothing(OutShape) + + in_tys2 = [Ops.mlir_type(invmap[arg]) for arg in linear_args] + + out_tys2 = [ + MLIR.IR.TensorType(OutShape, MLIR.IR.Type(eltype(arg))) for arg in linear_results + ] + + fname = get_attribute_by_name(func2, "sym_name") + fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)) + + batch_inputs = MLIR.IR.Value[] + + for a in linear_args + idx, path = TracedUtils.get_argidx(a) + if idx == 1 && fnwrap + push_val!(batch_inputs, f, path[3:end]) + else + if fnwrap + idx -= 1 + end + push_val!(batch_inputs, args[idx], path[3:end]) + end + end + + res = MLIR.Dialects.enzyme.batch( + batch_inputs; + outputs=out_tys2, + fn=fname, + batch_shape=MLIR.IR.DenseArrayAttribute([Int64(i) for i in OutShape]), + ) + + residx = 1 + + for a in linear_results + if TracedUtils.has_residx(a) + path = TracedUtils.get_residx(a) + TracedUtils.set!(result, path[2:end], MLIR.IR.result(res, residx)) + residx += 1 + else + idx, path = TracedUtils.get_argidx(a) + if idx == 1 && fnwrap + TracedUtils.set!(f, path[3:end], MLIR.IR.result(res, residx)) + residx += 1 + else + if fnwrap + idx -= 1 + end + TracedUtils.set!(args[idx], path[3:end], MLIR.IR.result(res, residx)) + residx += 1 + end + end + end + + seen_results = OrderedIdDict() + traced2_result = Reactant.make_tracer(seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape) + + func2.operation = MLIR.API.MlirOperation(C_NULL) + + return traced2_result +end + +new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A)) +new_traced_value(::TracedRNumber{T}) where {T} = TracedRNumber{T}((), nothing) + +broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize) + +function broadcast_to_size(arg::Base.RefValue, rsize) + # XXX: don't we want to expand here to rsize? + return arg +end + +broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize))) + +function broadcast_to_size(arg::TracedRNumber, rsize) + length(rsize) == 0 && return arg + return broadcast_to_size_internal( + TracedRArray{eltype(arg),0}((), arg.mlir_data, ()), rsize + ) +end + +function broadcast_to_size(arg::AnyTracedRArray{T,0}, rsize) where {T} + arg = materialize_traced_array(arg) + return broadcast_to_size(TracedRNumber{T}((), arg.mlir_data), rsize) +end + +function broadcast_to_size(arg::AnyTracedRArray, rsize) + arg = materialize_traced_array(arg) + size(arg) == Tuple(rsize) && return arg + return broadcast_to_size_internal(arg, rsize) +end + +function broadcast_to_size(arg::Broadcast.Extruded, rsize) + rsize2 = (keep ? rsizev : 1 for (keep, rsizev) in zip(arg.keeps, rsize)) + x = broadcast_to_size(arg.x, rsize2) + size(x) == rsize && return x + return broadcast_to_size_internal(x, rsize) +end + +function broadcast_to_size_internal(x::TracedRArray, rsize) + dims = collect(Int64, 0:(length(size(x)) - 1)) + + if length(size(MLIR.IR.type(x.mlir_data))) != length(dims) + @show x + @show arg + @show rsize + @show rsize2 + @show dims + end + @assert length(size(MLIR.IR.type(x.mlir_data))) == length(dims) + mlirty = MLIR.IR.type(x.mlir_data) + + return TracedRArray{eltype(x),Int(length(rsize))}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.broadcast_in_dim( + x.mlir_data; + result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), + broadcast_dimensions=MLIR.IR.DenseArrayAttribute(dims), + ), + 1, + ), + collect(rsize), + ) +end + +end diff --git a/src/utils.jl b/src/utils.jl index f26fe13cf..a1450e7b3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,37 +1,3 @@ -function mlir_type(x::RArray{T,N}) where {T,N} - return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) -end - -mlir_type(::RNumber{T}) where {T} = MLIR.IR.TensorType((), MLIR.IR.Type(T)) - -mlir_type(::MissingTracedValue) = MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) - -function mlir_type(::Type{<:RArray{T,N}}, shape) where {T,N} - @assert length(shape) == N - return MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) -end - -function mlir_type(::Type{<:RNumber{T}}) where {T} - return MLIR.IR.TensorType((), MLIR.IR.Type(T)) -end - -function mlir_type(::Type{<:MissingTracedValue}) - return MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) -end - -function batch_ty(width, mlirty) - return MLIR.IR.TensorType([width, size(mlirty)...], eltype(mlirty)) -end - -function transpose_ty(mlirty) - return MLIR.IR.TensorType([reverse(size(mlirty))...], eltype(mlirty)) -end -function transpose_val(val) - attr = MLIR.IR.DenseArrayAttribute( - Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...] - ) - return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) -end function apply(f, args...; kwargs...) return f(args...; kwargs...) @@ -39,13 +5,6 @@ end function call_with_reactant end -# generate a LineInfoNode for the current source code location -macro LineInfoNode(method) - return Core.LineInfoNode( - __module__, method, __source__.file, Int32(__source__.line), Int32(0) - ) -end - function maybe_argextype(@nospecialize(x), src) return try Core.Compiler.argextype(x, src) @@ -55,22 +14,6 @@ function maybe_argextype(@nospecialize(x), src) end end -function rewrite_inst(inst, ir) - if Meta.isexpr(inst, :call) - # Even if type unstable we do not want (or need) to replace intrinsic - # calls or builtins with our version. - ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) - if !(ft <: Core.IntrinsicFunction) && !(ft <: Core.Builtin) - rep = Expr(:call, call_with_reactant, inst.args...) - return true, rep - end - end - if Meta.isexpr(inst, :invoke) - # return false, Expr(:call, inst.args[2:end]...) - end - return false, inst -end - """ Reactant.REDUB_ARGUMENTS_NAME @@ -113,33 +56,97 @@ end end end +function has_ancestor(query::Module, target::Module) + query == target && return true + while true + next = parentmodule(query) + next == target && return true + next == query && return false + query = next + end +end -# HACK: in all versions of Julia, `jl_new_opaque_closure_from_code_info` doesn't take a world argument -# but instead always generates code for the current world. note that this doesn't -# actually change the world age, but just spoofs the counter `jl_create_native` reads. -# XXX: Base.get_world_counter is supposed to be monotonically increasing and is runtime global. -macro in_world(world, ex) - quote - actual_world = Base.get_world_counter() - world_counter = cglobal(:jl_world_counter, Csize_t) - unsafe_store!(world_counter, $(esc(world))) - try - $(esc(ex)) - finally - unsafe_store!(world_counter, actual_world) +function should_rewrite_ft(@nospecialize(ft)) + # Don't rewrite builtin or intrinsics + if ft <: Core.IntrinsicFunction || ft <: Core.Builtin + return false + end + if ft <: Core.Function + mod = ft.name.module + # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions + if has_ancestor(mod, Reactant.Ops) || has_ancestor(mod, Reactant.TracedUtils) || has_ancestor(mod, Reactant.MLIR) + return false end end + # Don't rewrite Val + if ft === Type{Base.Val} + return false + end + # Don't rewrite exception constructors + if ft <: Type{<:Core.Exception} + return false + end + + # Default assume all functions need to be reactant-ified + return true end -#define jl_current_task (container_of(jl_get_pgcstack(), jl_task_t, gcstack)) +# Avoid recursively interpreting into methods we define explicitly +# as overloads, which we assume should handle the entirety of the +# translation (and if not they can use call_in_reactant). +function is_reactant_method(mi::Core.MethodInstance) + meth = mi.def + if !isdefined(meth, :external_mt) + return false + end + mt = meth.external_mt + return mt === REACTANT_METHOD_TABLE +end +function rewrite_inst(inst, ir, interp) + if Meta.isexpr(inst, :call) + # Even if type unstable we do not want (or need) to replace intrinsic + # calls or builtins with our version. + ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) + if should_rewrite_ft(ft) + rep = Expr(:call, call_with_reactant, inst.args...) + return true, rep + end + end + if Meta.isexpr(inst, :invoke) + omi = inst.args[1]::Core.MethodInstance + sig = omi.specTypes + ft = sig.parameters[1] + + if should_rewrite_ft(ft) && !is_reactant_method(omi) + + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + + lookup_result = lookup_world(Tuple{typeof(call_with_reactant), sig.parameters...}, interp.world, Core.Compiler.method_table(interp), min_world, max_world) + + match = lookup_result::Core.MethodMatch + # look up the method and code instance + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, + ) + rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...) + return true, rep + end + end + return false, inst +end function make_oc(sig, rt, src, nargs, isva, f)::Core.OpaqueClosure ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), sig, rt, rt, @__MODULE__, src, 0, nothing, nargs, isva, f, true)::Core.OpaqueClosure end - # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -284,7 +291,7 @@ function call_with_reactant_generator( ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end - + # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type # to Any if our interpreter would change the return type of any result. @@ -293,10 +300,10 @@ function call_with_reactant_generator( any_changed = false for (i, inst) in enumerate(ir.stmts) @static if VERSION < v"1.11" - changed, next = rewrite_inst(inst[:inst], ir) + changed, next = rewrite_inst(inst[:inst], ir, interp) Core.Compiler.setindex!(ir.stmts[i], next, :inst) else - changed, next = rewrite_inst(inst[:stmt], ir) + changed, next = rewrite_inst(inst[:stmt], ir, interp) Core.Compiler.setindex!(ir.stmts[i], next, :stmt) end if changed @@ -307,7 +314,7 @@ function call_with_reactant_generator( Core.Compiler.finish(interp, opt, ir, caller) src = Core.Compiler.ir_to_codeinf!(opt) - + # prepare a new code info code_info = copy(src) static_params = match.sparams @@ -347,6 +354,7 @@ function call_with_reactant_generator( if method.isva iter_args = min(n_actual_args, n_method_args-1) end + for i in 1:iter_args actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset @@ -443,7 +451,7 @@ function call_with_reactant_generator( code_info.codelocs = overdubbed_codelocs code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - + return code_info end @@ -451,215 +459,3 @@ end $(Expr(:meta, :generated_only)) return $(Expr(:meta, :generated, call_with_reactant_generator)) end - -function make_mlir_fn( - f, - args, - kwargs, - name="main", - concretein=true; - toscalar=false, - return_dialect=:func, - no_args_in_result::Bool=false, - construct_function_without_args::Bool=false, - do_transpose=true, -) - if sizeof(typeof(f)) != 0 || f isa BroadcastFunction - return ( - true, - make_mlir_fn( - apply, - (f, args...), - kwargs, - name, - concretein; - toscalar, - return_dialect, - no_args_in_result, - construct_function_without_args, - do_transpose, - )[2:end]..., - ) - end - - N = length(args) - seen_args = OrderedIdDict() - traced_args = ntuple(N) do i - return make_tracer( - seen_args, - args[i], - (:args, i), - concretein ? ConcreteToTraced : TracedSetPath; - toscalar, - track_numbers=construct_function_without_args ? (Number,) : (), - ) - end - - linear_args = TracedType[] - for (k, v) in seen_args - v isa TracedType || continue - push!(linear_args, v) - end - - in_tys = if toscalar - [MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args] - elseif do_transpose - [transpose_ty(mlir_type(arg)) for arg in linear_args] - else - [mlir_type(arg) for arg in linear_args] - end - - sym_visibility = nothing - if !concretein - sym_visibility = MLIR.IR.Attribute("private") - end - - mod = MLIR.IR.mmodule() - func = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; - sym_name=name * "_tmp", - function_type=MLIR.IR.FunctionType(in_tys, []), - body=MLIR.IR.Region(), - ) - end - - if construct_function_without_args - fnbody = MLIR.IR.Block() - else - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args]) - end - push!(MLIR.IR.region(func, 1), fnbody) - - @assert MLIR.IR._has_block() - - result = MLIR.IR.block!(fnbody) do - for (i, arg) in enumerate(linear_args) - if construct_function_without_args - arg.mlir_data = args[i].mlir_data - else - raw_arg = MLIR.IR.argument(fnbody, i) - row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg - arg.mlir_data = row_maj_arg - end - end - - # TODO fix it for kwargs - if concretein - call_with_reactant(f, traced_args...) - else - f(traced_args...) - end - end - - seen_results = OrderedIdDict() - - traced_result = make_tracer( - seen_results, - result, - (:result,), - concretein ? TracedTrack : TracedSetPath; - track_numbers=construct_function_without_args ? (Number,) : (), - ) - - # marks buffers to be donated - for i in 1:N - make_tracer( - seen_results, traced_args[i], concretein ? (:resargs, i) : (), TracedTrack - ) - end - - linear_results = TracedType[] - - for (k, v) in seen_results - v isa TracedType || continue - (no_args_in_result && length(v.paths) > 0 && v.paths[1][1] == :args) && continue - push!(linear_results, v) - end - - out_tys = [transpose_ty(mlir_type(arg)) for arg in linear_results] - - ret = MLIR.IR.block!(fnbody) do - vals = MLIR.IR.Value[] - for res in linear_results - col_maj = if res isa MissingTracedValue - broadcast_to_size(false, ()).mlir_data - elseif construct_function_without_args || !do_transpose - res.mlir_data - elseif do_transpose - transpose_val(res.mlir_data) - end - push!(vals, col_maj) - end - !no_args_in_result && @assert length(vals) == length(linear_results) - - dialect = getfield(MLIR.Dialects, return_dialect) - return dialect.return_(vals) - end - - name2 = name - - tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)) - for i in 0:10000 - name2 = if i == 0 - name - else - name * string(i) - end - if MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, name2)) - break - end - end - - func2 = MLIR.IR.block!(MLIR.IR.body(mod)) do - return MLIR.Dialects.func.func_(; - sym_name=name2, - function_type=MLIR.IR.FunctionType(in_tys, out_tys), - body=MLIR.IR.Region(), - sym_visibility, - ) - end - MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func2, 1), MLIR.IR.region(func, 1)) - - MLIR.API.mlirOperationDestroy(func.operation) - func.operation = MLIR.API.MlirOperation(C_NULL) - return ( - false, - func2, - traced_result, - result, - seen_args, - ret, - linear_args, - in_tys, - linear_results, - ) -end - -const DEBUG_MODE::Ref{Bool} = Ref(false) - -function with_debug(f) - old = DEBUG_MODE[] - DEBUG_MODE[] = true - try - return f() - finally - DEBUG_MODE[] = old - end -end - -function mlir_stacktrace(name, file, line)::MLIR.IR.Location - # calling `stacktrace` can add a lot of time overhead, so let's avoid adding debug info if not used - if DEBUG_MODE[] - return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) - end - - # retrieve current stacktrace, remove this function's frame and translate to MLIR Location - st = stacktrace() - deleteat!(st, 1) - return mapfoldl(MLIR.IR.Location, st) do stackframe - name = string(stackframe.func) - file = stackframe.file - line = stackframe.line - return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) - end -end From 9e1fe6c50c82d45fba5499a6c840bc2f3ec2e27c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 13 Dec 2024 22:42:08 -0500 Subject: [PATCH 51/78] fixup --- ext/ReactantNNlibExt.jl | 7 ++++-- ext/ReactantStatisticsExt.jl | 3 ++- ext/ReactantYaoBlocksExt.jl | 7 +++--- src/Interpreter.jl | 4 ++-- src/Reactant.jl | 42 ++++++++++++------------------------ src/TracedRArray.jl | 2 +- src/TracedUtils.jl | 2 +- 7 files changed, 29 insertions(+), 38 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 8bfa5de02..ca7fff285 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -7,11 +7,14 @@ using Reactant: Ops, TracedRArray, AnyTracedRArray, - materialize_traced_array, MLIR, - TracedRNumber, + TracedRNumber + +using Reactant.TraceUtils: + materialize_traced_array, get_mlir_data, set_mlir_data! + using ReactantCore: @trace using LinearAlgebra: LinearAlgebra, triu diff --git a/ext/ReactantStatisticsExt.jl b/ext/ReactantStatisticsExt.jl index f733511af..40db81a8e 100644 --- a/ext/ReactantStatisticsExt.jl +++ b/ext/ReactantStatisticsExt.jl @@ -1,6 +1,7 @@ module ReactantStatisticsExt -using Reactant: AnyTracedRArray, materialize_traced_array +using Reactant: AnyTracedRArray +using Reactant.TracedUtils: materialize_traced_array using Statistics: Statistics function Statistics.mean(A::AnyTracedRArray{T,N}; dims=:) where {T,N} diff --git a/ext/ReactantYaoBlocksExt.jl b/ext/ReactantYaoBlocksExt.jl index 2542d8a08..114c8017f 100644 --- a/ext/ReactantYaoBlocksExt.jl +++ b/ext/ReactantYaoBlocksExt.jl @@ -1,12 +1,13 @@ module ReactantYaoBlocksExt using Reactant +using Reactant.TraceUtils: broadcast_to_size using YaoBlocks function YaoBlocks.mat( ::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:XGate} ) where {D,T,S} - M = Reactant.broadcast_to_size(zero(T), (2, 2)) + M = broadcast_to_size(zero(T), (2, 2)) c = cos(R.theta / 2) s = -im * sin(R.theta / 2) M[1, 1] = c @@ -19,7 +20,7 @@ end function YaoBlocks.mat( ::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:YGate} ) where {D,T,S} - M = Reactant.broadcast_to_size(zero(T), (2, 2)) + M = broadcast_to_size(zero(T), (2, 2)) c = cos(R.theta / 2) s = sin(R.theta / 2) M[1, 1] = c @@ -32,7 +33,7 @@ end function YaoBlocks.mat( ::Type{T}, R::RotationGate{D,Reactant.TracedRNumber{S},<:ZGate} ) where {D,T,S} - M = Reactant.broadcast_to_size(zero(T), (2, 2)) + M = broadcast_to_size(zero(T), (2, 2)) x = exp(im * R.theta / 2) M[1, 1] = conj(x) M[2, 2] = x diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 50f4f52d1..b3aa5b700 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -303,7 +303,7 @@ function overload_autodiff( act = act_from_type(A, reverse, needs_primal(CMode)) push!(ret_activity, act) if act == enzyme_out || act == enzyme_outnoneed - attr = fill(MLIR.IR.Attribute(eltype(a)(1)), mlir_type(a)) + attr = fill(MLIR.IR.Attribute(eltype(a)(1)), Ops.mlir_type(a)) cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) push!(ad_inputs, cst) end @@ -325,7 +325,7 @@ function overload_autodiff( if act != enzyme_out && act != enzyme_outnoneed continue end - TraceUtils.push_val!(ad_inputs, args[idx].dval, path[3:end]) + TracedUtils.push_val!(ad_inputs, args[idx].dval, path[3:end]) end end end diff --git a/src/Reactant.jl b/src/Reactant.jl index 5d670558e..bdbc8aefc 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -57,34 +57,6 @@ abstract type RNumber{T<:ReactantPrimitive} <: Number end Base.collect(A::RArray) = copy(A) -function Enzyme.make_zero( - ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) -)::RT where {copy_if_inactive,RT<:RArray} - if haskey(seen, prev) - return seen[prev] - end - if Enzyme.Compiler.guaranteed_const_nongen(RT, nothing) - return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev - end - if RT <: ConcreteRArray - res = RT(zeros(eltype(RT), size(prev))) - seen[prev] = res - return res - end - - if RT <: TracedRArray - res = broadcast_to_size(eltype(RT)(0), size(prev)) - seen[prev] = res - return res - end - - attr = fill(MLIR.IR.Attribute(eltype(RT)(0)), mlir_type(prev)) - cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) - res = RT((), cst) - seen[prev] = res - return res -end - function ancestor(x::AbstractArray) p_x = parent(x) p_x === x && return x @@ -159,6 +131,20 @@ include("ControlFlow.jl") include("Tracing.jl") include("Compiler.jl") +function Enzyme.make_zero( + ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) +)::RT where {copy_if_inactive,RT<:RArray} + if haskey(seen, prev) + return seen[prev] + end + if Enzyme.Compiler.guaranteed_const_nongen(eltype(RT), nothing) + return copy_if_inactive ? Base.deepcopy_internal(prev, seen) : prev + end + res = zero(prev) + seen[prev] = res + return res +end + using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 988d895da..31eefe77e 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -458,7 +458,7 @@ end for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber)) @eval function Base.clamp!(x::AnyTracedRArray, min::$(minT), max::$(maxT)) y = Ops.clamp(min, materialize_traced_array(x), max) - set_mlir_data!(x, y.mlir_data) + TracedUtils.set_mlir_data!(x, y.mlir_data) return x end end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index e18a819bd..19771318a 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -123,7 +123,7 @@ function make_mlir_fn( return ( true, make_mlir_fn( - apply, + Reactant.apply, (f, args...), kwargs, name, From 0982c09c83b736cb1a1fa90fca09f58ac051b0e9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 13 Dec 2024 22:47:44 -0500 Subject: [PATCH 52/78] fix --- src/linear_algebra.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 745195217..a554e4b67 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -1,3 +1,13 @@ +module TracedLinearAlgebra + +using ..Reactant +import ..TracedRArray +import ..AnyTracedRArray +import ..AnyTracedRMatrix +import ..AnyTracedRVector +using ..TracedUtils +using LinearAlgebra + function LinearAlgebra.mul!( @nospecialize(C::TracedRArray{T,1}), @nospecialize(A::AnyTracedRMatrix), @@ -142,3 +152,5 @@ function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) wh mat, promote_to(TracedRNumber{T}, 0); high=[m - length(v), n - length(v)] ) end + +end From 996d60a4c406f27421094dadc08b772a7cc072b4 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 13 Dec 2024 22:51:57 -0500 Subject: [PATCH 53/78] fix --- ext/ReactantNNlibExt.jl | 2 +- ext/ReactantYaoBlocksExt.jl | 2 +- src/ConcreteRArray.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index ca7fff285..65f5677b9 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -10,7 +10,7 @@ using Reactant: MLIR, TracedRNumber -using Reactant.TraceUtils: +using Reactant.TracedUtils: materialize_traced_array, get_mlir_data, set_mlir_data! diff --git a/ext/ReactantYaoBlocksExt.jl b/ext/ReactantYaoBlocksExt.jl index 114c8017f..cc16e51be 100644 --- a/ext/ReactantYaoBlocksExt.jl +++ b/ext/ReactantYaoBlocksExt.jl @@ -1,7 +1,7 @@ module ReactantYaoBlocksExt using Reactant -using Reactant.TraceUtils: broadcast_to_size +using Reactant.TracedUtils: broadcast_to_size using YaoBlocks function YaoBlocks.mat( diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index ba495aef3..e9d9c02d7 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -99,7 +99,7 @@ end function Base.convert( ::Type{T}, X::WrappedConcreteRArray{ElType,N} ) where {T<:Array,ElType,N} - fn = compile(materialize_traced_array, (X,)) + fn = compile(TracedUtils.materialize_traced_array, (X,)) return convert(Array, fn(X)) end Base.Array(x::AnyConcreteRArray) = convert(Array, x) From 681107f97c2c816aed0abd3527c0166f4ab1ce75 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 13 Dec 2024 22:55:57 -0500 Subject: [PATCH 54/78] fix --- src/linear_algebra.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index a554e4b67..61d82dae1 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -5,7 +5,13 @@ import ..TracedRArray import ..AnyTracedRArray import ..AnyTracedRMatrix import ..AnyTracedRVector -using ..TracedUtils + +using ..TracedUtils: + get_mlir_data, + materialize_traced_array, + set_mlir_data! + +import ..Ops using LinearAlgebra function LinearAlgebra.mul!( From 3a35f5c29a2eda3920f46c65d6152b48978f5f43 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 13 Dec 2024 22:57:13 -0500 Subject: [PATCH 55/78] fixup --- src/TracedRArray.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 31eefe77e..a909f5a9f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -445,7 +445,7 @@ function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} MLIR.IR.result( # TODO maybe we should do some conversion? MLIR.Dialects.stablehlo.concatenate( - collect(get_mlir_data.(X)); + collect(TracedUtils.get_mlir_data.(X)); result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)), dimension=dims - 1, # stablehlo expects this to be zero-indexed ), From eea1dbaa4a7b1c607352da7c067db452001b1d3a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 13 Dec 2024 22:59:39 -0500 Subject: [PATCH 56/78] fix --- src/TracedRNumber.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index d9505549b..8a1e37132 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -225,12 +225,12 @@ Base.float(x::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T} - return Base.typed_vcat(T, map(Base.Fix2(broadcast_to_size, (1,)), x)...) + return Base.typed_vcat(T, map(Base.Fix2(TracedUtils.broadcast_to_size, (1,)), x)...) end Base.hcat(x::TracedRNumber...) = Base.typed_hcat(Base.promote_eltypeof(x...), x...) function Base.typed_hcat(::Type{T}, x::TracedRNumber...) where {T} - return Base.typed_hcat(T, map(Base.Fix2(broadcast_to_size, (1, 1)), x)...) + return Base.typed_hcat(T, map(Base.Fix2(TracedUtils.broadcast_to_size, (1, 1)), x)...) end function Base.hvcat(rows::Tuple{Vararg{Int}}, xs::TracedRNumber...) From c65948d56461a14ec78764669d303f95958cfd6d Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 13 Dec 2024 23:16:42 -0500 Subject: [PATCH 57/78] wip --- src/Ops.jl | 111 +++++++++++++++++++++----------------------- src/TracedRArray.jl | 3 +- 2 files changed, 55 insertions(+), 59 deletions(-) diff --git a/src/Ops.jl b/src/Ops.jl index 4b335e182..f604f16a6 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -45,7 +45,7 @@ function with_debug(f) end end -function mlir_stacktrace(name, file, line)::MLIR.IR.Location +@noinline function mlir_stacktrace(name, file, line)::MLIR.IR.Location # calling `stacktrace` can add a lot of time overhead, so let's avoid adding debug info if not used if DEBUG_MODE[] return MLIR.IR.Location(name, MLIR.IR.Location(file, line, 0)) @@ -67,7 +67,7 @@ struct Token end # constant ops -function constant( +@noinline function constant( x::DenseArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T,N} value = MLIR.IR.DenseElementsAttribute(x) @@ -76,7 +76,7 @@ function constant( return TracedRArray{T,N}((), res, size(x)) end -function constant( +@noinline function constant( x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__) ) where {T<:Number} res = constant(fill(x); location) @@ -124,7 +124,7 @@ for (dialect, op) in [ (:chlo, :sinh), ] @eval begin - function $op( + @noinline function $op( x::TracedRArray{T,N}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T,N} @@ -136,7 +136,7 @@ for (dialect, op) in [ return TracedRArray{T,N}((), res, size(x)) end - function $op( + @noinline function $op( x::TracedRNumber{T}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T} @@ -172,7 +172,7 @@ for (dialect, op) in [ (:chlo, :zeta), ] @eval begin - function $op( + @noinline function $op( a::TracedRArray{T,N}, b::TracedRArray{T,N}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), @@ -188,7 +188,7 @@ for (dialect, op) in [ return TracedRArray{T,N}((), res, size(a)) end - function $op( + @noinline function $op( a::TracedRNumber{T}, b::TracedRNumber{T}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), @@ -214,7 +214,7 @@ for (dialect, op) in [ (:chlo, :is_pos_inf), ] @eval begin - function $op( + @noinline function $op( x::TracedRArray{T,N}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T,N} @@ -226,7 +226,7 @@ for (dialect, op) in [ return TracedRArray{Bool,N}((), res, size(x)) end - function $op( + @noinline function $op( x::TracedRNumber{T}; location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__), ) where {T} @@ -240,7 +240,7 @@ for (dialect, op) in [ end end -function is_finite( +@noinline function is_finite( x::TracedRArray{T,N}; location=mlir_stacktrace("is_finite", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( @@ -251,7 +251,7 @@ function is_finite( return TracedRArray{Bool,N}((), res, size(x)) end -function is_finite( +@noinline function is_finite( x::TracedRNumber{T}; location=mlir_stacktrace("is_finite", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( @@ -261,7 +261,7 @@ function is_finite( end # fixes to default automated implementations -function abs( +@noinline function abs( x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("abs", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( @@ -270,7 +270,7 @@ function abs( return TracedRArray{T,N}((), res, size(x)) end -function abs( +@noinline function abs( x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("abs", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( @@ -284,7 +284,7 @@ function reshape(x::TracedRArray, dims...; kwargs...) return reshape(x, collect(dims); kwargs...) end -function reshape( +@noinline function reshape( x::TracedRArray{T,N}, dims::Vector{Int}; location=mlir_stacktrace("reshape", @__FILE__, @__LINE__), @@ -299,7 +299,7 @@ function reshape( return transpose(result, Int64[length(dims):-1:1...]) end -function get_dimension_size( +@noinline function get_dimension_size( x::TracedRArray{T,N}, dim; location=mlir_stacktrace("get_dimension_size", @__FILE__, @__LINE__), @@ -313,7 +313,7 @@ function get_dimension_size( return TracedRNumber{Int32}((), res) end -function set_dimension_size( +@noinline function set_dimension_size( x::TracedRArray{T,N}, size::TracedRNumber{Int}, dim::Int; @@ -332,7 +332,7 @@ function set_dimension_size( return TracedRArray{T,N}((), res, size(x)) end -function transpose( +@noinline function transpose( x::TracedRArray{T,N}, permutation; location=mlir_stacktrace("transpose", @__FILE__, @__LINE__), @@ -346,7 +346,7 @@ function transpose( end # indexing ops -function pad( +@noinline function pad( x::TracedRArray{T,N}, padding_value::TracedRNumber{T}; low=fill(0, N), @@ -368,7 +368,7 @@ function pad( return TracedRArray{T,N}((), res, rsize) end -function slice( +@noinline function slice( x::TracedRArray{T,N}, start_indices, limit_indices; @@ -394,7 +394,7 @@ function slice( end # numerics -function complex( +@noinline function complex( real::TracedRArray{T,N}, imag::TracedRArray{T,N}; location=mlir_stacktrace("complex", @__FILE__, @__LINE__), @@ -410,7 +410,7 @@ function complex( return TracedRArray{Complex{T},N}((), res, size(real)) end -function complex( +@noinline function complex( real::TracedRNumber{T}, imag::TracedRNumber{T}; location=mlir_stacktrace("complex", @__FILE__, @__LINE__), @@ -426,7 +426,7 @@ function complex( return TracedRNumber{Complex{T}}((), res) end -function real( +@noinline function real( x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("real", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( @@ -435,7 +435,7 @@ function real( return TracedRArray{T,N}((), res, size(x)) end -function real( +@noinline function real( x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("real", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( @@ -444,7 +444,7 @@ function real( return TracedRNumber{T}((), res) end -function imag( +@noinline function imag( x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("imag", @__FILE__, @__LINE__) ) where {T,N} res = MLIR.IR.result( @@ -453,7 +453,7 @@ function imag( return TracedRArray{T,N}((), res, size(x)) end -function imag( +@noinline function imag( x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("imag", @__FILE__, @__LINE__) ) where {T} res = MLIR.IR.result( @@ -477,7 +477,7 @@ end # return TracedRArray{T,N}((), res, size(x)) # end -function fft( +@noinline function fft( x::TracedRArray{T,N}; type::String, length, @@ -519,7 +519,7 @@ function fft( return TracedRArray{Tout,N}((), res, rsize) end -function cholesky( +@noinline function cholesky( x::TracedRArray{T,N}; lower::Bool=false, location=mlir_stacktrace("cholesky", @__FILE__, @__LINE__), @@ -533,7 +533,7 @@ function cholesky( return TracedRArray{T,N}((), res, size(x)) end -function clamp( +@noinline function clamp( min::Union{TracedRNumber{T},TracedRArray{T,N}}, x::TracedRArray{T,N}, max::Union{TracedRNumber{T},TracedRArray{T,N}}; @@ -551,7 +551,7 @@ function clamp( return TracedRArray{T,N}((), res, size(x)) end -function clamp( +@noinline function clamp( min::TracedRNumber{T}, x::TracedRNumber{T}, max::TracedRNumber{T}; @@ -569,7 +569,7 @@ function clamp( return TracedRNumber{T}((), res) end -function clamp(min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T) where {T,N} +@noinline function clamp(min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T) where {T,N} return clamp(constant(min), x, constant(max)) end @@ -603,7 +603,7 @@ end # return TracedRArray{T,N}((), res, size(lhs)) # end -function dot_general( +@noinline function dot_general( lhs::TracedRArray{T}, rhs::TracedRArray{T}; contracting_dimensions, @@ -760,7 +760,7 @@ function dot_general( return TracedRArray{T,length(ressize)}((), res, ressize) end -function einsum( +@noinline function einsum( lhs::TracedRArray{T}, rhs::TracedRArray{T}; equation::String, @@ -818,23 +818,23 @@ end # end # paralell ops -function partition_id(; location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)) +@noinline function partition_id(; location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)) res = MLIR.IR.result(stablehlo.partition_id(; location)) return TracedRNumber{UInt32}((), res) end -function replica_id(; location=mlir_stacktrace("replica_id", @__FILE__, @__LINE__)) +@noinline function replica_id(; location=mlir_stacktrace("replica_id", @__FILE__, @__LINE__)) res = MLIR.IR.result(stablehlo.replica_id(; location)) return TracedRNumber{UInt32}((), res) end -function after_all(tokens...; location=mlir_stacktrace("after_all", @__FILE__, @__LINE__)) +@noinline function after_all(tokens...; location=mlir_stacktrace("after_all", @__FILE__, @__LINE__)) tokens = [token.mlir_data for token in tokens] res = MLIR.IR.result(stablehlo.after_all(tokens; location)) return Token(res) end -function optimization_barrier( +@noinline function optimization_barrier( operands::Union{TracedRNumber,TracedRArray}...; location=mlir_stacktrace("optimization_barrier", @__FILE__, @__LINE__), ) @@ -855,7 +855,7 @@ function optimization_barrier( ) end -function outfeed( +@noinline function outfeed( operands::Union{TracedRNumber,TracedRArray}...; token, config="", @@ -869,7 +869,7 @@ function outfeed( return Token(res) end -function send( +@noinline function send( operands::Union{TracedRNumber,TracedRArray}...; token, channel_id::Int, @@ -892,7 +892,7 @@ function send( return Token(res) end -function recv( +@noinline function recv( results::Tuple{Type,Vector{Int}}...; token, channel_id::Int, @@ -971,7 +971,7 @@ end # return TracedRArray{T,N}((), res, size(x)) # end -function top_k( +@noinline function top_k( x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__) ) where {T,N} rsize = [size(x)[1:(end - 1)]..., k] @@ -984,7 +984,7 @@ function top_k( ) end -function iota( +@noinline function iota( T::Type, shape::Vector{Int}; iota_dimension, @@ -997,7 +997,7 @@ function iota( return TracedRArray{T,N}((), res, shape) end -function reverse( +@noinline function reverse( x::TracedRArray{T,N}; dimensions, location=mlir_stacktrace("reverse", @__FILE__, @__LINE__), @@ -1014,7 +1014,7 @@ function reverse( end # random ops -function rng_bit_generator( +@noinline function rng_bit_generator( seed::TracedRArray{UInt64,1}, shape; algorithm::String="DEFAULT", @@ -1030,7 +1030,7 @@ function rng_bit_generator( end # functional ops -function return_( +@noinline function return_( results::Union{TracedRArray,TracedRNumber}...; location=mlir_stacktrace("return_", @__FILE__, @__LINE__), ) @@ -1038,7 +1038,7 @@ function return_( end # control flow ops -function select( +@noinline function select( pred::Union{TracedRArray{Bool,N},TracedRNumber{Bool}}, on_true::TracedRArray{T,N}, on_false::TracedRArray{T,N}, @@ -1057,7 +1057,7 @@ function select( return TracedRArray{T,N}((), res, size(on_true)) end -function select( +@noinline function select( pred::TracedRNumber{Bool}, on_true::TracedRNumber{T}, on_false::TracedRNumber{T} ) where {T} res = MLIR.IR.result( @@ -1072,20 +1072,15 @@ function select( end # comparison -function compare( - lhs::Union{TracedRArray{T},TracedRNumber{T}}, - rhs::Union{TracedRArray{T},TracedRNumber{T}}; +@noinline function compare( + lhs::AT, + rhs::AT; comparison_direction::String, compare_type=nothing, location=mlir_stacktrace("compare", @__FILE__, @__LINE__), -) where {T} + ) where {AT <: Union{TracedRArray,TracedRNumber}} @assert comparison_direction in ("EQ", "NE", "GE", "GT", "LE", "LT") @assert size(lhs) == size(rhs) - if lhs isa TracedRNumber - @assert rhs isa TracedRNumber - else - @assert rhs isa TracedRArray - end res = MLIR.IR.result( stablehlo.compare( @@ -1104,7 +1099,7 @@ function compare( end # eltype conversion -function convert( +@noinline function convert( ::Type{TracedRArray{T,N}}, x::TracedRArray; location=mlir_stacktrace("convert", @__FILE__, @__LINE__), @@ -1121,7 +1116,7 @@ function convert( ) end -function convert( +@noinline function convert( ::Type{TracedRNumber{T}}, x::TracedRNumber; location=mlir_stacktrace("convert", @__FILE__, @__LINE__), @@ -1163,7 +1158,7 @@ julia> Reactant.@jit( (ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),) ``` """ -function hlo_call( +@noinline function hlo_call( code, args...; func_name="main", diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index a909f5a9f..44c0394cb 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -11,6 +11,7 @@ import ..AnyTracedRArray using ..TracedUtils import ..Ops import ..MLIR +import ..ancestor import ReactantCore import ..TracedUtils: materialize_traced_array @@ -313,7 +314,7 @@ function Broadcast.copy(bc::Broadcasted{<:AbstractReactantArrayStyle{0}}) return dest[CartesianIndex()] # 0D broadcast needs to unwrap results end -# Base.eltype(::Broadcast.Extruded{T}) where {T} = eltype(T) +Base.eltype(::Broadcast.Extruded{T}) where {T} = eltype(T) # we need to override the outer copy method to make sure we never fall back to scalar # iteration (see, e.g., CUDA.jl#145) From 2ff00ad3f70319f41babb894e9bf02f476f2c2fa Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 13 Dec 2024 23:39:33 -0500 Subject: [PATCH 58/78] safe prints --- src/utils.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index a1450e7b3..93cd80a91 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -147,6 +147,10 @@ function make_oc(sig, rt, src, nargs, isva, f)::Core.OpaqueClosure sig, rt, rt, @__MODULE__, src, 0, nothing, nargs, isva, f, true)::Core.OpaqueClosure end +function safe_print(name, x) + ccall(:jl_, Cvoid, (Any,), name*" "*string(x)) +end + # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -159,6 +163,7 @@ function call_with_reactant_generator( ) @nospecialize args = redub_arguments + safe_print("args", args) stub = Core.GeneratedFunctionStub( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() @@ -283,6 +288,8 @@ function call_with_reactant_generator( # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) #else opt = Core.Compiler.OptimizationState(frame, interp) + + safe_print("opt.src", opt.src) caller = frame.result @static if VERSION < v"1.11-" @@ -292,6 +299,8 @@ function call_with_reactant_generator( Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end + safe_print("ir1", ir) + # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type # to Any if our interpreter would change the return type of any result. @@ -315,6 +324,8 @@ function call_with_reactant_generator( src = Core.Compiler.ir_to_codeinf!(opt) + safe_print("src", src) + # prepare a new code info code_info = copy(src) static_params = match.sparams @@ -452,6 +463,8 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code + safe_print("code_info", code_info) + return code_info end From caad928df5babc44d631d7644e9d04cb6566bd25 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 00:57:05 -0500 Subject: [PATCH 59/78] fix --- src/TracedUtils.jl | 2 +- src/utils.jl | 31 ++++++++++++++++++++++++------- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 19771318a..cebfafbf3 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -481,7 +481,7 @@ function broadcast_to_size(arg::Broadcast.Extruded, rsize) return broadcast_to_size_internal(x, rsize) end -function broadcast_to_size_internal(x::TracedRArray, rsize) +@noinline function broadcast_to_size_internal(x::TracedRArray, rsize) dims = collect(Int64, 0:(length(size(x)) - 1)) if length(size(MLIR.IR.type(x.mlir_data))) != length(dims) diff --git a/src/utils.jl b/src/utils.jl index 93cd80a91..2c62e6ff4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -119,11 +119,25 @@ function rewrite_inst(inst, ir, interp) ft = sig.parameters[1] if should_rewrite_ft(ft) && !is_reactant_method(omi) + method = omi.def::Core.Method min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - - lookup_result = lookup_world(Tuple{typeof(call_with_reactant), sig.parameters...}, interp.world, Core.Compiler.method_table(interp), min_world, max_world) + + + if !method.isva || !Base.isvarargtype(sig.parameters[end]) + sig2 = Tuple{typeof(call_with_reactant), sig.parameters...} + else + vartup = inst.args[end] + ns = Type[] + eT = sig.parameters[end].T + for i in 1:(length(inst.args) - 1 - (length(sig.parameters) - 1)) + push!(ns, eT) + end + sig2 = Tuple{typeof(call_with_reactant), sig.parameters[1:end-1]..., ns...} + end + + lookup_result = lookup_world(sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world) match = lookup_result::Core.MethodMatch # look up the method and code instance @@ -135,6 +149,7 @@ function rewrite_inst(inst, ir, interp) match.spec_types, match.sparams, ) + n_method_args = method.nargs rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...) return true, rep end @@ -163,7 +178,7 @@ function call_with_reactant_generator( ) @nospecialize args = redub_arguments - safe_print("args", args) + # safe_print("args", args) stub = Core.GeneratedFunctionStub( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() @@ -289,7 +304,7 @@ function call_with_reactant_generator( #else opt = Core.Compiler.OptimizationState(frame, interp) - safe_print("opt.src", opt.src) + # safe_print("opt.src", opt.src) caller = frame.result @static if VERSION < v"1.11-" @@ -298,8 +313,9 @@ function call_with_reactant_generator( ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end + - safe_print("ir1", ir) + # safe_print("ir1", ir) # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type @@ -320,11 +336,12 @@ function call_with_reactant_generator( Core.Compiler.setindex!(ir.stmts[i], Any, :type) end end + Core.Compiler.finish(interp, opt, ir, caller) src = Core.Compiler.ir_to_codeinf!(opt) - safe_print("src", src) + # safe_print("src", src) # prepare a new code info code_info = copy(src) @@ -463,7 +480,7 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - safe_print("code_info", code_info) + # safe_print("code_info", code_info) return code_info end From 7547699779623470f8e40e9031ebad3c134c85bc Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 01:28:06 -0500 Subject: [PATCH 60/78] fix --- src/TracedRNumber.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 8a1e37132..49d2ce651 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -239,7 +239,7 @@ end function Base.typed_hvcat( ::Type{T}, rows::Tuple{Vararg{Int}}, xs::TracedRNumber... ) where {T} - xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs) + xs = map(Base.Fix2(TracedUtils.broadcast_to_size, (1, 1)), xs) return Base.typed_hvcat(T, rows, xs...) end From c425ccb01de053f2e2ea45c8e9e58bf879031ae5 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 01:43:24 -0500 Subject: [PATCH 61/78] stackoverflow --- src/utils.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 2c62e6ff4..75d5fe163 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -87,6 +87,11 @@ function should_rewrite_ft(@nospecialize(ft)) return false end + # Avoid the 1.10 stackoverflow + if ft <: typeof(Base.typed_hvcat) + return false + end + # Default assume all functions need to be reactant-ified return true end From a6b52f594ef042df33daf7d54cb4d037c1079fe9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 02:00:25 -0500 Subject: [PATCH 62/78] cleanup --- src/TracedRArray.jl | 6 +++--- src/TracedRNumber.jl | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 44c0394cb..a3dd02e59 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -88,11 +88,11 @@ end # Prevent ambiguity function Base.getindex(a::WrappedTracedRArray, index::Union{Int,TracedRNumber{Int}}...) - return getindex(ancestor(a), get_ancestor_indices(a, index...)...) + return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, index...)...) end function Base.getindex(a::WrappedTracedRArray, indices...) - return getindex(ancestor(a), get_ancestor_indices(a, indices...)...) + return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...) end function Base.setindex!( @@ -124,7 +124,7 @@ function Base.setindex!( v, indices::Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N}, ) where {T,N} - ancestor_indices = get_ancestor_indices(a, indices...) + ancestor_indices = TracedUtils.get_ancestor_indices(a, indices...) setindex!(ancestor(a), v, ancestor_indices...) return a end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 49d2ce651..9a1578014 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -217,7 +217,7 @@ struct TypeCast{T<:ReactantPrimitive} <: Function end (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = TracedUtils.promote_to(TracedRNumber{T}, x) function Base.fill(x::TracedRNumber, dims::NTuple{N,Integer}) where {N} - return Reactant.broadcast_to_size(x, dims) + return TracedUtils.broadcast_to_size(x, dims) end Base.float(x::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{float(T)}, x) @@ -249,7 +249,7 @@ end function Base.typed_hvncat( ::Type{T}, dims::Tuple{Vararg{Int}}, row_first::Bool, xs::TracedRNumber... ) where {T} - xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs) + xs = map(Base.Fix2(TracedUtils.broadcast_to_size, (1, 1)), xs) return Base.typed_hvncat(T, dims, row_first, xs...) end From b17e75fb477d5f991ee5e0d4fc538b9a6b0fd0f2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 02:04:27 -0500 Subject: [PATCH 63/78] dyindex --- src/TracedRArray.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index a3dd02e59..4223d6cee 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -14,6 +14,7 @@ import ..MLIR import ..ancestor import ReactantCore import ..TracedUtils: materialize_traced_array +import GPUArraysCore ReactantCore.is_traced(::TracedRArray) = true From 2cff76e6277c691ff09a8ecb22a1dd8f913e3a80 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 02:16:09 -0500 Subject: [PATCH 64/78] rt --- src/utils.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils.jl b/src/utils.jl index 75d5fe163..bbcc14535 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -423,6 +423,7 @@ function call_with_reactant_generator( end rt = Base.Experimental.compute_ir_rettype(ir) + @assert code_info.rettype == rt # ocva = method.isva From 1d0cb8eef7158b90bad946a50199d46a1e52729c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 02:24:13 -0500 Subject: [PATCH 65/78] continue --- test/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basic.jl b/test/basic.jl index 75859122c..5eff286ed 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -640,7 +640,7 @@ end function f_row_major(x) y = [1 2; 3 4; 5 6] if x isa Reactant.TracedRArray - y = Reactant.promote_to(Reactant.TracedRArray{eltype(x),2}, y) + y = Reactant.TracedUtils.promote_to(Reactant.TracedRArray{eltype(x),2}, y) end return x .+ y end From f4349a9ff19fd34535ddd71d161ab348966a1c96 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 03:01:47 -0500 Subject: [PATCH 66/78] clean --- src/utils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index bbcc14535..a5fccba50 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -91,6 +91,9 @@ function should_rewrite_ft(@nospecialize(ft)) if ft <: typeof(Base.typed_hvcat) return false end + if ft <: typeof(Base.hvcat) + return false + end # Default assume all functions need to be reactant-ified return true @@ -423,7 +426,6 @@ function call_with_reactant_generator( end rt = Base.Experimental.compute_ir_rettype(ir) - @assert code_info.rettype == rt # ocva = method.isva From a4ae31a6729b0445be68789a43f366cdea0602d9 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 03:28:11 -0500 Subject: [PATCH 67/78] fix --- src/linear_algebra.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 61d82dae1..04bef981c 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -64,10 +64,10 @@ function LinearAlgebra.mul!( ) res = if iszero(β) - isone(α) ? tmp : Ops.multiply(tmp, broadcast_to_size(T(α), size(C))) + isone(α) ? tmp : Ops.multiply(tmp, TracedUtils.broadcast_to_size(T(α), size(C))) else - α_res = Ops.multiply(tmp, broadcast_to_size(T(α), size(C))) - β_C = Ops.multiply(C, broadcast_to_size(T(β), size(C))) + α_res = Ops.multiply(tmp, TracedUtils.broadcast_to_size(T(α), size(C))) + β_C = Ops.multiply(C, TracedUtils.broadcast_to_size(T(β), size(C))) Ops.add(α_res, β_C) end set_mlir_data!(C, get_mlir_data(res)) @@ -77,7 +77,7 @@ end function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = Ops.subtract( - Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) + Ops.iota(Int64, [size(X)...]; iota_dimension=2), TracedUtils.broadcast_to_size(k, size(X)) ) idxs = Ops.compare(iota_1, iota_2; comparison_direction="LE") X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data @@ -87,7 +87,7 @@ end function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = Ops.subtract( - Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X)) + Ops.iota(Int64, [size(X)...]; iota_dimension=2), TracedUtils.broadcast_to_size(k, size(X)) ) idxs = Ops.compare(iota_1, iota_2; comparison_direction="GE") X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data From 0dbe20f3f17bd7656777aa1231342cebf481206a Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 03:30:44 -0500 Subject: [PATCH 68/78] fix --- src/TracedRArray.jl | 2 +- src/TracedRNumber.jl | 4 ---- src/TracedUtils.jl | 5 +++++ test/complex.jl | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 4223d6cee..d8ba9b92f 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -157,7 +157,7 @@ Base.conj(A::AnyTracedRArray{<:Complex}) = Ops.conj(materialize_traced_array(A)) Base.conj!(A::AnyTracedRArray) = A function Base.conj!(A::AnyTracedRArray{<:Complex}) - set_mlir_data!(A, Ops.conj(materialize_traced_array(A)).mlir_data) + TracedUtils.set_mlir_data!(A, Ops.conj(materialize_traced_array(A)).mlir_data) return A end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 9a1578014..88f6f9e9b 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -212,10 +212,6 @@ for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRN @eval Base.clamp(x::TracedRNumber, min::$(minT), max::$(maxT)) = Ops.clamp(min, x, max) end -struct TypeCast{T<:ReactantPrimitive} <: Function end - -(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = TracedUtils.promote_to(TracedRNumber{T}, x) - function Base.fill(x::TracedRNumber, dims::NTuple{N,Integer}) where {N} return TracedUtils.broadcast_to_size(x, dims) end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index cebfafbf3..51d25a29f 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -291,6 +291,11 @@ function make_mlir_fn( end elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x + +struct TypeCast{T<:ReactantPrimitive} <: Function end + +(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = TracedUtils.promote_to(TracedRNumber{T}, x) + function elem_apply( ::Type{T}, x::TracedRArray{T2} ) where {T<:ReactantPrimitive,T2<:ReactantPrimitive} diff --git a/test/complex.jl b/test/complex.jl index 3bf19a051..43e3c4f6b 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -92,7 +92,7 @@ end y = Reactant.ConcreteRNumber(x) f = Reactant.compile((y,)) do z - z + Reactant.promote_to(Reactant.TracedRNumber{ComplexF64}, 1.0 - 3.0im) + z + Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{ComplexF64}, 1.0 - 3.0im) end @test isapprox(f(y), 2.0 - 1.0im) From 1c45d7e3ec25f6c971828c966ea7bcbccb1435ea Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 03:47:39 -0500 Subject: [PATCH 69/78] fix --- src/linear_algebra.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 04bef981c..80929cdec 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -2,6 +2,7 @@ module TracedLinearAlgebra using ..Reactant import ..TracedRArray +import ..TracedRNumber import ..AnyTracedRArray import ..AnyTracedRMatrix import ..AnyTracedRVector From 1ffb3666ee32cf6559c4cc8d8cd16e310beb20a5 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 03:52:11 -0500 Subject: [PATCH 70/78] fix --- lib/ReactantCore/src/ReactantCore.jl | 4 ++-- src/ControlFlow.jl | 8 ++++---- src/linear_algebra.jl | 5 +++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index 32c663fce..5b865dd34 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -153,7 +153,7 @@ function trace_for(mod, expr) all_syms = Expr(:tuple, counter, external_syms...) args_init = Expr( - :tuple, :(Reactant.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms... + :tuple, :(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms... ) reactant_code_block = quote @@ -161,7 +161,7 @@ function trace_for(mod, expr) cond_fn = $(all_syms) -> begin local num_iters = div($limit - $start, $step, RoundDown) - local num_iters = Reactant.promote_to( + local num_iters = Reactant.TracedUtils.promote_to( Reactant.TracedRNumber{Int64}, num_iters ) $counter < num_iters + 1 diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 3035e90c3..0209b62b3 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -1,7 +1,7 @@ function ReactantCore.traced_if( cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args ) where {TFn,FFn} - (_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.make_mlir_fn( + (_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results) = Reactant.TracedUtils.make_mlir_fn( true_fn, args, (), @@ -12,7 +12,7 @@ function ReactantCore.traced_if( construct_function_without_args=true, ) - (_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.make_mlir_fn( + (_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results) = Reactant.TracedUtils.make_mlir_fn( false_fn, args, (), @@ -88,7 +88,7 @@ function ReactantCore.traced_while( end for v in args ] - (_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.make_mlir_fn( + (_, cond_fn_compiled, cond_fn_results, _, _, _, _, in_tys, cond_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn( cond_fn, traced_args, (), @@ -99,7 +99,7 @@ function ReactantCore.traced_while( do_transpose=false, ) - (_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.make_mlir_fn( + (_, body_fn_compiled, body_fn_results, _, _, _, _, _, body_fn_linear_results) = Reactant.TracedUtils.make_mlir_fn( body_fn, traced_args, (), diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 80929cdec..a129e89e3 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -7,6 +7,7 @@ import ..AnyTracedRArray import ..AnyTracedRMatrix import ..AnyTracedRVector +import ..TracedUtils using ..TracedUtils: get_mlir_data, materialize_traced_array, @@ -116,9 +117,9 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} # terminate called after throwing an instance of 'xla::XlaRuntimeError' # what(): UNKNOWN: :0: error: 'tensor.empty' op unsupported op for export to XLA # :0: note: see current operation: %0 = "tensor.empty"() : () -> tensor<0xf64> - length(indices) ≤ 0 && return promote_to(TracedRArray{T,1}, T[]) + length(indices) ≤ 0 && return TracedUtils.promote_to(TracedRArray{T,1}, T[]) - idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,2}, indices)) + idxs = get_mlir_data(TracedUtils.promote_to(TracedRArray{Int,2}, indices)) #! format: off dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet( From 3887575ddc118af65572d861cc44707f752ed25f Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 03:52:56 -0500 Subject: [PATCH 71/78] fixup --- Project.toml | 2 +- lib/ReactantCore/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index dd4d67325..9af7dafef 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ LinearAlgebra = "1.10" NNlib = "0.9.26" OrderedCollections = "1" Preferences = "1.4" -ReactantCore = "0.1.2" +ReactantCore = "0.1.3" Reactant_jll = "0.0.26" Scratch = "1.2" Statistics = "1.10" diff --git a/lib/ReactantCore/Project.toml b/lib/ReactantCore/Project.toml index a11f5c66a..bec50b45e 100644 --- a/lib/ReactantCore/Project.toml +++ b/lib/ReactantCore/Project.toml @@ -1,7 +1,7 @@ name = "ReactantCore" uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal "] -version = "0.1.2" +version = "0.1.3" [deps] ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43" From 873e46bf8561c3a299f6880117db5b5ee76b4e4c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 04:28:05 -0500 Subject: [PATCH 72/78] fix --- ext/ReactantNNlibExt.jl | 12 ++++++------ src/ControlFlow.jl | 2 +- src/Tracing.jl | 4 ++-- src/linear_algebra.jl | 3 ++- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 65f5677b9..4325fdf08 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -241,9 +241,9 @@ function NNlib.batched_mul!( if size(x, 3) != size(y, 3) B = max(size(x, 3), size(y, 3)) if size(x, 3) == 1 - x = Reactant.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) + x = Reactant.TracedUtils.broadcast_to_size(x, (size(x, 1), size(x, 2), B)) elseif size(y, 3) == 1 - y = Reactant.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) + y = Reactant.TracedUtils.broadcast_to_size(y, (size(y, 1), size(y, 2), B)) end end @@ -253,9 +253,9 @@ function NNlib.batched_mul!( if size(x, 1) != size(y, 1) B = max(size(x, 1), size(y, 1)) if size(x, 1) == 1 - x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) + x = Reactant.TracedUtils.broadcast_to_size(x, (B, size(x, 2), size(x, 3))) elseif size(y, 1) == 1 - y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3))) + y = Reactant.TracedUtils.broadcast_to_size(y, (B, size(y, 2), size(y, 3))) end end @@ -273,7 +273,7 @@ end function NNlib.pad_constant( x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value ) where {T,N} - value = Reactant.promote_to(TracedRNumber{T}, value) + value = Reactant.TracedUtils.promote_to(TracedRNumber{T}, value) low = [i[1] for i in pad] high = [i[2] for i in pad] interior = [0 for i in pad] @@ -332,7 +332,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr start_sizes = ntuple(i -> size(src, i), dims) results = map(CartesianIndices(idxs)) do k res = @allowscalar src[colons..., Tuple(idxs[k])...] - res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,))) + res isa TracedRNumber && (res = Reactant.TracedUtils.broadcast_to_size(res, (1,))) return reshape(res, start_sizes..., :) end res = reshape(cat(results...; dims=(dims + 1)), size(dst)) diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 0209b62b3..0e0c00195 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -82,7 +82,7 @@ function ReactantCore.traced_while( # We promote all incoming args (is there a better way to do this?) traced_args = [ if v isa Number && !(v isa TracedType) - Reactant.promote_to(TracedRNumber{typeof(v)}, v) + Reactant.TracedUtils.promote_to(TracedRNumber{typeof(v)}, v) else v end for v in args diff --git a/src/Tracing.jl b/src/Tracing.jl index 4ea8172aa..bb7116eb9 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -498,14 +498,14 @@ function make_tracer( return ConcreteRNumber(prev) else if mode == TracedTrack - res = TracedRNumber{RT}((path,), broadcast_to_size(prev, ()).mlir_data) + res = TracedRNumber{RT}((path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data) if !haskey(seen, prev) return seen[prev] = res end return res elseif mode == TracedSetPath haskey(seen, prev) && return seen[prev] - res = TracedRNumber{RT}((path,), broadcast_to_size(prev, ()).mlir_data) + res = TracedRNumber{RT}((path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data) seen[prev] = res return res elseif mode == TracedToConcrete diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index a129e89e3..f54f19f6c 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -14,6 +14,7 @@ using ..TracedUtils: set_mlir_data! import ..Ops +import ..MLIR using LinearAlgebra function LinearAlgebra.mul!( @@ -157,7 +158,7 @@ function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) wh mat = (v .+ zero(v)') .* diag_indicator return Ops.pad( - mat, promote_to(TracedRNumber{T}, 0); high=[m - length(v), n - length(v)] + mat, TracedUtils.promote_to(TracedRNumber{T}, 0); high=[m - length(v), n - length(v)] ) end From 21244db79767b4660af3b665c76c2c2788a99fb5 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 05:00:38 -0500 Subject: [PATCH 73/78] fix --- src/linear_algebra.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index f54f19f6c..5cb10a861 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -134,7 +134,7 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} ) #! format: on - slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [1, 1])) + slice_sizes = get_mlir_data(Reactant.TracedUtils.promote_to(TracedRArray{Int,1}, [1, 1])) res = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_gather( get_mlir_data(y), idxs, slice_sizes; dimension_numbers From 70c3951e4c5b2ec1b6f7c43bba4eccbc9284d50b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 10:42:52 -0500 Subject: [PATCH 74/78] capture oc --- src/utils.jl | 68 ++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index a5fccba50..d21ff8a64 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -165,15 +165,29 @@ function rewrite_inst(inst, ir, interp) return false, inst end -function make_oc(sig, rt, src, nargs, isva, f)::Core.OpaqueClosure - ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), +const oc_captures = Dict{Tuple{Type, Type, Core.CodeInfo, Int, Bool, Any}, Core.OpaqueClosure}() + +# Caching is both good to reducing compile times and necessary to work around julia bugs +# in OpaqueClosure's: https://github.com/JuliaLang/julia/issues/56833 +function make_oc(sig::Type, rt::Type, src::Core.CodeInfo, nargs::Int, isva::Bool, f::Any)::Core.OpaqueClosure + key = (sig, rt, src, nargs, isva, f) + if haskey(oc_captures, key) + return oc_captures[key] + else + ores = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), sig, rt, rt, @__MODULE__, src, 0, nothing, nargs, isva, f, true)::Core.OpaqueClosure + oc_captures[key] = ores + return ores + end end function safe_print(name, x) ccall(:jl_, Cvoid, (Any,), name*" "*string(x)) end +const DEBUG_INTERP = Ref(false) + + # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: # 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance @@ -186,7 +200,9 @@ function call_with_reactant_generator( ) @nospecialize args = redub_arguments - # safe_print("args", args) + if DEBUG_INTERP[] + safe_print("args", args) + end stub = Core.GeneratedFunctionStub( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() @@ -293,12 +309,12 @@ function call_with_reactant_generator( ) result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) - frame = Core.Compiler.InferenceState(result, :local, interp) #=cache_mode=# + frame = Core.Compiler.InferenceState(result, VERSION < v"1.11-" ? :local : :global, interp) #=cache_mode=# @assert frame !== nothing Core.Compiler.typeinf(interp, frame) @static if VERSION >= v"1.11" # `typeinf` doesn't update the cfg. We need to do it manually. - frame.cfg = Core.Compiler.compute_basic_blocks(frame.src.code) + # frame.cfg = Core.Compiler.compute_basic_blocks(frame.src.code) end @assert Core.Compiler.is_inferred(frame) @@ -310,21 +326,28 @@ function call_with_reactant_generator( # rt = frame.result.result::Core.Compiler.Const # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) #else + # opt = Core.Compiler.OptimizationState(frame, interp) - - # safe_print("opt.src", opt.src) + + if DEBUG_INTERP[] + safe_print("opt.src", opt.src) + end caller = frame.result @static if VERSION < v"1.11-" ir = Core.Compiler.run_passes(opt.src, opt, caller) else ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller) - Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) + @static if VERSION < v"1.12-" + else + Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) + end end - - - # safe_print("ir1", ir) + if DEBUG_INTERP[] + safe_print("ir1", ir) + end + # Rewrite type unstable calls to recurse into call_with_reactant to ensure # they continue to use our interpreter. Reset the derived return type # to Any if our interpreter would change the return type of any result. @@ -349,8 +372,10 @@ function call_with_reactant_generator( src = Core.Compiler.ir_to_codeinf!(opt) - # safe_print("src", src) - + if DEBUG_INTERP[] + safe_print("src", src) + end + # prepare a new code info code_info = copy(src) static_params = match.sparams @@ -400,6 +425,11 @@ function call_with_reactant_generator( offset += 1 push!(fn_args, Core.SSAValue(length(overdubbed_code))) push!(tys, redub_arguments[i]) + + if DEBUG_INTERP[] + push!(overdubbed_code, Expr(:call, safe_print, "fn arg["*string(length(fn_args))*"]", fn_args[end])) + push!(overdubbed_codelocs, code_info.codelocs[1]) + end end @@ -423,6 +453,11 @@ function call_with_reactant_generator( push!(overdubbed_codelocs, code_info.codelocs[1]) push!(fn_args, Core.SSAValue(length(overdubbed_code))) push!(tys, Tuple{redub_arguments[n_method_args:n_actual_args]...}) + + if DEBUG_INTERP[] + push!(overdubbed_code, Expr(:call, safe_print, "fn arg["*string(length(fn_args))*"]", fn_args[end])) + push!(overdubbed_codelocs, code_info.codelocs[1]) + end end rt = Base.Experimental.compute_ir_rettype(ir) @@ -442,7 +477,8 @@ function call_with_reactant_generator( # Opaque closures also require takign the function argument. We can work around the latter # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure oc = if false && Base.issingletontype(args[1]) - Core._call_in_world_total(world, make_oc, octup, rt, src, ocnargs, ocva, args[1].instance)::Core.OpaqueClosure + res = Core._call_in_world_total(world, make_oc, octup, rt, src, ocnargs, ocva, args[1].instance)::Core.OpaqueClosure + else farg = fn_args[1] push!(overdubbed_code, @@ -488,7 +524,9 @@ function call_with_reactant_generator( code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - # safe_print("code_info", code_info) + if DEBUG_INTERP[] + safe_print("code_info", code_info) + end return code_info end From f839a0b09caab77d8224ff6c61e2213f85ab721e Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 11:29:07 -0500 Subject: [PATCH 75/78] compile perf --- src/TracedRArray.jl | 48 +++++++++++++++++++++++++++++---------------- src/utils.jl | 22 ++++++++++++++++++++- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index d8ba9b92f..e925b00f3 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -211,36 +211,50 @@ function Base.mapreduce( inp = [broadcast(f, A).mlir_data] - rdims = if dims == (:) - Int64[i for i in 0:(N - 1)] + rdims = Int64[] + + if dims == (:) + for i in 0:(N-1) + push!(rdims, i) + end else - Int64[i - 1 for i in dims] + for i in dims + push!(rdims, i-1) + end end in_tys = [ - MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(arg))) for arg in (inp[1], init[1]) + MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))), + MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))), ] - fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys]) + fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()]) args = ( - TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, i)) for - (i, ty) in enumerate(in_tys) + TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, 1)), + TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, 2)), ) - res = MLIR.IR.block!(fnbody) do - tmp = TracedUtils.broadcast_to_size(op(args...), ()).mlir_data - MLIR.Dialects.stablehlo.return_(MLIR.IR.Value[tmp]) - return tmp + resty = MLIR.IR.block!(fnbody) do + tmp = TracedUtils.broadcast_to_size(op(args...), ()) + Ops.return_(tmp) + return eltype(MLIR.IR.type(tmp.mlir_data)) end - toonedims = [(in(i - 1, rdims) ? 1 : size(A, i)) for i in 1:N] - outdims = [size(A, i) for i in 1:N if (i - 1) ∉ rdims] + toonedims = Int[] + outdims = Int[] + for i in 1:N + tmp = if in(i-1, rdims) + 1 + else + sz = size(A, i) + push!(outdims, sz) + sz + end + push!(toonedims, tmp) + end - TT = [ - MLIR.IR.TensorType(outdims, eltype(MLIR.IR.type(inp0))) for - (inp0, res0) in zip(inp, (res,)) - ] + TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)] body = MLIR.IR.Region() push!(body, fnbody) diff --git a/src/utils.jl b/src/utils.jl index d21ff8a64..5b075ab97 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -95,6 +95,21 @@ function should_rewrite_ft(@nospecialize(ft)) return false end + # Don't rewrite traced constructors + if ft <: Type{<:TracedRArray} || ft <: Type{<:TracedRNumber} || ft === Type{MLIR.IR.Location} || ft === Type{MLIR.IR.Block} + return false + end + + # Perf optimizations + if ft <: typeof(Core.Compiler.return_type) + return false + end + + # Perf optimizations + if ft <: typeof(Base.typemax) || ft <: typeof(Base.typemin) || ft <: typeof(Base.getproperty) || ft <: typeof(Base.vect) || ft <: typeof(Base.eltype) + return false + end + # Default assume all functions need to be reactant-ified return true end @@ -116,6 +131,9 @@ function rewrite_inst(inst, ir, interp) # Even if type unstable we do not want (or need) to replace intrinsic # calls or builtins with our version. ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) + if ft == typeof(Core.kwcall) + ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) + end if should_rewrite_ft(ft) rep = Expr(:call, call_with_reactant, inst.args...) return true, rep @@ -125,7 +143,9 @@ function rewrite_inst(inst, ir, interp) omi = inst.args[1]::Core.MethodInstance sig = omi.specTypes ft = sig.parameters[1] - + if ft == typeof(Core.kwcall) + ft = sig.parameters[3] + end if should_rewrite_ft(ft) && !is_reactant_method(omi) method = omi.def::Core.Method From 8073ecdafa139acf84d81ae59a6bbf3d30bab278 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Sat, 14 Dec 2024 20:28:08 +0100 Subject: [PATCH 76/78] v1.11 fix --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 5b075ab97..1239eb87d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -329,7 +329,7 @@ function call_with_reactant_generator( ) result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) - frame = Core.Compiler.InferenceState(result, VERSION < v"1.11-" ? :local : :global, interp) #=cache_mode=# + frame = Core.Compiler.InferenceState(result, VERSION < v"1.11-" ? :no : :global, interp) #=cache_mode=# @assert frame !== nothing Core.Compiler.typeinf(interp, frame) @static if VERSION >= v"1.11" From 585c4850a5818269fd2fba42fa1f1a8edc6adfe9 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Sat, 14 Dec 2024 20:31:08 +0100 Subject: [PATCH 77/78] other way 'round --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 1239eb87d..1d76c638a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -329,7 +329,7 @@ function call_with_reactant_generator( ) result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp)) - frame = Core.Compiler.InferenceState(result, VERSION < v"1.11-" ? :no : :global, interp) #=cache_mode=# + frame = Core.Compiler.InferenceState(result, VERSION < v"1.11-" ? :local : :no, interp) #=cache_mode=# @assert frame !== nothing Core.Compiler.typeinf(interp, frame) @static if VERSION >= v"1.11" From 0c56d358ac9ab772c692c19a58bad24b88d2b246 Mon Sep 17 00:00:00 2001 From: jumerckx Date: Sat, 14 Dec 2024 20:38:55 +0100 Subject: [PATCH 78/78] formatting --- ext/ReactantNNlibExt.jl | 18 +-- lib/ReactantCore/src/ReactantCore.jl | 4 +- src/Compiler.jl | 4 +- src/Interpreter.jl | 67 ++++---- src/Ops.jl | 26 ++-- src/Reactant.jl | 1 - src/TracedRArray.jl | 32 ++-- src/TracedRNumber.jl | 31 +++- src/TracedUtils.jl | 30 +++- src/Tracing.jl | 8 +- src/linear_algebra.jl | 19 ++- src/utils.jl | 221 +++++++++++++++++---------- 12 files changed, 279 insertions(+), 182 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 4325fdf08..f85bd1d84 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -2,18 +2,9 @@ module ReactantNNlibExt using NNlib using GPUArraysCore: @allowscalar -using Reactant: - Reactant, - Ops, - TracedRArray, - AnyTracedRArray, - MLIR, - TracedRNumber - -using Reactant.TracedUtils: - materialize_traced_array, - get_mlir_data, - set_mlir_data! +using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber + +using Reactant.TracedUtils: materialize_traced_array, get_mlir_data, set_mlir_data! using ReactantCore: @trace using LinearAlgebra: LinearAlgebra, triu @@ -332,7 +323,8 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr start_sizes = ntuple(i -> size(src, i), dims) results = map(CartesianIndices(idxs)) do k res = @allowscalar src[colons..., Tuple(idxs[k])...] - res isa TracedRNumber && (res = Reactant.TracedUtils.broadcast_to_size(res, (1,))) + res isa TracedRNumber && + (res = Reactant.TracedUtils.broadcast_to_size(res, (1,))) return reshape(res, start_sizes..., :) end res = reshape(cat(results...; dims=(dims + 1)), size(dst)) diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index 5b865dd34..f99d6cab9 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -153,7 +153,9 @@ function trace_for(mod, expr) all_syms = Expr(:tuple, counter, external_syms...) args_init = Expr( - :tuple, :(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)), external_syms... + :tuple, + :(Reactant.TracedUtils.promote_to(Reactant.TracedRNumber{Int}, 0)), + external_syms..., ) reactant_code_block = quote diff --git a/src/Compiler.jl b/src/Compiler.jl index 44e4f59d9..5f7158d82 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -780,7 +780,9 @@ function compile(f, args; client=nothing, optimize=true, sync=false) end # Compiling within a compile should return simply the original function -Reactant.@reactant_override function Reactant.Compiler.compile(f, args; client=nothing, optimize=true, sync=false) +Reactant.@reactant_override function Reactant.Compiler.compile( + f, args; client=nothing, optimize=true, sync=false +) return f end diff --git a/src/Interpreter.jl b/src/Interpreter.jl index b3aa5b700..72e27c5d8 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -42,18 +42,8 @@ function set_reactant_abi( # Improve inference by considering call_with_reactant as having the same results as # the original call if f === Reactant.call_with_reactant - arginfo2 = ArgInfo( - fargs isa Nothing ? nothing : - fargs[2:end], - argtypes[2:end], - ) - return abstract_call( - interp, - arginfo2::ArgInfo, - si, - sv, - max_methods, - ) + arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end]) + return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) end return Base.@invoke abstract_call_known( @@ -280,7 +270,12 @@ function overload_autodiff( if width == 1 push!(outtys, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data))) else - push!(outtys, TracedUtils.batch_ty(width, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)))) + push!( + outtys, + TracedUtils.batch_ty( + width, TracedUtils.transpose_ty(MLIR.IR.type(a.mlir_data)) + ), + ) end end else @@ -393,13 +388,21 @@ function overload_autodiff( else idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap - TracedUtils.set!(f.val, path[3:end], TracedUtils.transpose_val(MLIR.IR.result(res, residx))) + TracedUtils.set!( + f.val, + path[3:end], + TracedUtils.transpose_val(MLIR.IR.result(res, residx)), + ) residx += 1 else if fnwrap idx -= 1 end - TracedUtils.set!(args[idx].val, path[3:end], TracedUtils.transpose_val(MLIR.IR.result(res, residx))) + TracedUtils.set!( + args[idx].val, + path[3:end], + TracedUtils.transpose_val(MLIR.IR.result(res, residx)), + ) residx += 1 end end @@ -417,7 +420,12 @@ function overload_autodiff( residx += 1 continue end - set_act!(f, path[3:end], reverse, TracedUtils.transpose_val(MLIR.IR.result(res, residx))) + set_act!( + f, + path[3:end], + reverse, + TracedUtils.transpose_val(MLIR.IR.result(res, residx)), + ) else if fnwrap idx -= 1 @@ -437,7 +445,10 @@ function overload_autodiff( continue end set_act!( - args[idx], path[3:end], reverse, TracedUtils.transpose_val(MLIR.IR.result(res, residx)) + args[idx], + path[3:end], + reverse, + TracedUtils.transpose_val(MLIR.IR.result(res, residx)), ) end residx += 1 @@ -470,27 +481,13 @@ function overload_autodiff( end @reactant_override @noinline function Enzyme.autodiff_deferred( - rmode::Enzyme.Mode, - f::FA, - rt::Type{A}, - args::Vararg{Annotation,Nargs}, -) where { - FA<:Annotation, - A<:Annotation, - Nargs, -} + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,A<:Annotation,Nargs} return overload_autodiff(rmode, f, rt, args...) end @reactant_override @noinline function Enzyme.autodiff( - rmode::Enzyme.Mode, - f::FA, - rt::Type{A}, - args::Vararg{Annotation,Nargs}, -) where { - FA<:Annotation, - A<:Annotation, - Nargs, -} + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,A<:Annotation,Nargs} return overload_autodiff(rmode, f, rt, args...) end diff --git a/src/Ops.jl b/src/Ops.jl index f604f16a6..376122e8b 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -4,13 +4,7 @@ module Ops using ..MLIR: MLIR using ..MLIR.Dialects: stablehlo, chlo, enzyme -using ..Reactant: - Reactant, - TracedRArray, - TracedRNumber, - RArray, - RNumber, - MissingTracedValue +using ..Reactant: Reactant, TracedRArray, TracedRNumber, RArray, RNumber, MissingTracedValue function mlir_type(x::RArray{T,N}) where {T,N} return MLIR.IR.TensorType(size(x), MLIR.IR.Type(T)) @@ -569,7 +563,9 @@ end return TracedRNumber{T}((), res) end -@noinline function clamp(min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T) where {T,N} +@noinline function clamp( + min::T, x::Union{TracedRArray{T,N},TracedRNumber{T}}, max::T +) where {T,N} return clamp(constant(min), x, constant(max)) end @@ -818,17 +814,23 @@ end # end # paralell ops -@noinline function partition_id(; location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__)) +@noinline function partition_id(; + location=mlir_stacktrace("partition_id", @__FILE__, @__LINE__) +) res = MLIR.IR.result(stablehlo.partition_id(; location)) return TracedRNumber{UInt32}((), res) end -@noinline function replica_id(; location=mlir_stacktrace("replica_id", @__FILE__, @__LINE__)) +@noinline function replica_id(; + location=mlir_stacktrace("replica_id", @__FILE__, @__LINE__) +) res = MLIR.IR.result(stablehlo.replica_id(; location)) return TracedRNumber{UInt32}((), res) end -@noinline function after_all(tokens...; location=mlir_stacktrace("after_all", @__FILE__, @__LINE__)) +@noinline function after_all( + tokens...; location=mlir_stacktrace("after_all", @__FILE__, @__LINE__) +) tokens = [token.mlir_data for token in tokens] res = MLIR.IR.result(stablehlo.after_all(tokens; location)) return Token(res) @@ -1078,7 +1080,7 @@ end comparison_direction::String, compare_type=nothing, location=mlir_stacktrace("compare", @__FILE__, @__LINE__), - ) where {AT <: Union{TracedRArray,TracedRNumber}} +) where {AT<:Union{TracedRArray,TracedRNumber}} @assert comparison_direction in ("EQ", "NE", "GE", "GT", "LE", "LT") @assert size(lhs) == size(rhs) diff --git a/src/Reactant.jl b/src/Reactant.jl index bdbc8aefc..ba2da588d 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -122,7 +122,6 @@ include("TracedRArray.jl") include("ConcreteRArray.jl") - include("linear_algebra.jl") const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index e925b00f3..90135e320 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -12,9 +12,9 @@ using ..TracedUtils import ..Ops import ..MLIR import ..ancestor -import ReactantCore +using ReactantCore: ReactantCore import ..TracedUtils: materialize_traced_array -import GPUArraysCore +using GPUArraysCore: GPUArraysCore ReactantCore.is_traced(::TracedRArray) = true @@ -31,13 +31,14 @@ end TracedRArray{T,N}(x::AbstractArray) where {T,N} = convert(TracedRArray{T,N}, x) - function Base.getindex( a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N} ) where {T,N} GPUArraysCore.assertscalar("getindex(::TracedRArray, ::Vararg{Int, N})") - start_indices = [TracedUtils.promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index] + start_indices = [ + TracedUtils.promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index + ] slice_sizes = [Int64(1) for _ in index] res1 = MLIR.IR.result( @@ -107,8 +108,9 @@ function Base.setindex!( v = TracedUtils.broadcast_to_size(v, length.(indices)) v = TracedUtils.promote_to(TracedRArray{T,N}, v) indices = [ - (TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data for - i in indices + ( + TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1 + ).mlir_data for i in indices ] res = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_update_slice( @@ -168,7 +170,9 @@ Base.imag(A::AnyTracedRArray) = zero(A) Base.imag(A::AnyTracedRArray{<:Complex}) = Ops.imag(materialize_traced_array(A)) TracedUtils.promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs) -TracedUtils.promote_to(::TracedRArray{T,N}, rhs) where {T,N} = TracedUtils.promote_to(TracedRArray{T,N}, rhs) +function TracedUtils.promote_to(::TracedRArray{T,N}, rhs) where {T,N} + return TracedUtils.promote_to(TracedRArray{T,N}, rhs) +end for (jlop, hloop, hlocomp, merge) in ((:(Base.:(==)), :compare, "EQ", :all), (:(Base.:(!=)), :compare, "NE", :any)) @@ -214,12 +218,12 @@ function Base.mapreduce( rdims = Int64[] if dims == (:) - for i in 0:(N-1) + for i in 0:(N - 1) push!(rdims, i) end else for i in dims - push!(rdims, i-1) + push!(rdims, i - 1) end end @@ -244,7 +248,7 @@ function Base.mapreduce( toonedims = Int[] outdims = Int[] for i in 1:N - tmp = if in(i-1, rdims) + tmp = if in(i - 1, rdims) 1 else sz = size(A, i) @@ -283,7 +287,9 @@ function Base.mapreducedim!( @nospecialize(R::TracedRArray), A::Base.AbstractArrayOrBroadcasted, ) - tmp = TracedUtils.broadcast_to_size(Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...)) + tmp = TracedUtils.broadcast_to_size( + Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...) + ) R.mlir_data = broadcast(op, R, tmp).mlir_data return R end @@ -295,7 +301,9 @@ function Base.fill!(A::TracedRArray{T,N}, x) where {T,N} end function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2} - bcast = TracedUtils.broadcast_to_size(TracedUtils.promote_to(TracedRNumber{T}, x), size(A)) + bcast = TracedUtils.broadcast_to_size( + TracedUtils.promote_to(TracedRNumber{T}, x), size(A) + ) A.mlir_data = bcast.mlir_data return A end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 88f6f9e9b..df664031e 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -18,7 +18,9 @@ Base.zero(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T Base.one(::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, one(T)) Base.collect(x::TracedRNumber{T}) where {T} = TracedRArray{T,0}((), x.mlir_data, ()) -Base.eps(::Type{TracedRNumber{T}}) where {T} = TracedUtils.promote_to(TracedRNumber{T}, eps(T)) +function Base.eps(::Type{TracedRNumber{T}}) where {T} + return TracedUtils.promote_to(TracedRNumber{T}, eps(T)) +end function Base.convert(::Type{<:TracedRNumber{T}}, x::Number) where {T} return TracedUtils.promote_to(TracedRNumber{T}, T(x)) @@ -63,13 +65,18 @@ function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T} return Ops.convert(TracedRNumber{T}, rhs) end if rhs isa TracedRArray{<:Any,0} - return TracedUtils.promote_to(TracedRNumber{T}, TracedRNumber{eltype(rhs)}((), rhs.mlir_data)) + return TracedUtils.promote_to( + TracedRNumber{T}, TracedRNumber{eltype(rhs)}((), rhs.mlir_data) + ) end - rhs isa Number && return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(fill(T(rhs)))) + rhs isa Number && + return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(fill(T(rhs)))) return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(collect(rhs))) end -TracedUtils.promote_to(::TracedRNumber{T}, rhs) where {T} = TracedUtils.promote_to(TracedRNumber{T}, rhs) +function TracedUtils.promote_to(::TracedRNumber{T}, rhs) where {T} + return TracedUtils.promote_to(TracedRNumber{T}, rhs) +end for (jlop, hloop) in ( (:(Base.min), :minimum), @@ -146,7 +153,11 @@ function Base.ifelse( element-type to the common type. This is semantically different from the \ behavior of `ifelse` in Base. Use with caution" maxlog = 1 T = promote_type(T1, T2) - return ifelse(pred, TracedUtils.promote_to(TracedRNumber{T}, x), TracedUtils.promote_to(TracedRNumber{T}, y)) + return ifelse( + pred, + TracedUtils.promote_to(TracedRNumber{T}, x), + TracedUtils.promote_to(TracedRNumber{T}, y), + ) end function Base.ifelse( @@ -162,12 +173,14 @@ for (T1, T2) in zip((Bool, Integer), (Bool, Integer)) @eval begin function Base.:&(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) return Ops.and( - TracedUtils.promote_to(TracedRNumber{$(T)}, x), TracedUtils.promote_to(TracedRNumber{$(T)}, y) + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), ) end function Base.:|(x::TracedRNumber{<:$(T1)}, y::TracedRNumber{<:$(T2)}) return Ops.or( - TracedUtils.promote_to(TracedRNumber{$(T)}, x), TracedUtils.promote_to(TracedRNumber{$(T)}, y) + TracedUtils.promote_to(TracedRNumber{$(T)}, x), + TracedUtils.promote_to(TracedRNumber{$(T)}, y), ) end Base.:!(x::TracedRNumber{<:$(T1)}) = Ops.not(x) @@ -216,7 +229,9 @@ function Base.fill(x::TracedRNumber, dims::NTuple{N,Integer}) where {N} return TracedUtils.broadcast_to_size(x, dims) end -Base.float(x::TracedRNumber{T}) where {T} = TracedUtils.promote_to(TracedRNumber{float(T)}, x) +function Base.float(x::TracedRNumber{T}) where {T} + return TracedUtils.promote_to(TracedRNumber{float(T)}, x) +end # Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 51d25a29f..d4fc11e94 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -3,9 +3,17 @@ # within compilation. However, it means these functions are a _lot_ faster to compile. module TracedUtils -import LinearAlgebra -import Adapt -using ..Reactant: RArray, RNumber, TracedRArray, TracedRNumber, WrappedTracedRArray, AnyTracedRArray, MissingTracedValue, OrderedIdDict +using LinearAlgebra: LinearAlgebra +using Adapt: Adapt +using ..Reactant: + RArray, + RNumber, + TracedRArray, + TracedRNumber, + WrappedTracedRArray, + AnyTracedRArray, + MissingTracedValue, + OrderedIdDict import ..Reactant import ..Reactant.MLIR import ..ReactantPrimitive @@ -92,7 +100,6 @@ function get_ancestor_indices(x::WrappedTracedRArray, indices...) return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) end - function batch_ty(width, mlirty) return MLIR.IR.TensorType([width, size(mlirty)...], eltype(mlirty)) end @@ -200,7 +207,7 @@ function make_mlir_fn( # TODO fix it for kwargs #if concretein - Reactant.call_with_reactant(f, traced_args...) + Reactant.call_with_reactant(f, traced_args...) #else # f(traced_args...) #end @@ -219,7 +226,10 @@ function make_mlir_fn( # marks buffers to be donated for i in 1:N Reactant.make_tracer( - seen_results, traced_args[i], concretein ? (:resargs, i) : (), Reactant.TracedTrack + seen_results, + traced_args[i], + concretein ? (:resargs, i) : (), + Reactant.TracedTrack, ) end @@ -294,7 +304,9 @@ elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitive} = x struct TypeCast{T<:ReactantPrimitive} <: Function end -(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = TracedUtils.promote_to(TracedRNumber{T}, x) +function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} + return TracedUtils.promote_to(TracedRNumber{T}, x) +end function elem_apply( ::Type{T}, x::TracedRArray{T2} @@ -442,7 +454,9 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} end seen_results = OrderedIdDict() - traced2_result = Reactant.make_tracer(seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape) + traced2_result = Reactant.make_tracer( + seen_results, result, (), Reactant.TracedSetPath; tobatch=OutShape + ) func2.operation = MLIR.API.MlirOperation(C_NULL) diff --git a/src/Tracing.jl b/src/Tracing.jl index bb7116eb9..62bb71f69 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -498,14 +498,18 @@ function make_tracer( return ConcreteRNumber(prev) else if mode == TracedTrack - res = TracedRNumber{RT}((path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data) + res = TracedRNumber{RT}( + (path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data + ) if !haskey(seen, prev) return seen[prev] = res end return res elseif mode == TracedSetPath haskey(seen, prev) && return seen[prev] - res = TracedRNumber{RT}((path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data) + res = TracedRNumber{RT}( + (path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data + ) seen[prev] = res return res elseif mode == TracedToConcrete diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index 5cb10a861..c011f8aec 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -8,10 +8,7 @@ import ..AnyTracedRMatrix import ..AnyTracedRVector import ..TracedUtils -using ..TracedUtils: - get_mlir_data, - materialize_traced_array, - set_mlir_data! +using ..TracedUtils: get_mlir_data, materialize_traced_array, set_mlir_data! import ..Ops import ..MLIR @@ -80,7 +77,8 @@ end function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = Ops.subtract( - Ops.iota(Int64, [size(X)...]; iota_dimension=2), TracedUtils.broadcast_to_size(k, size(X)) + Ops.iota(Int64, [size(X)...]; iota_dimension=2), + TracedUtils.broadcast_to_size(k, size(X)), ) idxs = Ops.compare(iota_1, iota_2; comparison_direction="LE") X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data @@ -90,7 +88,8 @@ end function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = Ops.subtract( - Ops.iota(Int64, [size(X)...]; iota_dimension=2), TracedUtils.broadcast_to_size(k, size(X)) + Ops.iota(Int64, [size(X)...]; iota_dimension=2), + TracedUtils.broadcast_to_size(k, size(X)), ) idxs = Ops.compare(iota_1, iota_2; comparison_direction="GE") X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data @@ -134,7 +133,9 @@ function LinearAlgebra.diag(x::AnyTracedRArray{T,2}, k::Integer=0) where {T} ) #! format: on - slice_sizes = get_mlir_data(Reactant.TracedUtils.promote_to(TracedRArray{Int,1}, [1, 1])) + slice_sizes = get_mlir_data( + Reactant.TracedUtils.promote_to(TracedRArray{Int,1}, [1, 1]) + ) res = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_gather( get_mlir_data(y), idxs, slice_sizes; dimension_numbers @@ -158,7 +159,9 @@ function LinearAlgebra.diagm(m::Integer, n::Integer, v::AnyTracedRArray{T,1}) wh mat = (v .+ zero(v)') .* diag_indicator return Ops.pad( - mat, TracedUtils.promote_to(TracedRNumber{T}, 0); high=[m - length(v), n - length(v)] + mat, + TracedUtils.promote_to(TracedRNumber{T}, 0); + high=[m - length(v), n - length(v)], ) end diff --git a/src/utils.jl b/src/utils.jl index 1d76c638a..b65077c03 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -33,21 +33,44 @@ function throw_method_error(argtys) throw(MethodError(argtys[1], argtys[2:end])) end - - -@inline function lookup_world(@nospecialize(sig::Type), world::UInt, mt::Union{Nothing,Core.MethodTable}, min_world::Ref{UInt}, max_world::Ref{UInt}) - res = ccall(:jl_gf_invoke_lookup_worlds, Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - sig, mt, world, min_world, max_world) +@inline function lookup_world( + @nospecialize(sig::Type), + world::UInt, + mt::Union{Nothing,Core.MethodTable}, + min_world::Ref{UInt}, + max_world::Ref{UInt}, +) + res = ccall( + :jl_gf_invoke_lookup_worlds, + Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, + mt, + world, + min_world, + max_world, + ) return res end -@inline function lookup_world(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.InternalMethodTable, min_world::Ref{UInt}, max_world::Ref{UInt}) +@inline function lookup_world( + @nospecialize(sig::Type), + world::UInt, + mt::Core.Compiler.InternalMethodTable, + min_world::Ref{UInt}, + max_world::Ref{UInt}, +) res = lookup_world(sig, mt.world, nothing, min_world, max_world) return res end -@inline function lookup_world(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.OverlayMethodTable, min_world::Ref{UInt}, max_world::Ref{UInt}) +@inline function lookup_world( + @nospecialize(sig::Type), + world::UInt, + mt::Core.Compiler.OverlayMethodTable, + min_world::Ref{UInt}, + max_world::Ref{UInt}, +) res = lookup_world(sig, mt.world, mt.mt, min_world, max_world) if res !== nothing return res @@ -74,7 +97,9 @@ function should_rewrite_ft(@nospecialize(ft)) if ft <: Core.Function mod = ft.name.module # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions - if has_ancestor(mod, Reactant.Ops) || has_ancestor(mod, Reactant.TracedUtils) || has_ancestor(mod, Reactant.MLIR) + if has_ancestor(mod, Reactant.Ops) || + has_ancestor(mod, Reactant.TracedUtils) || + has_ancestor(mod, Reactant.MLIR) return false end end @@ -96,7 +121,10 @@ function should_rewrite_ft(@nospecialize(ft)) end # Don't rewrite traced constructors - if ft <: Type{<:TracedRArray} || ft <: Type{<:TracedRNumber} || ft === Type{MLIR.IR.Location} || ft === Type{MLIR.IR.Block} + if ft <: Type{<:TracedRArray} || + ft <: Type{<:TracedRNumber} || + ft === Type{MLIR.IR.Location} || + ft === Type{MLIR.IR.Block} return false end @@ -104,9 +132,13 @@ function should_rewrite_ft(@nospecialize(ft)) if ft <: typeof(Core.Compiler.return_type) return false end - + # Perf optimizations - if ft <: typeof(Base.typemax) || ft <: typeof(Base.typemin) || ft <: typeof(Base.getproperty) || ft <: typeof(Base.vect) || ft <: typeof(Base.eltype) + if ft <: typeof(Base.typemax) || + ft <: typeof(Base.typemin) || + ft <: typeof(Base.getproperty) || + ft <: typeof(Base.vect) || + ft <: typeof(Base.eltype) return false end @@ -151,10 +183,9 @@ function rewrite_inst(inst, ir, interp) min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - if !method.isva || !Base.isvarargtype(sig.parameters[end]) - sig2 = Tuple{typeof(call_with_reactant), sig.parameters...} + sig2 = Tuple{typeof(call_with_reactant),sig.parameters...} else vartup = inst.args[end] ns = Type[] @@ -162,11 +193,15 @@ function rewrite_inst(inst, ir, interp) for i in 1:(length(inst.args) - 1 - (length(sig.parameters) - 1)) push!(ns, eT) end - sig2 = Tuple{typeof(call_with_reactant), sig.parameters[1:end-1]..., ns...} + sig2 = Tuple{ + typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns... + } end - lookup_result = lookup_world(sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world) - + lookup_result = lookup_world( + sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world + ) + match = lookup_result::Core.MethodMatch # look up the method and code instance mi = ccall( @@ -185,28 +220,43 @@ function rewrite_inst(inst, ir, interp) return false, inst end -const oc_captures = Dict{Tuple{Type, Type, Core.CodeInfo, Int, Bool, Any}, Core.OpaqueClosure}() +const oc_captures = Dict{Tuple{Type,Type,Core.CodeInfo,Int,Bool,Any},Core.OpaqueClosure}() # Caching is both good to reducing compile times and necessary to work around julia bugs # in OpaqueClosure's: https://github.com/JuliaLang/julia/issues/56833 -function make_oc(sig::Type, rt::Type, src::Core.CodeInfo, nargs::Int, isva::Bool, f::Any)::Core.OpaqueClosure +function make_oc( + sig::Type, rt::Type, src::Core.CodeInfo, nargs::Int, isva::Bool, f::Any +)::Core.OpaqueClosure key = (sig, rt, src, nargs, isva, f) if haskey(oc_captures, key) return oc_captures[key] else - ores = ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - sig, rt, rt, @__MODULE__, src, 0, nothing, nargs, isva, f, true)::Core.OpaqueClosure + ores = ccall( + :jl_new_opaque_closure_from_code_info, + Any, + (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), + sig, + rt, + rt, + @__MODULE__, + src, + 0, + nothing, + nargs, + isva, + f, + true, + )::Core.OpaqueClosure oc_captures[key] = ores return ores end end function safe_print(name, x) - ccall(:jl_, Cvoid, (Any,), name*" "*string(x)) + return ccall(:jl_, Cvoid, (Any,), name * " " * string(x)) end -const DEBUG_INTERP = Ref(false) - +const DEBUG_INTERP = Ref(false) # Generator function which ensures that all calls to the function are executed within the ReactantInterpreter # In particular this entails two pieces: @@ -236,7 +286,9 @@ function call_with_reactant_generator( if args[1] <: Core.Builtin return stub(world, source, builtin_error) end - method_error = :(throw(MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world))) + method_error = :(throw( + MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) + )) interp = ReactantInterpreter(; world) @@ -245,8 +297,10 @@ function call_with_reactant_generator( min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - lookup_result = lookup_world(sig, world, Core.Compiler.method_table(interp), min_world, max_world) - + lookup_result = lookup_world( + sig, world, Core.Compiler.method_table(interp), min_world, max_world + ) + overdubbed_code = Any[] overdubbed_codelocs = Int32[] @@ -255,21 +309,36 @@ function call_with_reactant_generator( return stub(world, source, method_error) tmp_min_world = Ref{UInt}(typemin(UInt)) tmp_max_world = Ref{UInt}(typemax(UInt)) - match = ccall(:jl_gf_invoke_lookup_worlds, Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - Tuple{typeof(throw_method_error), sig}, #=mt=# nothing, world, tmp_min_world, tmp_max_world) + match = ccall( + :jl_gf_invoke_lookup_worlds, + Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + Tuple{typeof(throw_method_error),sig}, + nothing, + world, + tmp_min_world, + tmp_max_world, + ) #=mt=# @assert match !== nothing # look up the method and code instance - mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, - (Any, Any, Any), match.method, match.spec_types, match.sparams) - + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, + ) + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo src = copy(ci) src.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] - src.edges = Any[ccall(:jl_method_table_for, Any, (Any,), sig)::Core.MethodTable, sig] + src.edges = Any[ + ccall(:jl_method_table_for, Any, (Any,), sig)::Core.MethodTable, sig + ] src.min_world = min_world[] src.max_world = max_world[] @@ -278,14 +347,12 @@ function call_with_reactant_generator( expr_fn = Core.SSAValue(length(overdubbed_code)) - push!(overdubbed_code, :($(Base.lastindex)($(Core.Argument(2))))) push!(overdubbed_codelocs, 0) expr_lastindex = Core.SSAValue(length(overdubbed_code)) - - push!(overdubbed_code, :(2:$expr_lastindex)) + push!(overdubbed_code, :(2:($expr_lastindex))) push!(overdubbed_codelocs, 0) expr_slice = Core.SSAValue(length(overdubbed_code)) @@ -303,10 +370,7 @@ function call_with_reactant_generator( push!(overdubbed_code, :($(Base.throw)($expr_method))) push!(overdubbed_codelocs, 0) - push!( - overdubbed_code, - Core.ReturnNode(Core.SSAValue(length(overdubbed_code))) - ) + push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code)))) push!(overdubbed_codelocs, 0) src.code = overdubbed_code @@ -363,7 +427,7 @@ function call_with_reactant_generator( Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end end - + if DEBUG_INTERP[] safe_print("ir1", ir) end @@ -387,11 +451,11 @@ function call_with_reactant_generator( Core.Compiler.setindex!(ir.stmts[i], Any, :type) end end - + Core.Compiler.finish(interp, opt, ir, caller) src = Core.Compiler.ir_to_codeinf!(opt) - + if DEBUG_INTERP[] safe_print("src", src) end @@ -430,12 +494,12 @@ function call_with_reactant_generator( n_actual_args = length(redub_arguments) tys = [] - + iter_args = n_actual_args if method.isva - iter_args = min(n_actual_args, n_method_args-1) + iter_args = min(n_actual_args, n_method_args - 1) end - + for i in 1:iter_args actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset @@ -445,14 +509,21 @@ function call_with_reactant_generator( offset += 1 push!(fn_args, Core.SSAValue(length(overdubbed_code))) push!(tys, redub_arguments[i]) - + if DEBUG_INTERP[] - push!(overdubbed_code, Expr(:call, safe_print, "fn arg["*string(length(fn_args))*"]", fn_args[end])) + push!( + overdubbed_code, + Expr( + :call, + safe_print, + "fn arg[" * string(length(fn_args)) * "]", + fn_args[end], + ), + ) push!(overdubbed_codelocs, code_info.codelocs[1]) end end - # If `method` is a varargs method, we have to restructure the original method call's # trailing arguments into a tuple and assign that tuple to the expected argument slot. if method.isva @@ -467,21 +538,27 @@ function call_with_reactant_generator( offset += 1 end - push!( - overdubbed_code, trailing_arguments - ) + push!(overdubbed_code, trailing_arguments) push!(overdubbed_codelocs, code_info.codelocs[1]) push!(fn_args, Core.SSAValue(length(overdubbed_code))) push!(tys, Tuple{redub_arguments[n_method_args:n_actual_args]...}) - + if DEBUG_INTERP[] - push!(overdubbed_code, Expr(:call, safe_print, "fn arg["*string(length(fn_args))*"]", fn_args[end])) + push!( + overdubbed_code, + Expr( + :call, + safe_print, + "fn arg[" * string(length(fn_args)) * "]", + fn_args[end], + ), + ) push!(overdubbed_codelocs, code_info.codelocs[1]) end end rt = Base.Experimental.compute_ir_rettype(ir) - + # ocva = method.isva ocva = false # method.isva @@ -497,40 +574,22 @@ function call_with_reactant_generator( # Opaque closures also require takign the function argument. We can work around the latter # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure oc = if false && Base.issingletontype(args[1]) - res = Core._call_in_world_total(world, make_oc, octup, rt, src, ocnargs, ocva, args[1].instance)::Core.OpaqueClosure + res = Core._call_in_world_total( + world, make_oc, octup, rt, src, ocnargs, ocva, args[1].instance + )::Core.OpaqueClosure else farg = fn_args[1] - push!(overdubbed_code, - Expr(:call, - make_oc, - octup, - rt, - src, - ocnargs, - ocva, - farg - ) - ) + push!(overdubbed_code, Expr(:call, make_oc, octup, rt, src, ocnargs, ocva, farg)) push!(overdubbed_codelocs, code_info.codelocs[1]) Core.SSAValue(length(overdubbed_code)) end - push!( - overdubbed_code, - Expr( - :(call), - oc, - fn_args[2:end]... - ), - ) + push!(overdubbed_code, Expr(:(call), oc, fn_args[2:end]...)) push!(overdubbed_codelocs, code_info.codelocs[1]) - push!( - overdubbed_code, - Core.ReturnNode(Core.SSAValue(length(overdubbed_code))) - ) + push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code)))) push!(overdubbed_codelocs, code_info.codelocs[1]) #=== set `code_info`/`reflection` fields accordingly ===# @@ -543,7 +602,7 @@ function call_with_reactant_generator( code_info.codelocs = overdubbed_codelocs code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - + if DEBUG_INTERP[] safe_print("code_info", code_info) end