From a93e79551db1bbdd3e6b55446118307829dc1647 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 26 Jul 2024 15:20:55 -0700 Subject: [PATCH] feat: support `var`. --- src/overloads.jl | 48 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/src/overloads.jl b/src/overloads.jl index 5d85c638e..69eaef3cb 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -59,24 +59,43 @@ for (jlop, hloop, RT) in ( ) end - function $jlop(lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N} - rhs = promote_to(lhs, rhs) - return TracedRArray{$RT,Shape,N}( + function $jlop( + lhs::TracedRArray{ElType,(),0}, rhs::TracedRArray{ElType,(),0} + ) where {ElType} + return TracedRArray{$RT,(),0}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 ), ) end + end - function $jlop(lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} - lhs = promote_to(rhs, lhs) - return TracedRArray{$RT,Shape,N}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 - ), - ) + for otherType in (Number, Any, TracedRArray{S,(),0} where {S}) + @eval begin + function $jlop( + lhs::TracedRArray{ElType,Shape,N}, rhs::$otherType + ) where {ElType,Shape,N} + rhs = promote_to(lhs, rhs) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end + + function $jlop( + lhs::$otherType, rhs::TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} + lhs = promote_to(rhs, lhs) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end end end end @@ -203,6 +222,13 @@ function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) return mapreduce(identity, +, A; dims) / denom end +function Statistics.var( + A::TracedRArray{T,Shape,N}; dims=:, mean=nothing, corrected=true +) where {T,Shape,N} + mean === nothing && (mean = Statistics.mean(A; dims)) + denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected + return mapreduce(abs2, +, A .- mean; dims) / denom +end function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(