From e99089ee97b1ad498d10934080137f686fd20d4c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 15 Dec 2024 08:59:18 +0530 Subject: [PATCH] feat: use the override macro --- src/Interpreter.jl | 12 ------------ src/stdlibs/Random.jl | 9 +++++---- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 5e15916e2..72e27c5d8 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -46,18 +46,6 @@ function set_reactant_abi( return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) end - # ensures we are not generating a constant array in the trace - # https://github.com/EnzymeAD/Reactant.jl/issues/356 - if (f === Random.default_rng || f === default_rng) && length(argtypes) == 1 - arginfo2 = ArgInfo( - fargs isa Nothing ? nothing : Any[:($(default_rng_inside_interpreter))], - Any[Core.Const(default_rng_inside_interpreter)], - ) - return abstract_call_known( - interp, default_rng_inside_interpreter, arginfo2, si, sv, max_methods - ) - end - return Base.@invoke abstract_call_known( interp::AbstractInterpreter, f::Any, diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 090b72bb3..656d61383 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -27,16 +27,17 @@ TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT") default_rng() = TracedRNG() function default_rng_inside_interpreter() - return TracedRNG(promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") + return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT") end -# XXX: Currently we get an illegal instruction if we don't call Random.default_rng() +@reactant_override @noinline Random.default_rng() = default_rng_inside_interpreter() +@reactant_override @noinline default_rng() = default_rng_inside_interpreter() function Random.rand!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} length(A) == 0 && return A res = Ops.rng_bit_generator(T, rng.seed, [size(A)...]; rng.algorithm) rng.seed = res.output_state - set_mlir_data!(A, res.output.mlir_data) + TracedUtils.set_mlir_data!(A, res.output.mlir_data) return A end @@ -49,7 +50,7 @@ function Random.randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N} ) probit = Ops.erf_inv(scaled_uniform) rand_normal = Ops.multiply(probit, Ops.constant(fill(sqrt(T(2)), size(A)))) - set_mlir_data!(A, rand_normal.mlir_data) + TracedUtils.set_mlir_data!(A, rand_normal.mlir_data) return A end