-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move
NamedArrayPartition
back to StartUpDG.jl (temporarily) (#175)
- Loading branch information
Showing
5 changed files
with
173 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
|
||
using RecursiveArrayTools: RecursiveArrayTools, ArrayPartition, npartitions, unpack | ||
|
||
""" | ||
NamedArrayPartition(; kwargs...) | ||
NamedArrayPartition(x::NamedTuple) | ||
Similar to an `ArrayPartition` but the individual arrays can be accessed via the | ||
constructor-specified names. However, unlike `ArrayPartition`, each individual array | ||
must have the same element type. | ||
""" | ||
struct NamedArrayPartition{T, A<:ArrayPartition{T}, NT<:NamedTuple} <: AbstractVector{T} | ||
array_partition::A | ||
names_to_indices::NT | ||
end | ||
NamedArrayPartition(; kwargs...) = NamedArrayPartition(NamedTuple(kwargs)) | ||
function NamedArrayPartition(x::NamedTuple) | ||
names_to_indices = NamedTuple(Pair(symbol, index) for (index, symbol) in enumerate(keys(x))) | ||
|
||
# enforce homogeneity of eltypes | ||
@assert all(eltype.(values(x)) .== eltype(first(x))) | ||
T = eltype(first(x)) | ||
S = typeof(values(x)) | ||
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices) | ||
end | ||
|
||
# note that overloading `getproperty` means we cannot access `NamedArrayPartition` | ||
# fields except through `getfield` and accessor functions. | ||
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition) | ||
|
||
Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x)) | ||
|
||
Base.zero(x::NamedArrayPartition{T, S, TN}) where {T, S, TN} = | ||
NamedArrayPartition{T, S, TN}(zero(ArrayPartition(x)), getfield(x, :names_to_indices)) | ||
Base.zero(A::NamedArrayPartition, dims::NTuple{N, Int}) where {N} = zero(A) # ignore dims since named array partitions are vectors | ||
|
||
|
||
Base.propertynames(x::NamedArrayPartition) = propertynames(getfield(x, :names_to_indices)) | ||
Base.getproperty(x::NamedArrayPartition, s::Symbol) = | ||
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s)) | ||
|
||
# !!! this won't work if `v` isn't the same size as | ||
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v) | ||
index = getproperty(getfield(x, :names_to_indices), s) | ||
ArrayPartition(x).x[index] .= v | ||
end | ||
|
||
# print out NamedArrayPartition as a NamedTuple | ||
Base.summary(x::NamedArrayPartition) = string(typeof(x), " with arrays:") | ||
Base.show(io::IO, m::MIME"text/plain", x::NamedArrayPartition) = | ||
show(io, m, NamedTuple(Pair.(keys(getfield(x, :names_to_indices)), ArrayPartition(x).x))) | ||
|
||
Base.size(x::NamedArrayPartition) = size(ArrayPartition(x)) | ||
Base.length(x::NamedArrayPartition) = length(ArrayPartition(x)) | ||
Base.getindex(x::NamedArrayPartition, args...) = getindex(ArrayPartition(x), args...) | ||
|
||
Base.setindex!(x::NamedArrayPartition, args...) = setindex!(ArrayPartition(x), args...) | ||
Base.map(f, x::NamedArrayPartition) = NamedArrayPartition(map(f, ArrayPartition(x)), getfield(x, :names_to_indices)) | ||
Base.mapreduce(f, op, x::NamedArrayPartition) = mapreduce(f, op, ArrayPartition(x)) | ||
# Base.filter(f, x::NamedArrayPartition) = filter(f, ArrayPartition(x)) | ||
|
||
Base.similar(x::NamedArrayPartition{T, S, NT}) where {T, S, NT} = | ||
NamedArrayPartition{T, S, NT}(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) | ||
|
||
# # return NamedArrayPartition when possible, otherwise next best thing of the correct size | ||
# function Base.similar(x::ArrayPartition, dims::NTuple{N,Int}) where {N} | ||
# if dims == size(x) | ||
# return similar(x) | ||
# else | ||
# return similar(ArrayPartition(x).x[1], eltype(x), dims) | ||
# end | ||
# end | ||
|
||
# # similar array partition of common type | ||
# @inline function Base.similar(A::ArrayPartition, ::Type{T}) where {T} | ||
# N = npartitions(A) | ||
# ArrayPartition(i->similar(A.x[i], T), N) | ||
# end | ||
|
||
# broadcasting | ||
Base.BroadcastStyle(::Type{<:NamedArrayPartition}) = Broadcast.ArrayStyle{NamedArrayPartition}() | ||
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, | ||
::Type{ElType}) where {ElType} | ||
x = find_NamedArrayPartition(bc) | ||
return NamedArrayPartition(similar(ArrayPartition(x)), getfield(x, :names_to_indices)) | ||
end | ||
|
||
# when broadcasting with ArrayPartition + another array type, the output is the other array tupe | ||
Base.BroadcastStyle(::Broadcast.ArrayStyle{NamedArrayPartition}, ::Broadcast.DefaultArrayStyle{1}) = | ||
Broadcast.DefaultArrayStyle{1}() | ||
|
||
# hook into ArrayPartition broadcasting routines | ||
@inline RecursiveArrayTools.npartitions(x::NamedArrayPartition) = npartitions(ArrayPartition(x)) | ||
@inline RecursiveArrayTools.unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}, i) = | ||
Broadcast.Broadcasted(bc.f, RecursiveArrayTools.unpack_args(i, bc.args)) | ||
@inline RecursiveArrayTools.unpack(x::NamedArrayPartition, i) = unpack(ArrayPartition(x), i) | ||
|
||
Base.copy(A::NamedArrayPartition{T,S,NT}) where {T,S,NT} = | ||
NamedArrayPartition{T,S,NT}(copy(ArrayPartition(A)), getfield(A, :names_to_indices)) | ||
|
||
@inline NamedArrayPartition(f::F, N, names_to_indices) where F<:Function = | ||
NamedArrayPartition(ArrayPartition(ntuple(f, Val(N))), names_to_indices) | ||
|
||
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) | ||
N = npartitions(bc) | ||
@inline function f(i) | ||
copy(unpack(bc, i)) | ||
end | ||
x = find_NamedArrayPartition(bc) | ||
NamedArrayPartition(f, N, getfield(x, :names_to_indices)) | ||
end | ||
|
||
@inline function Base.copyto!(dest::NamedArrayPartition, | ||
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}}) | ||
N = npartitions(dest, bc) | ||
@inline function f(i) | ||
copyto!(ArrayPartition(dest).x[i], unpack(bc, i)) | ||
end | ||
ntuple(f, Val(N)) | ||
return dest | ||
end | ||
|
||
# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments. | ||
find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args) | ||
find_NamedArrayPartition(args::Tuple) = | ||
find_NamedArrayPartition(find_NamedArrayPartition(args[1]), Base.tail(args)) | ||
find_NamedArrayPartition(x) = x | ||
find_NamedArrayPartition(::Tuple{}) = nothing | ||
find_NamedArrayPartition(x::NamedArrayPartition, rest) = x | ||
find_NamedArrayPartition(::Any, rest) = find_NamedArrayPartition(rest) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
@testset "NamedArrayPartition tests" begin | ||
x = NamedArrayPartition(a = ones(10), b = rand(20)) | ||
@test typeof(@. sin(x * x^2 / x - 1)) <: NamedArrayPartition | ||
@test typeof(x.^2) <: NamedArrayPartition | ||
@test x.a ≈ ones(10) | ||
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence | ||
@test all(x .== x[1:end]) | ||
y = copy(x) | ||
@test zero(x, (10, 20)) == zero(x) # test that ignoring dims works | ||
@test typeof(zero(x)) <: NamedArrayPartition | ||
@test (y .*= 2).a[1] ≈ 2 # test in-place bcast | ||
|
||
@test length(Array(x))==30 | ||
@test typeof(Array(x)) <: Array | ||
@test propertynames(x) == (:a, :b) | ||
|
||
x = NamedArrayPartition(a = ones(1), b = 2*ones(1)) | ||
@test Base.summary(x) == string(typeof(x), " with arrays:") | ||
@test (@capture_out Base.show(stdout, MIME"text/plain"(), x)) == "(a = [1.0], b = [2.0])" | ||
|
||
using StructArrays | ||
using StartUpDG: SVector | ||
x = NamedArrayPartition(a = StructArray{SVector{2, Float64}}((ones(5), 2*ones(5))), | ||
b = StructArray{SVector{2, Float64}}((3 * ones(2,2), 4*ones(2,2)))) | ||
@test typeof(x.a) <: StructVector{<:SVector{2}} | ||
@test typeof(x.b) <: StructArray{<:SVector{2}, 2} | ||
@test typeof((x->x[1]).(x)) <: NamedArrayPartition | ||
@test typeof(map(x->x[1], x)) <: NamedArrayPartition | ||
@test typeof(similar(x)) == typeof(x) | ||
end | ||
|
||
# x = NamedArrayPartition(a = ones(10), b = rand(20)) | ||
# x_ap = ArrayPartition(x) | ||
# @btime @. x_ap * x_ap; # 498.836 ns (5 allocations: 2.77 KiB) | ||
# @btime @. x * x; # 2.032 μs (5 allocations: 2.84 KiB) - 5x slower than ArrayPartition |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters