Skip to content

Commit

Permalink
refactor: move overrides into a separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 15, 2024
1 parent 65e9976 commit 3990373
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
7 changes: 0 additions & 7 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,13 +779,6 @@ function compile(f, args; client=nothing, optimize=true, sync=false)
return register_thunk(fname, body)
end

# Compiling within a compile should return simply the original function
Reactant.@reactant_override function Reactant.Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# inspired by RuntimeGeneratedFunction.jl
const __thunk_body_cache = Dict{Symbol,Expr}()

Expand Down
24 changes: 24 additions & 0 deletions src/Overrides.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# NOTE: We are placing all the reactant_overrides here to avoid incompatibilities with
# Revise.jl. Essentially files that contain reactant_overrides cannot be revised
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
# we should move all the reactant_overrides to relevant files.

# Compiling within a compile should return simply the original function
@reactant_override function Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# Enzyme overrides
@reactant_override @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

@reactant_override @noinline function Enzyme.autodiff(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end
2 changes: 2 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ include("ControlFlow.jl")
include("Tracing.jl")
include("Compiler.jl")

include("Overrides.jl")

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
)::RT where {copy_if_inactive,RT<:RArray}
Expand Down

0 comments on commit 3990373

Please sign in to comment.