Skip to content

Commit

Permalink
refactor: move overrides into a separate file
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 15, 2024
1 parent 65e9976 commit 2120d3b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 22 deletions.
7 changes: 0 additions & 7 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,13 +779,6 @@ function compile(f, args; client=nothing, optimize=true, sync=false)
return register_thunk(fname, body)
end

# Compiling within a compile should return simply the original function
Reactant.@reactant_override function Reactant.Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# inspired by RuntimeGeneratedFunction.jl
const __thunk_body_cache = Dict{Symbol,Expr}()

Expand Down
12 changes: 0 additions & 12 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,3 @@ function overload_autodiff(
end
end
end

@reactant_override @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

@reactant_override @noinline function Enzyme.autodiff(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end
24 changes: 24 additions & 0 deletions src/Overrides.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# NOTE: We are placing all the reactant_overrides here to avoid incompatibilities with
# Revise.jl. Essentially files that contain reactant_overrides cannot be revised
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
# we should move all the reactant_overrides to relevant files.

# Compiling within a compile should return simply the original function
@reactant_override function Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# Enzyme overrides
@reactant_override @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

@reactant_override @noinline function Enzyme.autodiff(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end
2 changes: 2 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ include("ControlFlow.jl")
include("Tracing.jl")
include("Compiler.jl")

include("Overrides.jl")

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
)::RT where {copy_if_inactive,RT<:RArray}
Expand Down
7 changes: 4 additions & 3 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ using ..Reactant:
WrappedTracedRArray,
AnyTracedRArray,
MissingTracedValue,
OrderedIdDict
OrderedIdDict,
Compiler
import ..Reactant
import ..Reactant.MLIR
import ..ReactantPrimitive
Expand Down Expand Up @@ -323,7 +324,7 @@ end

function push_val!(ad_inputs, x, path)
for p in path
x = traced_getfield(x, p)
x = Compiler.traced_getfield(x, p)
end
x = x.mlir_data
return push!(ad_inputs, x)
Expand All @@ -343,7 +344,7 @@ end

function set!(x, path, tostore; emptypath=false)
for p in path
x = traced_getfield(x, p)
x = Compiler.traced_getfield(x, p)
end

x.mlir_data = tostore
Expand Down

0 comments on commit 2120d3b

Please sign in to comment.