Skip to content

Commit

Permalink
refactor: move overrides into a separate file (#379)
Browse files Browse the repository at this point in the history
* refactor: move overrides into a separate file

* fix: traced_getfield move to TracedUtils

* refactor: rename to overlay

* ci: increase build time allowance

* Revert "fix: traced_getfield move to TracedUtils"

This reverts commit f35d911.
  • Loading branch information
avik-pal authored Dec 16, 2024
1 parent 668e2fc commit a98c02b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 23 deletions.
6 changes: 3 additions & 3 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
steps:
- group: ":test_tube: Tests"
steps:
- label: "CUDA Julia v{{matrix.version}} -- {{matrix.group}}"
- label: ":julia: :linux: CUDA Julia v{{matrix.version}} -- {{matrix.group}}"
matrix:
setup:
version:
Expand Down Expand Up @@ -33,7 +33,7 @@ steps:
env:
REACTANT_TEST_GROUP: "{{matrix.group}}"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
timeout_in_minutes: 120

- label: ":julia: :linux: aarch64 - Julia v{{matrix.version}} -- {{matrix.group}}"
matrix:
Expand Down Expand Up @@ -70,7 +70,7 @@ steps:
env:
REACTANT_TEST_GROUP: "{{matrix.group}}"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
timeout_in_minutes: 120

- group: ":racehorse: Benchmarks"
steps:
Expand Down
7 changes: 0 additions & 7 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -779,13 +779,6 @@ function compile(f, args; client=nothing, optimize=true, sync=false)
return register_thunk(fname, body)
end

# Compiling within a compile should return simply the original function
Reactant.@reactant_override function Reactant.Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# inspired by RuntimeGeneratedFunction.jl
const __thunk_body_cache = Dict{Symbol,Expr}()

Expand Down
14 changes: 1 addition & 13 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import Core.Compiler:

Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE)

function var"@reactant_override"(__source__::LineNumberNode, __module__::Module, def)
function var"@reactant_overlay"(__source__::LineNumberNode, __module__::Module, def)
return Base.Experimental.var"@overlay"(
__source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def
)
Expand Down Expand Up @@ -479,15 +479,3 @@ function overload_autodiff(
end
end
end

@reactant_override @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

@reactant_override @noinline function Enzyme.autodiff(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end
24 changes: 24 additions & 0 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# NOTE: We are placing all the reactant_overrides here to avoid incompatibilities with
# Revise.jl. Essentially files that contain reactant_overrides cannot be revised
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
# we should move all the reactant_overrides to relevant files.

# Compiling within a compile should return simply the original function
@reactant_overlay function Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# Enzyme overrides
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

@reactant_overlay @noinline function Enzyme.autodiff(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end
2 changes: 2 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ include("ControlFlow.jl")
include("Tracing.jl")
include("Compiler.jl")

include("Overlay.jl")

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
)::RT where {copy_if_inactive,RT<:RArray}
Expand Down

0 comments on commit a98c02b

Please sign in to comment.