Skip to content

Commit

Permalink
refactor: rework how the overlays are implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 16, 2024
1 parent 30d2732 commit 67d05aa
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
16 changes: 15 additions & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,23 @@
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
# we should move all the reactant_overrides to relevant files.

# Helper Function to determine if we are inside the ReactantInterpreter
"""
within_reactant_interpreter()
Returns `true` if we are currently inside the ReactantInterpreter.
"""
@noinline within_reactant_interpreter() = false
@reactant_overlay @noinline within_reactant_interpreter() = true

# Compiling within a compile should return simply the original function
@reactant_overlay function Compiler.compile(
f, args; client=nothing, optimize=true, sync=false
)
return f
end

# Enzyme overrides
# Enzyme.jl overlays
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
) where {FA<:Annotation,A<:Annotation,Nargs}
Expand All @@ -22,3 +31,8 @@ end
) where {FA<:Annotation,A<:Annotation,Nargs}
return overload_autodiff(rmode, f, rt, args...)
end

# Random.jl overlays
@reactant_overlay @noinline function Random.default_rng()
return call_with_reactant(TracedRandom.default_rng)
end
22 changes: 8 additions & 14 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,18 @@ using ..Reactant:
AnyTracedRArray,
Reactant,
TracedUtils,
@reactant_override,
Ops,
ConcreteRArray
using Random: Random

function Random.seed!(rng::TracedRNG, seed::Number)
seed = reinterpret(UInt64, Random.hash_seed(seed))
return Random.seed!(rng, ConcreteRArray(seed[1:length(rng.seed)]))
end

@reactant_override @noinline function Random.seed!(rng::TracedRNG, seed::Number)
seed = reinterpret(UInt64, Random.hash_seed(seed))
return Random.seed!(
rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)])
)
seed = if Reactant.within_reactant_interpreter()
TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)])
else
ConcreteRArray(seed[1:length(rng.seed)])
end
return Random.seed!(rng, seed)
end

function Random.seed!(
Expand All @@ -41,14 +38,11 @@ make_seed() = rand(Random.RandomDevice(), UInt64, 2)
TracedRNG() = TracedRNG(ConcreteRArray(make_seed()))
TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT")

default_rng() = TracedRNG()
function default_rng_inside_interpreter()
function default_rng()
Reactant.within_reactant_interpreter() || return TracedRNG()
return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT")
end

@reactant_override @noinline Random.default_rng() = default_rng_inside_interpreter()
@reactant_override @noinline default_rng() = default_rng_inside_interpreter()

function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
length(A) == 0 && return A
res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm)
Expand Down

0 comments on commit 67d05aa

Please sign in to comment.