Skip to content

Commit

Permalink
Adding more operations
Browse files Browse the repository at this point in the history
  • Loading branch information
fverdugo committed Oct 14, 2019
1 parent 6c0aab2 commit 0f74b46
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 2 deletions.
33 changes: 33 additions & 0 deletions src/Operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ function (*)(a::MultiValue,b::MultiValue)
MultiValue(r)
end

@inline dot(u::VectorValue,v::VectorValue) = inner(u,v)

# Inner product (full contraction)

inner(a::Real,b::Real) = a*b
Expand Down Expand Up @@ -105,7 +107,38 @@ meas(a::VectorValue) = sqrt(inner(a,a))

meas(a::TensorValue) = abs(det(a))

@inline norm(u::VectorValue) = sqrt(inner(u,u))

# conj

conj(a::MultiValue) = MultiValue(conj(a.array))

# Trace

@generated function trace(v::TensorValue{D}) where D
str = join([" v.array.data[$i+$((i-1)*D)] +" for i in 1:D ])
Meta.parse(str[1:(end-1)])
end

@inline tr(v::TensorValue) = trace(v)

# Adjoint

function adjoint(v::TensorValue)
t = adjoint(v.array)
TensorValue(t)
end

# Symmetric part

@generated function symmetic_part(v::TensorValue{D}) where D
str = "("
for j in 1:D
for i in 1:D
str *= "0.5*v.array.data[$i+$((j-1)*D)] + 0.5*v.array.data[$j+$((i-1)*D)], "
end
end
str *= ")"
Meta.parse("TensorValue($str)")
end

7 changes: 5 additions & 2 deletions src/TensorValues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ export TensorValue
export VectorValue

export inner, outer, meas
export det, inv
export det, inv, tr, dot, norm
export mutable
export trace
export symmetic_part

import Base: show
import Base: zero, one
Expand All @@ -22,8 +24,9 @@ import Base: reinterpret
import Base: convert
import Base: CartesianIndices
import Base: LinearIndices
import Base: adjoint

import LinearAlgebra: det, inv
import LinearAlgebra: det, inv, tr, dot, norm

include("Types.jl")

Expand Down
20 changes: 20 additions & 0 deletions test/OperationsTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,24 @@ R = fill(r,(4,5))
v = VectorValue(1,0)
@test v == v'

t = TensorValue(1,2,3,4)
@test trace(t) == 5
@test tr(t) == 5

t = TensorValue(1,2,3,4,5,6,7,8,9)
@test trace(t) == 15
@test tr(t) == 15

@test symmetic_part(t) == TensorValue(1.0, 3.0, 5.0, 3.0, 5.0, 7.0, 5.0, 7.0, 9.0)

a = TensorValue(1,2,3,4)
b = a'
@test b == TensorValue(1,3,2,4)
@test a*b == TensorValue(10,14,14,20)

u = VectorValue(1.0,2.0)
v = VectorValue(2.0,3.0)
@test dot(u,v) inner(u,v)
@test norm(u) sqrt(inner(u,u))

end # module OperationsTests

0 comments on commit 0f74b46

Please sign in to comment.