Skip to content

Commit

Permalink
feat: support randexp
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 15, 2024
1 parent 1d6c894 commit 1e41b36
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/stdlibs/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,19 @@ function Random.randn!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
return A
end

for randfun in (:rand, :randn)
function Random.randexp!(rng::TracedRNG, A::AnyTracedRArray{T,N}) where {T,N}
length(A) == 0 && return A
Random.rand!(rng, A)
TracedUtils.set_mlir_data!(
A,
Ops.negate(
Ops.log_plus_one(Ops.negate(TracedUtils.materialize_traced_array(A)))
).mlir_data,
)
return A
end

for randfun in (:rand, :randn, :randexp)
randfun! = Symbol(randfun, :!)
@eval begin
function Random.$(randfun)(rng::TracedRNG, ::Type{T}, dims::Dims) where {T}
Expand Down Expand Up @@ -97,10 +109,12 @@ for randfun in (:rand, :randn)
end

# resolve ambiguities
function Random.randn(rng::TracedRNG, T::Random.BitFloatType)
A = promote_to(TracedRArray{T,0}, fill(T(0)))
Random.randn!(rng, A)
return A[]
for randfun in (:randn, :randexp)
@eval function Random.$(randfun)(rng::TracedRNG, T::Random.BitFloatType)
A = promote_to(TracedRArray{T,0}, fill(T(0)))
Random.randn!(rng, A)
return A[]
end
end

# TODO: At some later point we might want to implement the sampler API as well since it
Expand Down

0 comments on commit 1e41b36

Please sign in to comment.