Skip to content

Commit

Permalink
refactor: move things into a module
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 16, 2024
1 parent 6a1d4bd commit 30d2732
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
5 changes: 5 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ include("TracedRArray.jl")

include("ConcreteRArray.jl")

mutable struct TracedRNG <: Random.AbstractRNG
seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
const algorithm::String
end

# StdLib Overloads
include("stdlibs/LinearAlgebra.jl")
include("stdlibs/Random.jl")
Expand Down
20 changes: 16 additions & 4 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
module TracedRandom

# Implementation based on the following:
# 1. https://github.com/JuliaGPU/CUDA.jl/blob/master/src/random.jl
# 2. https://github.com/JuliaRandom/Random123.jl/blob/master/src/common.jl#L125

mutable struct TracedRNG <: Random.AbstractRNG
seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
const algorithm::String
end
using ..Reactant:
Reactant,
TracedRArray,
TracedRNumber,
TracedRNG,
AnyTracedRArray,
Reactant,
TracedUtils,
@reactant_override,
Ops,
ConcreteRArray
using Random: Random

function Random.seed!(rng::TracedRNG, seed::Number)
seed = reinterpret(UInt64, Random.hash_seed(seed))
Expand Down Expand Up @@ -128,3 +138,5 @@ end
# confirm that the dynamic_update_slice calls are optimized away into a single
# `stablehlo.rng_bit_generator` call -- confirm that this should be the case based on
# how the seeding should work?

end

0 comments on commit 30d2732

Please sign in to comment.