Skip to content

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed May 28, 2024
1 parent 7e54418 commit 7e31d37
Show file tree
Hide file tree
Showing 34 changed files with 2,270 additions and 1,335 deletions.
24 changes: 15 additions & 9 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,33 @@ using NNlib
using Reactant

function __init__()
for (jlop, hloop) in ((:(NNlib.tanh), :tanh),(:(NNlib.tanh_fast), :tanh),)
@eval begin
if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast
function Reactant.elem_apply(::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
return Reactant.TracedRArray{ElType,Shape,N}((), Reactant.MLIR.IR.result(Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1))
for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh))
@eval begin
if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast
function Reactant.elem_apply(::typeof($jlop),
lhs::Reactant.TracedRArray{ElType,Shape,N}) where {ElType,
Shape,
N}
return Reactant.TracedRArray{ElType,Shape,N}((),
Reactant.MLIR.IR.result(Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data),
1))
end
end
end
end
end
end

# TODO handle non finite cases
function NNlib.softmax!(out::Reactant.TracedRArray{T, Shape, N}, x::AbstractArray; dims = 1) where {T, Shape, N}
function NNlib.softmax!(out::Reactant.TracedRArray{T,Shape,N}, x::AbstractArray;
dims=1) where {T,Shape,N}
max_ = NNlib.fast_maximum(x; dims)
#if all(isfinite, max_)
@fastmath out .= exp.(x .- max_)
@fastmath out .= exp.(x .- max_)
#else
# _zero, _one, _inf = T(0), T(1), T(Inf)
# @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_))
#end
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
out ./= tmp
return out ./= tmp
end
end
Loading

0 comments on commit 7e31d37

Please sign in to comment.