Skip to content

Commit

Permalink
Added one and zero
Browse files Browse the repository at this point in the history
  • Loading branch information
fverdugo committed May 14, 2019
1 parent 3925fb6 commit f914f6e
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/TensorValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export MultiValue
export TensorValue
export VectorValue
import Base: show
import Base: zero, one

include("Types.jl")

Expand Down
18 changes: 13 additions & 5 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,22 @@ function VectorValue(args::Vararg{T,D}) where {T,D}
VectorValue{D,T}(args)
end

# Custom type printing
# Initializers

function zero(::Type{<:MultiValue{S,T,N,L}}) where {S,T,N,L}
z = zero(SArray{S,T,N,L})
MultiValue{S,T,N,L}(z)
end

function show(io::IO,::Type{<:MultiValue{S,T}}) where {S,T}
print(io,"MultiValue{$S,$T}")
function one(::Type{<:MultiValue{S,T,N,L}}) where {S,T,N,L}
z = one(SArray{S,T,N,L})
MultiValue{S,T,N,L}(z)
end

function show(io::IO,::Type{<:TensorValue{D,T}}) where {D,T}
print(io,"TensorValue{$D,$T}")
# Custom type printing

function show(io::IO,::Type{<:TensorValue{D,T,L}}) where {D,T,L}
print(io,"TensorValue{$D,$T,$L}")
end

function show(io::IO,::Type{<:VectorValue{D,T}}) where {D,T}
Expand Down
22 changes: 20 additions & 2 deletions test/TypesTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,30 @@ g = VectorValue(1,2,3,4)
@test isa(g,VectorValue{4,Int})
@test g.array == [1,2,3,4]

# Initializers

z = zero(MultiValue{Tuple{3,2},Int,2,6})
@test isa(z,MultiValue{Tuple{3,2},Int,2,6})
@test z.array == zeros(Int,(3,2))

z = zero(TensorValue{3,Int,9})
@test isa(z,TensorValue{3,Int,9})
@test z.array == zeros(Int,(3,3))

z = zero(VectorValue{3,Int})
@test isa(z,VectorValue{3,Int})
@test z.array == zeros(Int,3)

z = one(TensorValue{3,Int,9})
@test isa(z,TensorValue{3,Int,9})
@test z.array == [1 0 0; 0 1 0; 0 0 1]

# Custom type printing

s = "MultiValue{Tuple{3,2},Float64}(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)"
s = "MultiValue{Tuple{3,2},Float64,2,6}(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)"
@test string(v) == s

s = "TensorValue{2,Int64}(1, 2, 3, 4)"
s = "TensorValue{2,Int64,4}(1, 2, 3, 4)"
@test string(t) == s

s = "VectorValue{4,Int64}(1, 2, 3, 4)"
Expand Down

0 comments on commit f914f6e

Please sign in to comment.