diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index ec261c66..2d9e3710 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -1,7 +1,7 @@ steps: - group: ":test_tube: Tests" steps: - - label: "CUDA Julia v{{matrix.version}} -- {{matrix.group}}" + - label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}}" matrix: setup: version: @@ -33,7 +33,7 @@ steps: env: REACTANT_TEST_GROUP: "{{matrix.group}}" if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 + timeout_in_minutes: 120 - label: ":julia: :linux: aarch64 - Julia v{{matrix.version}} -- {{matrix.group}}" matrix: @@ -70,7 +70,7 @@ steps: env: REACTANT_TEST_GROUP: "{{matrix.group}}" if: build.message !~ /\[skip tests\]/ - timeout_in_minutes: 60 + timeout_in_minutes: 120 - group: ":racehorse: Benchmarks" steps: diff --git a/src/Compiler.jl b/src/Compiler.jl index 5f7158d8..cc32a90b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -779,13 +779,6 @@ function compile(f, args; client=nothing, optimize=true, sync=false) return register_thunk(fname, body) end -# Compiling within a compile should return simply the original function -Reactant.@reactant_override function Reactant.Compiler.compile( - f, args; client=nothing, optimize=true, sync=false -) - return f -end - # inspired by RuntimeGeneratedFunction.jl const __thunk_body_cache = Dict{Symbol,Expr}() diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 72e27c5d..4b71a134 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -23,7 +23,7 @@ import Core.Compiler: Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE) -function var"@reactant_override"(__source__::LineNumberNode, __module__::Module, def) +function var"@reactant_overlay"(__source__::LineNumberNode, __module__::Module, def) return Base.Experimental.var"@overlay"( __source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def ) @@ -479,15 +479,3 @@ function overload_autodiff( end end end - -@reactant_override @noinline function Enzyme.autodiff_deferred( - rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} -) where {FA<:Annotation,A<:Annotation,Nargs} - return overload_autodiff(rmode, f, rt, args...) -end - -@reactant_override @noinline function Enzyme.autodiff( - rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} -) where {FA<:Annotation,A<:Annotation,Nargs} - return overload_autodiff(rmode, f, rt, args...) -end diff --git a/src/Overlay.jl b/src/Overlay.jl new file mode 100644 index 00000000..6d4752ac --- /dev/null +++ b/src/Overlay.jl @@ -0,0 +1,24 @@ +# NOTE: We are placing all the reactant_overrides here to avoid incompatibilities with +# Revise.jl. Essentially files that contain reactant_overrides cannot be revised +# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved +# we should move all the reactant_overrides to relevant files. + +# Compiling within a compile should return simply the original function +@reactant_overlay function Compiler.compile( + f, args; client=nothing, optimize=true, sync=false +) + return f +end + +# Enzyme overrides +@reactant_overlay @noinline function Enzyme.autodiff_deferred( + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,A<:Annotation,Nargs} + return overload_autodiff(rmode, f, rt, args...) +end + +@reactant_overlay @noinline function Enzyme.autodiff( + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,A<:Annotation,Nargs} + return overload_autodiff(rmode, f, rt, args...) +end diff --git a/src/Reactant.jl b/src/Reactant.jl index ba2da588..e7c8805d 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -130,6 +130,8 @@ include("ControlFlow.jl") include("Tracing.jl") include("Compiler.jl") +include("Overlay.jl") + function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) )::RT where {copy_if_inactive,RT<:RArray}