From 47d85a1dcc469d2f07011d4095515a0634fc4be3 Mon Sep 17 00:00:00 2001 From: fverdugo Date: Tue, 14 May 2019 17:21:51 +0200 Subject: [PATCH] Added indexing --- src/Indexing.jl | 21 +++++++++++++++++++++ src/TensorValues.jl | 6 ++++++ test/IndexingTests.jl | 33 +++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++++ 4 files changed, 64 insertions(+) create mode 100644 src/Indexing.jl create mode 100644 test/IndexingTests.jl diff --git a/src/Indexing.jl b/src/Indexing.jl new file mode 100644 index 0000000..85ec476 --- /dev/null +++ b/src/Indexing.jl @@ -0,0 +1,21 @@ + +size(a::MultiValue) = size(a.array) + +length(a::MultiValue) = length(a.array) + +@propagate_inbounds function getindex( + a::MultiValue{S,T,N}, I::Vararg{Integer,N}) where {S,T,N} + @inbounds a.array[I...] +end + +@propagate_inbounds function getindex(a::MultiValue, i::Integer) + @inbounds a.array[i] +end + +eltype(a::Type{MultiValue{S,T,N,L}}) where {S,T,N,L} = T + +iterate(a::MultiValue) = iterate(a.array) + +iterate(a::MultiValue, state) = iterate(a.array, state) + +eachindex(a::MultiValue) = eachindex(a.array) diff --git a/src/TensorValues.jl b/src/TensorValues.jl index 409dedc..bac92dd 100644 --- a/src/TensorValues.jl +++ b/src/TensorValues.jl @@ -1,6 +1,7 @@ module TensorValues using StaticArrays +using Base: @propagate_inbounds export MultiValue export TensorValue @@ -12,10 +13,15 @@ export det, inv import Base: show import Base: zero, one import Base: +, -, *, /, \, ==, ≈ +import Base: getindex, iterate, eachindex +import Base: size, length, eltype + import LinearAlgebra: det, inv include("Types.jl") +include("Indexing.jl") + include("Operations.jl") end # module diff --git a/test/IndexingTests.jl b/test/IndexingTests.jl new file mode 100644 index 0000000..b9e5b29 --- /dev/null +++ b/test/IndexingTests.jl @@ -0,0 +1,33 @@ +module IndexingTests + +using Test +using TensorValues + +a = (3,4,5,1) + +v = VectorValue{4}(a) + +@test eltype(v) == Int +@test eltype(typeof(v)) == Int + +@test size(v) == (4,) +@test length(v) == 4 + +for (k,i) in enumerate(eachindex(v)) + @test v[i] == a[k] +end + +t = TensorValue{2}(a) + +@test size(t) == (2,2) +@test length(t) == 4 + +for (k,i) in enumerate(eachindex(t)) + @test t[i] == a[k] +end + +@test t[2,1] == 4 + +@test t[2] == 4 + +end # module IndexingTests diff --git a/test/runtests.jl b/test/runtests.jl index a630cec..fdc187d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,4 +11,8 @@ end include("OperationsTests.jl") end +@testset "IndexingTests" begin + include("IndexingTests.jl") +end + end # module TensorValuesTests