From 454b3a86bf8b8a3b895d5fab69144cf48bed06ca Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 15:09:14 -0700 Subject: [PATCH] feat: support `mean`. --- Project.toml | 9 +++------ src/Reactant.jl | 1 + src/overloads.jl | 16 ++++++++++++++++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index ab4fad078..fadbd9f31 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,6 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" -authors = [ - "William Moses ", - "Valentin Churavy ", - "Sergio Sánchez Ramírez ", - "Paul Berg ", -] +authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg "] version = "0.1.8" [deps] @@ -15,6 +10,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -35,6 +31,7 @@ NNlib = "0.9" PackageExtensionCompat = "1" Preferences = "1.4" Reactant_jll = "0.0.12" +Statistics = "1.11.1" julia = "1.9" [extras] diff --git a/src/Reactant.jl b/src/Reactant.jl index dfca13907..73d4fa39e 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -1,6 +1,7 @@ module Reactant using PackageExtensionCompat +using Statistics: Statistics include("mlir/MLIR.jl") include("XLA.jl") diff --git a/src/overloads.jl b/src/overloads.jl index 593d6e22b..5d85c638e 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -130,6 +130,17 @@ for (jlop, hloop, RT) in ( ) end + # Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity + function $jlop(lhs::TracedRArray{ElType,Shape,0}, rhs::Number) where {ElType,Shape} + rhs = promote_to(lhs, rhs) + return TracedRArray{$RT,Shape,0}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end + function $jlop(lhs, rhs::TracedRArray{ElType,Shape,0}) where {ElType,Shape} lhs = promote_to(rhs, lhs) return TracedRArray{$RT,Shape,0}( @@ -188,6 +199,11 @@ for (jlop, hloop) in ( end end +function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} + denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) + return mapreduce(identity, +, A; dims) / denom +end + function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn( f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true