From f914f6eae61167616ebac44b343059ced45365dc Mon Sep 17 00:00:00 2001 From: fverdugo Date: Tue, 14 May 2019 14:19:43 +0200 Subject: [PATCH] Added one and zero --- src/TensorValues.jl | 1 + src/Types.jl | 18 +++++++++++++----- test/TypesTests.jl | 22 ++++++++++++++++++++-- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/TensorValues.jl b/src/TensorValues.jl index ec8121e..b7641d9 100644 --- a/src/TensorValues.jl +++ b/src/TensorValues.jl @@ -6,6 +6,7 @@ export MultiValue export TensorValue export VectorValue import Base: show +import Base: zero, one include("Types.jl") diff --git a/src/Types.jl b/src/Types.jl index ec10584..5e49ff8 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -89,14 +89,22 @@ function VectorValue(args::Vararg{T,D}) where {T,D} VectorValue{D,T}(args) end -# Custom type printing +# Initializers + +function zero(::Type{<:MultiValue{S,T,N,L}}) where {S,T,N,L} + z = zero(SArray{S,T,N,L}) + MultiValue{S,T,N,L}(z) +end -function show(io::IO,::Type{<:MultiValue{S,T}}) where {S,T} - print(io,"MultiValue{$S,$T}") +function one(::Type{<:MultiValue{S,T,N,L}}) where {S,T,N,L} + z = one(SArray{S,T,N,L}) + MultiValue{S,T,N,L}(z) end -function show(io::IO,::Type{<:TensorValue{D,T}}) where {D,T} - print(io,"TensorValue{$D,$T}") +# Custom type printing + +function show(io::IO,::Type{<:TensorValue{D,T,L}}) where {D,T,L} + print(io,"TensorValue{$D,$T,$L}") end function show(io::IO,::Type{<:VectorValue{D,T}}) where {D,T} diff --git a/test/TypesTests.jl b/test/TypesTests.jl index 71d276d..bc4bfb0 100644 --- a/test/TypesTests.jl +++ b/test/TypesTests.jl @@ -72,12 +72,30 @@ g = VectorValue(1,2,3,4) @test isa(g,VectorValue{4,Int}) @test g.array == [1,2,3,4] +# Initializers + +z = zero(MultiValue{Tuple{3,2},Int,2,6}) +@test isa(z,MultiValue{Tuple{3,2},Int,2,6}) +@test z.array == zeros(Int,(3,2)) + +z = zero(TensorValue{3,Int,9}) +@test isa(z,TensorValue{3,Int,9}) +@test z.array == zeros(Int,(3,3)) + +z = zero(VectorValue{3,Int}) +@test isa(z,VectorValue{3,Int}) +@test z.array == zeros(Int,3) + +z = one(TensorValue{3,Int,9}) +@test isa(z,TensorValue{3,Int,9}) +@test z.array == [1 0 0; 0 1 0; 0 0 1] + # Custom type printing -s = "MultiValue{Tuple{3,2},Float64}(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)" +s = "MultiValue{Tuple{3,2},Float64,2,6}(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)" @test string(v) == s -s = "TensorValue{2,Int64}(1, 2, 3, 4)" +s = "TensorValue{2,Int64,4}(1, 2, 3, 4)" @test string(t) == s s = "VectorValue{4,Int64}(1, 2, 3, 4)"