diff --git a/src/overloads.jl b/src/overloads.jl index 5d85c638e..69eaef3cb 100644 --- a/src/overloads.jl +++ b/src/overloads.jl @@ -59,24 +59,43 @@ for (jlop, hloop, RT) in ( ) end - function $jlop(lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N} - rhs = promote_to(lhs, rhs) - return TracedRArray{$RT,Shape,N}( + function $jlop( + lhs::TracedRArray{ElType,(),0}, rhs::TracedRArray{ElType,(),0} + ) where {ElType} + return TracedRArray{$RT,(),0}( (), MLIR.IR.result( MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 ), ) end + end - function $jlop(lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N} - lhs = promote_to(rhs, lhs) - return TracedRArray{$RT,Shape,N}( - (), - MLIR.IR.result( - MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 - ), - ) + for otherType in (Number, Any, TracedRArray{S,(),0} where {S}) + @eval begin + function $jlop( + lhs::TracedRArray{ElType,Shape,N}, rhs::$otherType + ) where {ElType,Shape,N} + rhs = promote_to(lhs, rhs) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end + + function $jlop( + lhs::$otherType, rhs::TracedRArray{ElType,Shape,N} + ) where {ElType,Shape,N} + lhs = promote_to(rhs, lhs) + return TracedRArray{$RT,Shape,N}( + (), + MLIR.IR.result( + MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1 + ), + ) + end end end end @@ -203,6 +222,13 @@ function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N} denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims) return mapreduce(identity, +, A; dims) / denom end +function Statistics.var( + A::TracedRArray{T,Shape,N}; dims=:, mean=nothing, corrected=true +) where {T,Shape,N} + mean === nothing && (mean = Statistics.mean(A; dims)) + denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected + return mapreduce(abs2, +, A .- mean; dims) / denom +end function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs} fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(