diff --git a/Project.toml b/Project.toml index e0d6f99a1..590ac2dce 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,6 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" -authors = [ - "William Moses ", - "Valentin Churavy ", - "Sergio Sánchez Ramírez ", - "Paul Berg ", -] +authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg "] version = "0.1.8" [deps] @@ -20,11 +15,13 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [extensions] ReactantAdaptExt = "Adapt" ReactantArrayInterfaceExt = "ArrayInterface" ReactantNNlibExt = "NNlib" +ReactantStatisticsExt = "Statistics" [compat] Adapt = "4" @@ -35,6 +32,7 @@ NNlib = "0.9" PackageExtensionCompat = "1" Preferences = "1.4" Reactant_jll = "0.0.14" +Statistics = "1.9" julia = "1.9" [extras] diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 7f9351416..23763e4c4 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -3,23 +3,25 @@ module ReactantNNlibExt using NNlib using Reactant -for (jlop, hloop) in ((:(NNlib.tanh), :tanh), (:(NNlib.tanh_fast), :tanh)) - @eval begin - if $jlop != Base.tanh && $jlop != Base.FastMath.tanh_fast - function Reactant.elem_apply( - ::typeof($jlop), lhs::Reactant.TracedRArray{ElType,Shape,N} - ) where {ElType,Shape,N} - return Reactant.TracedRArray{ElType,Shape,N}( - (), - Reactant.MLIR.IR.result( - Reactant.MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data), 1 - ), - ) - end - end +for (jlop, hloop) in ( + (:(NNlib.tanh_fast), :tanh), + (:(NNlib.sigmoid_fast), :logistic), + (:(NNlib.sigmoid), :logistic), +) + @eval function $(jlop)(x::Reactant.TracedRArray{T,(),0}) where {T} + return Reactant.TracedRArray{T,(),0}( + (), + Reactant.MLIR.IR.result( + Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1 + ), + ) end end +NNlib.relu(x::Reactant.TracedRArray{T,(),0}) where {T} = max(x, zero(T)) + +NNlib.gelu(x::Reactant.TracedRArray{T,(),0}) where {T} = x * sigmoid(T(1.702) * x) + # TODO handle non finite cases function NNlib.softmax!( out::Reactant.TracedRArray{T,Shape,N}, x::AbstractArray; dims=1 diff --git a/ext/ReactantStatisticsExt.jl b/ext/ReactantStatisticsExt.jl new file mode 100644 index 000000000..2dd813ae8 --- /dev/null +++ b/ext/ReactantStatisticsExt.jl @@ -0,0 +1,19 @@ +module ReactantStatisticsExt + +using Reactant: TracedRArray +using Statistics: Statistics + +function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} + denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) + return mapreduce(identity, +, A; dims) / denom +end + +function Statistics.var( + A::TracedRArray{T,Shape,N}; dims=:, mean=nothing, corrected=true +) where {T,Shape,N} + mean === nothing && (mean = Statistics.mean(A; dims)) + denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected + return mapreduce(abs2, +, A .- mean; dims) / denom +end + +end diff --git a/src/Reactant.jl b/src/Reactant.jl index dfca13907..7c70b8cb8 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -56,13 +56,13 @@ function Base.isapprox(x::ConcreteRArray{ElType,(),0}, y; kwargs...) where {ElTy end function Base.isapprox(x, y::ConcreteRArray{ElType,(),0}; kwargs...) where {ElType} - return Base.isapprox(to_float(x), y; kwargs...) + return Base.isapprox(x, to_float(y); kwargs...) end function Base.isapprox( x::ConcreteRArray{ElType,(),0}, y::ConcreteRArray{ElType2,(),0}; kwargs... ) where {ElType,ElType2} - return Base.isapprox(to_float(x), y; kwargs...) + return Base.isapprox(to_float(x), to_float(y); kwargs...) end function Base.print_array(io::IO, X::ConcreteRArray) diff --git a/src/overloads.jl b/src/overloads.jl index 43b73f669..7f98627da 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -59,28 +59,49 @@ for (jlop, hloop, RT) in ( ) end - function $jlop(lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N} - rhs = promote_to(lhs, rhs) - return TracedRArray{$RT,Shape,N}( + function $jlop( + lhs::TracedRArray{ElType,(),0}, rhs::TracedRArray{ElType,(),0} + ) where {ElType} + return TracedRArray{$RT,(),0}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 ), ) end + end - function $jlop(lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} - lhs = promote_to(rhs, lhs) - return TracedRArray{$RT,Shape,N}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 - ), - ) + for otherType in (Number, Any, TracedRArray{S,(),0} where {S}) + @eval begin + function $jlop( + lhs::TracedRArray{ElType,Shape,N}, rhs::$otherType + ) where {ElType,Shape,N} + rhs = promote_to(lhs, rhs) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end + + function $jlop( + lhs::$otherType, rhs::TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} + lhs = promote_to(rhs, lhs) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end end end end +Base.abs2(x::Reactant.TracedRArray{T,(),0}) where {T} = x * conj(x) + function Base.literal_pow( ::Base.RefValue{typeof(^)}, x::Reactant.TracedRArray{T,(),0}, ::Base.RefValue{Val{P}} ) where {T,P} @@ -137,9 +158,41 @@ for (jlop, hloop, RT) in ( ), ) end + + # Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity + function $jlop(lhs::TracedRArray{ElType,Shape,0}, rhs::Number) where {ElType,Shape} + rhs = promote_to(lhs, rhs) + return TracedRArray{$RT,Shape,0}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end + + function $jlop(lhs::Number, rhs::TracedRArray{ElType,Shape,0}) where {ElType,Shape} + lhs = promote_to(rhs, lhs) + return TracedRArray{$RT,Shape,0}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end end end +function Base.ifelse( + pred::TracedRArray{Bool,(),0}, x::TracedRArray{T1,(),0}, y::TracedRArray{T2,(),0} +) where {T1,T2} + return TracedRArray{promote_type(T1, T2),(),0}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1 + ), + ) +end + function Base.:*( lhs::TracedRArray{ElType,Shape,2}, rhs::TracedRArray{ElType,Shape2,2} ) where {ElType,Shape,Shape2} diff --git a/test/Project.toml b/test/Project.toml index 10bc878e9..6a5bf9013 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/basic.jl b/test/basic.jl index 07870de97..20ef425d5 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -1,6 +1,7 @@ using Reactant using Test using Enzyme +using Statistics # Reactant.set_default_backend("gpu") @@ -152,3 +153,38 @@ end @test contains(res_repr, "stablehlo.dot_general") end + +@testset "Statistics: `mean` & `var`" begin + x = randn(2, 3, 4) + x_ca = Reactant.ConcreteRArray(x) + + mean_fn1(x) = mean(x) + mean_fn2(x) = mean(x; dims=1) + mean_fn3(x) = mean(x; dims=(1, 2)) + mean_fn4(x) = mean(x; dims=(1, 3)) + + mean_fn1_compiled = Reactant.compile(mean_fn1, (x_ca,)) + mean_fn2_compiled = Reactant.compile(mean_fn2, (x_ca,)) + mean_fn3_compiled = Reactant.compile(mean_fn3, (x_ca,)) + mean_fn4_compiled = Reactant.compile(mean_fn4, (x_ca,)) + + @test mean_fn1(x) ≈ mean_fn1_compiled(x_ca) + @test mean_fn2(x) ≈ mean_fn2_compiled(x_ca) + @test mean_fn3(x) ≈ mean_fn3_compiled(x_ca) + @test mean_fn4(x) ≈ mean_fn4_compiled(x_ca) + + var_fn1(x) = var(x) + var_fn2(x) = var(x; dims=1) + var_fn3(x) = var(x; dims=(1, 2), corrected=false) + var_fn4(x) = var(x; dims=(1, 3), corrected=false) + + var_fn1_compiled = Reactant.compile(var_fn1, (x_ca,)) + var_fn2_compiled = Reactant.compile(var_fn2, (x_ca,)) + var_fn3_compiled = Reactant.compile(var_fn3, (x_ca,)) + var_fn4_compiled = Reactant.compile(var_fn4, (x_ca,)) + + @test var_fn1(x) ≈ var_fn1_compiled(x_ca) + @test var_fn2(x) ≈ var_fn2_compiled(x_ca) + @test var_fn3(x) ≈ var_fn3_compiled(x_ca) + @test var_fn4(x) ≈ var_fn4_compiled(x_ca) +end diff --git a/test/bcast.jl b/test/bcast.jl index 9d05200ad..f4942e2ec 100644 --- a/test/bcast.jl +++ b/test/bcast.jl @@ -1,6 +1,6 @@ using Reactant - +using Enzyme, NNlib using Reactant.MLIR @noinline function no(@nospecialize(x)) @@ -56,3 +56,34 @@ function test() end end test() + +@testset "Activation Functions" begin + sumabs2(f, x) = sum(abs2, f.(x)) + + function ∇sumabs2(f, x) + dx = Enzyme.make_zero(x) + Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx)) + return dx + end + + x_act = randn(Float32, 10, 10) + x_act_ca = Reactant.ConcreteRArray(x_act) + + @testset "Activation: $act" for act in ( + identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2 + ) + f_compile = Reactant.compile(sumabs2, (act, x_act)) + + y_simple = sumabs2(act, x_act) + y_compile = f_compile(act, x_act_ca) + + ∂x_enz = Enzyme.make_zero(x_act) + Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz)) + + ∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca)) + + ∂x_compile = ∇sumabs2_compiled(act, x_act_ca) + + @test y_simple ≈ y_compile + end +end diff --git a/test/nn_lux.jl b/test/nn_lux.jl index e9bf1b204..3521efc38 100644 --- a/test/nn_lux.jl +++ b/test/nn_lux.jl @@ -9,6 +9,7 @@ truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-ele # Define our model, a multi-layer perceptron with one hidden layer of size 3: model = Lux.Chain( Lux.Dense(2 => 3, tanh), # activation function inside layer + Lux.BatchNorm(3, gelu), Lux.Dense(3 => 2), softmax, ) @@ -17,8 +18,7 @@ ps, st = Lux.setup(Xoshiro(123), model) using BenchmarkTools origout, _ = model(noisy, ps, st) -@show origout[3] -@btime model($noisy, $ps, $st) # 52.731 μs (10 allocations: 32.03 KiB) +@btime model($noisy, $ps, $st) # 68.444 μs (46 allocations: 45.88 KiB) cmodel = Reactant.make_tracer(IdDict(), model, (), Reactant.ArrayToConcrete) cps = Reactant.make_tracer(IdDict(), ps, (), Reactant.ArrayToConcrete) @@ -31,8 +31,9 @@ f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cs # # @show @code_typed f(cmodel,cnoisy) # # @show @code_llvm f(cmodel,cnoisy) comp = f(cmodel, cnoisy, cps, cst) -@show comp[3] -@btime f($cmodel, $cnoisy, $cps, $cst) # 4.430 μs (5 allocations: 160 bytes) +@btime f($cmodel, $cnoisy, $cps, $cst) # 21.790 μs (6 allocations: 224 bytes) + +@test comp ≈ origout atol = 1e-5 rtol = 1e-2 # To train the model, we use batches of 64 samples, and one-hot encoding: @@ -81,6 +82,8 @@ compiled_gradient = Reactant.compile( gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst) ) +@test length(compiled_gradient(cmodel, cnoisy, ctarget, cps, cst)) == 2 + # # Training loop, using the whole data set 1000 times: # losses = [] # for epoch in 1:1_000 diff --git a/test/runtests.jl b/test/runtests.jl index 77ecc0637..bf98d3a10 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -46,4 +46,7 @@ include("nn.jl") include("struct.jl") include("closure.jl") include("compile.jl") -include("nn_lux.jl") + +if VERSION ≥ v"1.10-" # Lux isn't supported on 1.9 + include("nn_lux.jl") +end