Skip to content

Commit

Permalink
Added operations
Browse files Browse the repository at this point in the history
  • Loading branch information
fverdugo committed May 14, 2019
1 parent 5697839 commit ee998ef
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Francesc Verdugo <[email protected]>"]
version = "0.1.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extras]
Expand Down
91 changes: 91 additions & 0 deletions src/Operations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@

# Comparison

function (==)(a::MultiValue,b::MultiValue)
a.array == b.array
end

function ()(a::MultiValue,b::MultiValue)
a.array b.array
end

# Addition / subtraction

for op in (:+,:-)
@eval begin

function ($op)(a::MultiValue{S}) where S
r = $op(a.array)
MultiValue(r)
end

function ($op)(a::MultiValue{S},b::MultiValue{S}) where S
r = $op(a.array, b.array)
MultiValue(r)
end

end
end

# Matrix Division

function (\)(a::TensorValue, b::MultiValue)
r = a.array \ b.array
MultiValue(r)
end

# Scaling by a scalar

function (*)(a::MultiValue,b::Real)
r = a.array * b
MultiValue(r)
end

function (*)(a::Real,b::MultiValue)
r = a * b.array
MultiValue(r)
end

# Dot product (simple contraction)

(*)(a::VectorValue{D}, b::VectorValue{D}) where D = inner(a,b)

function (*)(a::MultiValue,b::MultiValue)
r = a.array * b.array
MultiValue(r)
end

# Inner product (full contraction)

inner(a::Real,b::Real) = a*b

@generated function inner(a::MultiValue{S,T,N,L}, b::MultiValue{S,W,N,L}) where {S,T,N,L,W}
str = join([" a.array.data[$i]*b.array.data[$i] +" for i in 1:L ])
Meta.parse(str[1:(end-1)])
end

# Outer product (aka dyadic product)

outer(a::Real,b::Real) = a*b

outer(a::MultiValue,b::Real) = a*b

outer(a::Real,b::MultiValue) = a*b

@generated function outer(a::VectorValue{D},b::VectorValue{Z}) where {D,Z}
str = join(["a.array[$i]*b.array[$j], " for j in 1:Z for i in 1:D])
Meta.parse("MultiValue(SMatrix{$D,$Z}($str))")
end

# Linear Algebra

det(a::TensorValue) = det(a.array)

inv(a::TensorValue) = MultiValue(inv(a.array))

# Measure

meas(a::VectorValue) = sqrt(inner(a,a))

meas(a::TensorValue) = abs(det(a))

8 changes: 8 additions & 0 deletions src/TensorValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,17 @@ using StaticArrays
export MultiValue
export TensorValue
export VectorValue

export inner, outer, meas
export det, inv

import Base: show
import Base: zero, one
import Base: +, -, *, /, \, ==,
import LinearAlgebra: det, inv

include("Types.jl")

include("Operations.jl")

end # module
160 changes: 160 additions & 0 deletions test/OperationsTests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
module OperationsTests

using Test
using TensorValues

a = VectorValue(1,2,3)
b = VectorValue(2,1,6)

# Comparison

@test a==a
@test a a
@test a!=b

# Addition / subtraction

c = +a
r = VectorValue(1,2,3)
@test c == r

c = -a
r = VectorValue(-1,-2,-3)
@test c == r

c = a + b
r = VectorValue(3,3,9)
@test c == r

c = a - b
r = VectorValue(-1,1,-3)
@test c == r

# Matrix Division

t = one(TensorValue{3,Int,9})

c = t\a

@test c == a

# Scaling by a scalar

a = VectorValue(1,2,3)
r = VectorValue(2,4,6)

c = 2 * a
@test isa(c,VectorValue{3,Int})
@test c == r

c = a * 2
@test isa(c,VectorValue{3,Int})
@test c == r

# Dot product (simple contraction)

a = VectorValue(1,2,3)
b = VectorValue(2,1,6)

t = TensorValue(1,2,3,4,5,6,7,8,9)
s = TensorValue(9,8,3,4,5,6,7,2,1)

c = a * b
@test isa(c,Int)
@test c == 2+2+18

c = t * a
@test isa(c,VectorValue{3,Int})
r = VectorValue(30,36,42)
@test c == r

c = s * t
@test isa(c,TensorValue{3,Int})
r = TensorValue(38,24,18,98,69,48,158,114,78)
@test c == r

# Inner product (full contraction)

c = inner(2,3)
@test c == 6

c = inner(a,b)
@test isa(c,Int)
@test c == 2+2+18

c = inner(t,s)
@test isa(c,Int)
@test c == 185

# Outer product (aka dyadic product)

a = VectorValue(1,2,3)
e = VectorValue(2,5)

c = outer(2,3)
@test c == 6

r = VectorValue(2,4,6)
c = outer(2,a)
@test isa(c,VectorValue{3,Int})
@test c == r

c = outer(a,2)
@test isa(c,VectorValue{3,Int})
@test c == r

c = outer(a,e)
@test isa(c,MultiValue{Tuple{3,2},Int})
r = MultiValue{Tuple{3,2},Int}(2,4,6,5,10,15)
@test c == r

# Linear Algebra

t = TensorValue(10,2,30,4,5,6,70,8,9)

c = det(t)
@test c -8802.0

c = inv(t)
@test isa(c,TensorValue{3})

# Measure

a = VectorValue(1,2,3)
c = meas(a)
@test c 3.7416573867739413

t = TensorValue(10,2,30,4,5,6,70,8,9)
c = meas(t)
@test c 8802.0

# Broadcasted operations

a = VectorValue(1,2,3)
b = VectorValue(2,1,6)

A = Array{VectorValue{3,Int},2}(undef,(4,5))
A .= a
@test A == fill(a,(4,5))

C = .- A

r = VectorValue(-1,-2,-3)
R = fill(r,(4,5))
@test C == R

B = fill(b,(4,1))

C = A .+ B

r = VectorValue(3,3,9)
R = fill(r,(4,5))
@test C == R

C = A .- B

r = VectorValue(-1,1,-3)
R = fill(r,(4,5))
@test C == R

end # module OperationsTests
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
module TensorValuesTests

using TensorValues
using Test

@testset "TypesTests" begin
include("TypesTests.jl")
end

@testset "OperationsTests" begin
include("OperationsTests.jl")
end

end # module TensorValuesTests

0 comments on commit ee998ef

Please sign in to comment.