diff --git a/src/Compiler.jl b/src/Compiler.jl index 5f7158d8..cc32a90b 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -779,13 +779,6 @@ function compile(f, args; client=nothing, optimize=true, sync=false) return register_thunk(fname, body) end -# Compiling within a compile should return simply the original function -Reactant.@reactant_override function Reactant.Compiler.compile( - f, args; client=nothing, optimize=true, sync=false -) - return f -end - # inspired by RuntimeGeneratedFunction.jl const __thunk_body_cache = Dict{Symbol,Expr}() diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 72e27c5d..cd265c9e 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -479,15 +479,3 @@ function overload_autodiff( end end end - -@reactant_override @noinline function Enzyme.autodiff_deferred( - rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} -) where {FA<:Annotation,A<:Annotation,Nargs} - return overload_autodiff(rmode, f, rt, args...) -end - -@reactant_override @noinline function Enzyme.autodiff( - rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} -) where {FA<:Annotation,A<:Annotation,Nargs} - return overload_autodiff(rmode, f, rt, args...) -end diff --git a/src/Overrides.jl b/src/Overrides.jl new file mode 100644 index 00000000..0df144d3 --- /dev/null +++ b/src/Overrides.jl @@ -0,0 +1,24 @@ +# NOTE: We are placing all the reactant_overrides here to avoid incompatibilities with +# Revise.jl. Essentially files that contain reactant_overrides cannot be revised +# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved +# we should move all the reactant_overrides to relevant files. + +# Compiling within a compile should return simply the original function +@reactant_override function Compiler.compile( + f, args; client=nothing, optimize=true, sync=false +) + return f +end + +# Enzyme overrides +@reactant_override @noinline function Enzyme.autodiff_deferred( + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,A<:Annotation,Nargs} + return overload_autodiff(rmode, f, rt, args...) +end + +@reactant_override @noinline function Enzyme.autodiff( + rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} +) where {FA<:Annotation,A<:Annotation,Nargs} + return overload_autodiff(rmode, f, rt, args...) +end diff --git a/src/Reactant.jl b/src/Reactant.jl index ba2da588..9830965d 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -130,6 +130,8 @@ include("ControlFlow.jl") include("Tracing.jl") include("Compiler.jl") +include("Overrides.jl") + function Enzyme.make_zero( ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) )::RT where {copy_if_inactive,RT<:RArray}