Skip to content

Commit

Permalink
feat: support mean.
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 26, 2024
1 parent dda3a9c commit 454b3a8
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
9 changes: 3 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
name = "Reactant"
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
authors = [
"William Moses <[email protected]>",
"Valentin Churavy <[email protected]>",
"Sergio Sánchez Ramírez <[email protected]>",
"Paul Berg <[email protected]>",
]
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>"]
version = "0.1.8"

[deps]
Expand All @@ -15,6 +10,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -35,6 +31,7 @@ NNlib = "0.9"
PackageExtensionCompat = "1"
Preferences = "1.4"
Reactant_jll = "0.0.12"
Statistics = "1.11.1"
julia = "1.9"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Reactant

using PackageExtensionCompat
using Statistics: Statistics

include("mlir/MLIR.jl")
include("XLA.jl")
Expand Down
16 changes: 16 additions & 0 deletions src/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ for (jlop, hloop, RT) in (
)
end

# Base defines ::AbstractArray / ::Number, so we need this to avoid ambiguity
function $jlop(lhs::TracedRArray{ElType,Shape,0}, rhs::Number) where {ElType,Shape}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end

function $jlop(lhs, rhs::TracedRArray{ElType,Shape,0}) where {ElType,Shape}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,0}(
Expand Down Expand Up @@ -188,6 +199,11 @@ for (jlop, hloop) in (
end
end

function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N}
denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)
return mapreduce(identity, +, A; dims) / denom
end

function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(
f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true
Expand Down

0 comments on commit 454b3a8

Please sign in to comment.