Skip to content

Commit

Permalink
Merge pull request #86 from longemen3000/LinearAlgebraExt
Browse files Browse the repository at this point in the history
add LinearAlgebraExt
  • Loading branch information
jw3126 authored Aug 21, 2024
2 parents 71fb5a5 + 6947d65 commit 3f1e80e
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 37 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ ConstructionBaseStaticArraysExt = "StaticArrays"
IntervalSets = "0.5, 0.6, 0.7"
StaticArrays = "1"
julia = "1"
LinearAlgebra = "<0.0.1,1"

[extras]
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[targets]
test = ["IntervalSets","StaticArrays","Test"]
test = ["IntervalSets","LinearAlgebra","StaticArrays","Test"]
46 changes: 46 additions & 0 deletions ext/ConstructionBaseLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
module ConstructionBaseLinearAlgebraExt

import ConstructionBase
import LinearAlgebra

### Tridiagonal

function tridiagonal_constructor(dl::V, d::V, du::V) where {V<:AbstractVector{T}} where T
LinearAlgebra.Tridiagonal{T,V}(dl, d, du)
end
function tridiagonal_constructor(dl::V, d::V, du::V, du2::V) where {V<:AbstractVector{T}} where T
LinearAlgebra.Tridiagonal{T,V}(dl, d, du, du2)
end

# `du2` may be undefined, so we need a custom `getfields` that checks `isdefined`
function ConstructionBase.getfields(o::LinearAlgebra.Tridiagonal)
if isdefined(o, :du2)
(dl=o.dl, d=o.d, du=o.du, du2=o.du2)
else
(dl=o.dl, d=o.d, du=o.du)
end
end

ConstructionBase.constructorof(::Type{<:LinearAlgebra.Tridiagonal}) = tridiagonal_constructor

### Cholesky

ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{()}) = C

function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:L,),<:Tuple{<:LinearAlgebra.LowerTriangular}})
return LinearAlgebra.Cholesky(C.uplo === 'U' ? copy(patch.L.data') : patch.L.data, C.uplo, C.info)
end
function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:U,),<:Tuple{<:LinearAlgebra.UpperTriangular}})
return LinearAlgebra.Cholesky(C.uplo === 'L' ? copy(patch.U.data') : patch.U.data, C.uplo, C.info)
end
function ConstructionBase.setproperties(
C::LinearAlgebra.Cholesky,
patch::NamedTuple{(:UL,),<:Tuple{<:Union{LinearAlgebra.LowerTriangular,LinearAlgebra.UpperTriangular}}}
)
return LinearAlgebra.Cholesky(patch.UL.data, C.uplo, C.info)
end
function ConstructionBase.setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple)
throw(ArgumentError("Invalid patch for `Cholesky`: $(patch)"))
end

end #module
3 changes: 3 additions & 0 deletions src/ConstructionBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,7 @@ end
include("nonstandard.jl")
include("functions.jl")

#unconditionally include the extension for now
include("../ext/ConstructionBaseLinearAlgebraExt.jl")

end # module
36 changes: 0 additions & 36 deletions src/nonstandard.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using LinearAlgebra

### SubArray
# `offset1` and `stride1` fields are calculated from parent indices.
Expand Down Expand Up @@ -28,24 +27,7 @@ end
constructorof(::Type{<:PermutedDimsArray{<:Any,N,perm,iperm,<:Any}}) where {N,perm,iperm} =
PermutedDimsArrayConstructor{N,perm,iperm}()

### Tridiagonal
function tridiagonal_constructor(dl::V, d::V, du::V) where {V<:AbstractVector{T}} where T
Tridiagonal{T,V}(dl, d, du)
end
function tridiagonal_constructor(dl::V, d::V, du::V, du2::V) where {V<:AbstractVector{T}} where T
Tridiagonal{T,V}(dl, d, du, du2)
end

# `du2` may be undefined, so we need a custom `getfields` that checks `isdefined`
function getfields(o::Tridiagonal)
if isdefined(o, :du2)
(dl=o.dl, d=o.d, du=o.du, du2=o.du2)
else
(dl=o.dl, d=o.d, du=o.du)
end
end

constructorof(::Type{<:LinearAlgebra.Tridiagonal}) = tridiagonal_constructor

### LinRange
# `lendiv` is a calculated field
Expand All @@ -56,21 +38,3 @@ constructorof(::Type{<:LinRange}) = linrange_constructor
### Expr: args get splatted
# ::Expr annotation is to make it type-stable on Julia 1.3-
constructorof(::Type{<:Expr}) = (head, args) -> Expr(head, args...)::Expr

### Cholesky
setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{()}) = C
function setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:L,),<:Tuple{<:LinearAlgebra.LowerTriangular}})
return LinearAlgebra.Cholesky(C.uplo === 'U' ? copy(patch.L.data') : patch.L.data, C.uplo, C.info)
end
function setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple{(:U,),<:Tuple{<:LinearAlgebra.UpperTriangular}})
return LinearAlgebra.Cholesky(C.uplo === 'L' ? copy(patch.U.data') : patch.U.data, C.uplo, C.info)
end
function setproperties(
C::LinearAlgebra.Cholesky,
patch::NamedTuple{(:UL,),<:Tuple{<:Union{LinearAlgebra.LowerTriangular,LinearAlgebra.UpperTriangular}}}
)
return LinearAlgebra.Cholesky(patch.UL.data, C.uplo, C.info)
end
function setproperties(C::LinearAlgebra.Cholesky, patch::NamedTuple)
throw(ArgumentError("Invalid patch for `Cholesky`: $(patch)"))
end

0 comments on commit 3f1e80e

Please sign in to comment.