diff --git a/src/Overlay.jl b/src/Overlay.jl index 6d4752ac..18da97f9 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -3,6 +3,15 @@ # correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved # we should move all the reactant_overrides to relevant files. +# Helper Function to determine if we are inside the ReactantInterpreter +""" + within_reactant_interpreter() + +Returns `true` if we are currently inside the ReactantInterpreter. +""" +@noinline within_reactant_interpreter() = false +@reactant_overlay @noinline within_reactant_interpreter() = true + # Compiling within a compile should return simply the original function @reactant_overlay function Compiler.compile( f, args; client=nothing, optimize=true, sync=false @@ -10,7 +19,7 @@ return f end -# Enzyme overrides +# Enzyme.jl overlays @reactant_overlay @noinline function Enzyme.autodiff_deferred( rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs} ) where {FA<:Annotation,A<:Annotation,Nargs} @@ -22,3 +31,8 @@ end ) where {FA<:Annotation,A<:Annotation,Nargs} return overload_autodiff(rmode, f, rt, args...) end + +# Random.jl overlays +@reactant_overlay @noinline function Random.default_rng() + return call_with_reactant(TracedRandom.default_rng) +end diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index eb652b16..d5bda428 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -12,21 +12,18 @@ using ..Reactant: AnyTracedRArray, Reactant, TracedUtils, - @reactant_override, Ops, ConcreteRArray using Random: Random function Random.seed!(rng::TracedRNG, seed::Number) seed = reinterpret(UInt64, Random.hash_seed(seed)) - return Random.seed!(rng, ConcreteRArray(seed[1:length(rng.seed)])) -end - -@reactant_override @noinline function Random.seed!(rng::TracedRNG, seed::Number) - seed = reinterpret(UInt64, Random.hash_seed(seed)) - return Random.seed!( - rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)]) - ) + seed = if Reactant.within_reactant_interpreter() + TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)]) + else + ConcreteRArray(seed[1:length(rng.seed)]) + end + return Random.seed!(rng, seed) end function Random.seed!( @@ -41,14 +38,11 @@ make_seed() = rand(Random.RandomDevice(), UInt64, 2) TracedRNG() = TracedRNG(ConcreteRArray(make_seed())) TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") -default_rng() = TracedRNG() -function default_rng_inside_interpreter() +function default_rng() + Reactant.within_reactant_interpreter() || return TracedRNG() return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") end -@reactant_override @noinline Random.default_rng() = default_rng_inside_interpreter() -@reactant_override @noinline default_rng() = default_rng_inside_interpreter() - function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm)