Skip to content

Commit

Permalink
feat: support var.
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 26, 2024
1 parent 454b3a8 commit a93e795
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions src/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,43 @@ for (jlop, hloop, RT) in (
)
end

function $jlop(lhs::TracedRArray{ElType,Shape,N}, rhs) where {ElType,Shape,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,N}(
function $jlop(
lhs::TracedRArray{ElType,(),0}, rhs::TracedRArray{ElType,(),0}
) where {ElType}
return TracedRArray{$RT,(),0}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end
end

function $jlop(lhs, rhs::TracedRArray{ElType,Shape,N}) where {ElType,Shape,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
for otherType in (Number, Any, TracedRArray{S,(),0} where {S})
@eval begin
function $jlop(
lhs::TracedRArray{ElType,Shape,N}, rhs::$otherType
) where {ElType,Shape,N}
rhs = promote_to(lhs, rhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end

function $jlop(
lhs::$otherType, rhs::TracedRArray{ElType,Shape,N}
) where {ElType,Shape,N}
lhs = promote_to(rhs, lhs)
return TracedRArray{$RT,Shape,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.$hloop(lhs.mlir_data, rhs.mlir_data), 1
),
)
end
end
end
end
Expand Down Expand Up @@ -203,6 +222,13 @@ function Statistics.mean(A::TracedRArray{T,Shape,N}; dims=:) where {T,Shape,N}
denom = dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)
return mapreduce(identity, +, A; dims) / denom
end
function Statistics.var(
A::TracedRArray{T,Shape,N}; dims=:, mean=nothing, corrected=true
) where {T,Shape,N}
mean === nothing && (mean = Statistics.mean(A; dims))
denom = (dims isa Colon ? length(A) : prod(Base.Fix1(size, A), dims)) - corrected
return mapreduce(abs2, +, A .- mean; dims) / denom
end

function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(
Expand Down

0 comments on commit a93e795

Please sign in to comment.