Skip to content

Commit

Permalink
feedbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
glou-nes committed Dec 16, 2024
1 parent 3289169 commit deb935f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
12 changes: 6 additions & 6 deletions ext/ReactantSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
module ReactantSpecialFunctionsExt
using SpecialFunctions
using Reactant: Ops, Reactant, ReactantFloat, TracedRNumber
using Reactant.TracedUtils: promote_to
using Reactant.TracedRNumberOverrides: float

for fn in [:gamma, :loggamma, :digamma, :erf, :erfc]
for fn in [:gamma, :loggamma, :digamma, :trigamma, :erf, :erfc]
@eval(function SpecialFunctions.$fn(x::TracedRNumber{<:Number})
return $fn(promote_to(TracedRNumber{Float64}, x))
return $fn(float(x))
end)
end

Expand Down Expand Up @@ -41,7 +41,7 @@ end

# SpecialFunctions.invdigamma

function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T}
function SpecialFunctions.trigamma(x::TracedRNumber{T}) where {T<: ReactantFloat}
return Ops.polygamma(Ops.constant(T(1)), x)
end

Expand All @@ -52,7 +52,7 @@ function SpecialFunctions.polygamma(
end

function SpecialFunctions.polygamma(n::TracedRNumber{T}, x::TracedRNumber{T}) where {T}
x = promote_to(TracedRNumber{Float64}, x)
x = promote_to(TracedRNumber{T}, x)
return polygamma(n, x)
end

Expand Down Expand Up @@ -101,7 +101,7 @@ function SpecialFunctions.logerf(x::TracedRNumber{T}, y::TracedRNumber{T}) where
end

function SpecialFunctions.erfcx(x::TracedRNumber{T}) where {T}
return exp(x^2) * erfc(x)
return exp(float(x^2)) * erfc(x)
end

function SpecialFunctions.logerfc(x::TracedRNumber{T}) where {T}
Expand Down
41 changes: 22 additions & 19 deletions test/integration/special_functions.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,31 @@
using SpecialFunctions, Reactant
@testset "Generic" begin
values = [0.5, 0.6]
for (op, n_args) in [
(:gamma, 1),
(:loggamma, 1),
(:loggamma1p, 1),
(:digamma, 1),
(:trigamma, 1),
(:beta, 2),
(:logbeta, 2),
(:erf, 1),
(:erf, 2),
(:erfc, 1),
(:logerf, 2),
(:erfcx, 1),
(:logerfc, 1),
(:logerfcx, 1),
]
x = values[1:n_args]
@testset "$op" for (op, n_args) in [
(:gamma, 1),
(:loggamma, 1),
(:digamma, 1),
(:trigamma, 1),
(:beta, 2),
(:logbeta, 2),
(:erf, 1),
(:erf, 2),
(:erfc, 1),
(:logerf, 2),
(:erfcx, 1),
(:logerfc, 1),
(:logerfcx, 1),
]
for data in ([0.5, 0.6], [2, 4])
x = data[1:n_args]
@eval @test @jit(SpecialFunctions.$op(ConcreteRNumber.($x)...))
SpecialFunctions.$op($x...)
end
end

@testset "loggamma1p" begin
@test SpecialFunctions.loggamma1p(0.5)
@jit SpecialFunctions.loggamma1p(ConcreteRNumber(0.5))
end

@testset "loggammadiv" begin
@test SpecialFunctions.loggammadiv(150, 20)
@jit SpecialFunctions.loggammadiv(ConcreteRNumber(150), ConcreteRNumber(20))
Expand Down

0 comments on commit deb935f

Please sign in to comment.