From 5f31db261ae28e0132e60de600563036e2173783 Mon Sep 17 00:00:00 2001 From: Francesc Verdugo Date: Fri, 17 May 2019 09:08:36 +0200 Subject: [PATCH] Added convert and CartesianIndices, LinearIndices --- src/Indexing.jl | 12 ++++++++++-- src/TensorValues.jl | 3 +++ src/Types.jl | 6 ++++++ test/IndexingTests.jl | 6 ++++++ test/TypesTests.jl | 6 ++++++ 5 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/Indexing.jl b/src/Indexing.jl index 85ec476..99f7912 100644 --- a/src/Indexing.jl +++ b/src/Indexing.jl @@ -14,8 +14,16 @@ end eltype(a::Type{MultiValue{S,T,N,L}}) where {S,T,N,L} = T -iterate(a::MultiValue) = iterate(a.array) +@inline iterate(a::MultiValue) = iterate(a.array) -iterate(a::MultiValue, state) = iterate(a.array, state) +@inline iterate(a::MultiValue, state) = iterate(a.array, state) eachindex(a::MultiValue) = eachindex(a.array) + +function CartesianIndices(a::MultiValue) + CartesianIndices(a.array) +end + +function LinearIndices(a::MultiValue) + LinearIndices(a.array) +end diff --git a/src/TensorValues.jl b/src/TensorValues.jl index 268f410..ad3a3db 100644 --- a/src/TensorValues.jl +++ b/src/TensorValues.jl @@ -16,6 +16,9 @@ import Base: +, -, *, /, \, ==, ≈ import Base: getindex, iterate, eachindex import Base: size, length, eltype import Base: reinterpret +import Base: convert +import Base: CartesianIndices +import Base: LinearIndices import LinearAlgebra: det, inv diff --git a/src/Types.jl b/src/Types.jl index 0ea9327..4001068 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -121,6 +121,12 @@ function one(::Type{<:MultiValue{S,T,N,L}}) where {S,T,N,L} MultiValue{S,T,N,L}(z) end +# Conversions + +function convert(::Type{<:MultiValue{S,T,N}},a::StaticArray{S,T,N}) where {S,T,N} + MultiValue(a) +end + # Custom type printing function show(io::IO,v::MultiValue) diff --git a/test/IndexingTests.jl b/test/IndexingTests.jl index eccaff5..ff02a9a 100644 --- a/test/IndexingTests.jl +++ b/test/IndexingTests.jl @@ -2,6 +2,7 @@ module IndexingTests using Test using TensorValues +using StaticArrays a = (3,4,5,1) @@ -34,4 +35,9 @@ for (k,ti) in enumerate(t) @test ti == a[k] end +v = @SMatrix zeros(2,3) +w = MultiValue(v) +@test CartesianIndices(w) == CartesianIndices(v) +@test LinearIndices(w) == LinearIndices(v) + end # module IndexingTests diff --git a/test/TypesTests.jl b/test/TypesTests.jl index bf0d7d1..34bb86c 100644 --- a/test/TypesTests.jl +++ b/test/TypesTests.jl @@ -123,6 +123,12 @@ z = one(TensorValue{3,Int,9}) @test isa(z,TensorValue{3,Int,9}) @test z.array == [1 0 0; 0 1 0; 0 0 1] +# Conversions + +a = @SVector ones(Int,3) +b = convert(VectorValue{3,Int},a) +@test isa(b,VectorValue{3,Int}) + # Custom type printing s = "TensorValues.MultiValue{Tuple{3,2},Float64,2,6}(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)"