Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable more functionality for matrices over NCRing #1499

Merged
merged 1 commit into from
Nov 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 47 additions & 46 deletions src/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -428,16 +428,16 @@ end
#
###############################################################################

Base.ndims(::MatrixElem{T}) where T <: RingElement = 2
Base.ndims(::MatrixElem{T}) where T <: NCRingElement = 2

# Cartesian indexing

Base.eachindex(a::MatrixElem{T}) where T <: RingElement = CartesianIndices((nrows(a), ncols(a)))
Base.eachindex(a::MatrixElem{T}) where T <: NCRingElement = CartesianIndices((nrows(a), ncols(a)))

Base.@propagate_inbounds Base.getindex(a::MatrixElem{T}, I::CartesianIndex) where T <: RingElement =
Base.@propagate_inbounds Base.getindex(a::MatrixElem{T}, I::CartesianIndex) where T <: NCRingElement =
a[I[1], I[2]]

Base.@propagate_inbounds function Base.setindex!(a::MatrixElem{T}, x, I::CartesianIndex) where T <: RingElement
Base.@propagate_inbounds function Base.setindex!(a::MatrixElem{T}, x, I::CartesianIndex) where T <: NCRingElement
a[I[1], I[2]] = x
a
end
Expand Down Expand Up @@ -488,7 +488,7 @@ Base.IteratorEltype(::Type{<:MatrixElem}) = Base.HasEltype() # default
#
###############################################################################

function setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix}, r::AbstractUnitRange{Int}, c::AbstractUnitRange{Int}) where T <: RingElement
function setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix}, r::AbstractUnitRange{Int}, c::AbstractUnitRange{Int}) where T <: NCRingElement
_checkbounds(a, r, c)
size(b) == (length(r), length(c)) || throw(DimensionMismatch("tried to assign a $(size(b, 1))x$(size(b, 2)) matrix to a $(length(r))x$(length(c)) destination"))
startr = first(r)
Expand All @@ -500,7 +500,7 @@ function setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix}, r::AbstractUn
end
end

function setindex!(a::MatrixElem{T}, b::Vector, r::AbstractUnitRange{Int}, c::AbstractUnitRange{Int}) where T <: RingElement
function setindex!(a::MatrixElem{T}, b::Vector, r::AbstractUnitRange{Int}, c::AbstractUnitRange{Int}) where T <: NCRingElement
_checkbounds(a, r, c)
if !((length(r) == 1 && length(c) == length(b)) || length(c) == 1 && length(r) == length(b))
throw(DimensionMismatch("tried to assign vector of length $(length(b)) to a $(length(r))x$(length(c)) destination"))
Expand All @@ -515,35 +515,35 @@ function setindex!(a::MatrixElem{T}, b::Vector, r::AbstractUnitRange{Int}, c::Ab
end

# AbstractUnitRange{Int}, Colon
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::AbstractUnitRange{Int}, ::Colon) where T <: RingElement = setindex!(a, b, r, 1:ncols(a))
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::AbstractUnitRange{Int}, ::Colon) where T <: NCRingElement = setindex!(a, b, r, 1:ncols(a))

# Colon, AbstractUnitRange{Int}
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, ::Colon, c::AbstractUnitRange{Int}) where T <: RingElement = setindex!(a, b, 1:nrows(a), c)
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, ::Colon, c::AbstractUnitRange{Int}) where T <: NCRingElement = setindex!(a, b, 1:nrows(a), c)

# Colon, Colon
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, ::Colon, ::Colon) where T <: RingElement = setindex!(a, b, 1:nrows(a), 1:ncols(a))
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, ::Colon, ::Colon) where T <: NCRingElement = setindex!(a, b, 1:nrows(a), 1:ncols(a))

# Int, AbstractUnitRange{Int}
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Int, c::AbstractUnitRange{Int}) where T <: RingElement = setindex!(a, b, r:r, c)
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Int, c::AbstractUnitRange{Int}) where T <: NCRingElement = setindex!(a, b, r:r, c)

# AbstractUnitRange{Int}, Int
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::AbstractUnitRange{Int}, c::Int) where T <: RingElement = setindex!(a, b, r, c:c)
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::AbstractUnitRange{Int}, c::Int) where T <: NCRingElement = setindex!(a, b, r, c:c)

# Int, Colon
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Int, ::Colon) where T <: RingElement = setindex!(a, b, r:r, 1:ncols(a))
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Int, ::Colon) where T <: NCRingElement = setindex!(a, b, r:r, 1:ncols(a))

# Colon, Int
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, ::Colon, c::Int) where T <: RingElement = setindex!(a, b, 1:nrows(a), c:c)
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, ::Colon, c::Int) where T <: NCRingElement = setindex!(a, b, 1:nrows(a), c:c)

function _setindex!(a::MatrixElem{T}, b, r, c) where T <: RingElement
function _setindex!(a::MatrixElem{T}, b, r, c) where T <: NCRingElement
for (i, i2) in enumerate(r)
for (j, j2) in enumerate(c)
a[i2, j2] = b[i, j]
end
end
end

function _setindex!(a::MatrixElem{T}, b::Vector, r, c) where T <: RingElement
function _setindex!(a::MatrixElem{T}, b::Vector, r, c) where T <: NCRingElement
for (i, i2) in enumerate(r)
for (j, j2) in enumerate(c)
a[i2, j2] = b[i + j - 1]
Expand All @@ -552,25 +552,25 @@ function _setindex!(a::MatrixElem{T}, b::Vector, r, c) where T <: RingElement
end

# Vector{Int}, Vector{Int}
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Vector{Int}, c::Vector{Int}) where T <: RingElement = _setindex!(a, b, r, c)
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Vector{Int}, c::Vector{Int}) where T <: NCRingElement = _setindex!(a, b, r, c)

# Vector{Int}, AbstractUnitRange{Int}
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Vector{Int}, c::AbstractUnitRange{Int}) where T <: RingElement = _setindex!(a, b, r, c)
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Vector{Int}, c::AbstractUnitRange{Int}) where T <: NCRingElement = _setindex!(a, b, r, c)

# AbstractUnitRange{Int}, Vector{Int}
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::AbstractUnitRange{Int}, c::Vector{Int}) where T <: RingElement = _setindex!(a, b, r, c)
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::AbstractUnitRange{Int}, c::Vector{Int}) where T <: NCRingElement = _setindex!(a, b, r, c)

# Vector{Int}, Colon
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Vector{Int}, ::Colon) where T <: RingElement = _setindex!(a, b, r, 1:ncols(a))
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Vector{Int}, ::Colon) where T <: NCRingElement = _setindex!(a, b, r, 1:ncols(a))

# Colon, Vector{Int}
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, ::Colon, c::Vector{Int}) where T <: RingElement = _setindex!(a, b, 1:nrows(a), c)
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, ::Colon, c::Vector{Int}) where T <: NCRingElement = _setindex!(a, b, 1:nrows(a), c)

# Int, Vector{Int}
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Int, c::Vector{Int}) where T <: RingElement = setindex!(a, b, r:r, c)
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Int, c::Vector{Int}) where T <: NCRingElement = setindex!(a, b, r:r, c)

# Vector{Int}, Int
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Vector{Int}, c::Int) where T <: RingElement = setindex!(a, b, r, c:c)
setindex!(a::MatrixElem{T}, b::Union{MatrixElem, Matrix, Vector}, r::Vector{Int}, c::Int) where T <: NCRingElement = setindex!(a, b, r, c:c)

################################################################################
#
Expand Down Expand Up @@ -1162,7 +1162,7 @@ end
###############################################################################

@doc raw"""
==(x::MatrixElem{T}, y::Union{Integer, Rational, AbstractFloat}) where T <: RingElement
==(x::MatrixElem{T}, y::Union{Integer, Rational, AbstractFloat}) where T <: NCRingElement

Return `true` if $x == S(y)$ arithmetically, where $S$ is the parent of $x$,
otherwise return `false`.
Expand All @@ -1184,7 +1184,7 @@ function ==(x::MatrixElem{T}, y::Union{Integer, Rational, AbstractFloat}) where
end

@doc raw"""
==(x::Union{Integer, Rational, AbstractFloat}, y::MatrixElem{T}) where T <: RingElement
==(x::Union{Integer, Rational, AbstractFloat}, y::MatrixElem{T}) where T <: NCRingElement

Return `true` if $S(x) == y$ arithmetically, where $S$ is the parent of $y$,
otherwise return `false`.
Expand Down Expand Up @@ -1312,7 +1312,7 @@ end


@doc raw"""
transpose(x::MatrixElem{T}) where T <: RingElement
transpose(x::MatrixElem{T}) where T <: NCRingElement

Return the transpose of the given matrix.

Expand All @@ -1337,7 +1337,8 @@ julia> B = transpose(A)
[ 1 t t^2 + t + 1]

```
""" transpose(x::MatrixElem{T}) where T <: RingElement
"""
transpose(x::MatrixElem{T}) where T <: NCRingElement


###############################################################################
Expand Down Expand Up @@ -1418,7 +1419,7 @@ end
###############################################################################

@doc raw"""
tr(x::MatrixElem{T}) where T <: RingElement
tr(x::MatrixElem{T}) where T <: NCRingElement

Return the trace of the matrix $a$, i.e. the sum of the diagonal elements. We
require the matrix to be square.
Expand All @@ -1443,7 +1444,7 @@ t^2 + 3*t + 2

```
"""
function tr(x::MatrixElem{T}) where T <: RingElement
function tr(x::MatrixElem{T}) where T <: NCRingElement
!is_square(x) && error("Not a square matrix in trace")
d = zero(base_ring(x))
for i = 1:nrows(x)
Expand Down Expand Up @@ -1504,7 +1505,7 @@ end
###############################################################################

@doc raw"""
*(P::perm, x::MatrixElem{T}) where T <: RingElement
*(P::perm, x::MatrixElem{T}) where T <: NCRingElement

Apply the pemutation $P$ to the rows of the matrix $x$ and return the result.

Expand Down Expand Up @@ -1536,7 +1537,7 @@ julia> B = P*A

```
"""
function *(P::Perm, x::MatrixElem{T}) where T <: RingElement
function *(P::Perm, x::MatrixElem{T}) where T <: NCRingElement
z = similar(x)
m = nrows(x)
n = ncols(x)
Expand Down Expand Up @@ -6165,7 +6166,7 @@ end
###############################################################################

@doc raw"""
swap_rows(a::MatrixElem{T}, i::Int, j::Int) where T <: RingElement
swap_rows(a::MatrixElem{T}, i::Int, j::Int) where T <: NCRingElement

Return a matrix $b$ with the entries of $a$, where the $i$th and $j$th
row are swapped.
Expand All @@ -6188,15 +6189,15 @@ julia> M # was not modified
[0 0 1]
```
"""
function swap_rows(a::MatrixElem{T}, i::Int, j::Int) where T <: RingElement
function swap_rows(a::MatrixElem{T}, i::Int, j::Int) where T <: NCRingElement
(1 <= i <= nrows(a) && 1 <= j <= nrows(a)) || throw(BoundsError())
b = deepcopy(a)
swap_rows!(b, i, j)
return b
end

@doc raw"""
swap_rows!(a::MatrixElem{T}, i::Int, j::Int) where T <: RingElement
swap_rows!(a::MatrixElem{T}, i::Int, j::Int) where T <: NCRingElement

Swap the $i$th and $j$th row of $a$ in place. The function returns the mutated
matrix (since matrices are assumed to be mutable in AbstractAlgebra.jl).
Expand All @@ -6219,7 +6220,7 @@ julia> M # was modified
[0 0 1]
```
"""
function swap_rows!(a::MatrixElem{T}, i::Int, j::Int) where T <: RingElement
function swap_rows!(a::MatrixElem{T}, i::Int, j::Int) where T <: NCRingElement
(1 <= i <= nrows(a) && 1 <= j <= nrows(a)) || throw(BoundsError())
if i != j
for k = 1:ncols(a)
Expand All @@ -6232,25 +6233,25 @@ function swap_rows!(a::MatrixElem{T}, i::Int, j::Int) where T <: RingElement
end

@doc raw"""
swap_cols(a::MatrixElem{T}, i::Int, j::Int) where T <: RingElement
swap_cols(a::MatrixElem{T}, i::Int, j::Int) where T <: NCRingElement

Return a matrix $b$ with the entries of $a$, where the $i$th and $j$th
row are swapped.
"""
function swap_cols(a::MatrixElem{T}, i::Int, j::Int) where T <: RingElement
function swap_cols(a::MatrixElem{T}, i::Int, j::Int) where T <: NCRingElement
(1 <= i <= ncols(a) && 1 <= j <= ncols(a)) || throw(BoundsError())
b = deepcopy(a)
swap_cols!(b, i, j)
return b
end

@doc raw"""
swap_cols!(a::MatrixElem{T}, i::Int, j::Int) where T <: RingElement
swap_cols!(a::MatrixElem{T}, i::Int, j::Int) where T <: NCRingElement

Swap the $i$th and $j$th column of $a$ in place. The function returns the mutated
matrix (since matrices are assumed to be mutable in AbstractAlgebra.jl).
"""
function swap_cols!(a::MatrixElem{T}, i::Int, j::Int) where T <: RingElement
function swap_cols!(a::MatrixElem{T}, i::Int, j::Int) where T <: NCRingElement
if i != j
for k = 1:nrows(a)
x = a[k, i]
Expand All @@ -6262,12 +6263,12 @@ function swap_cols!(a::MatrixElem{T}, i::Int, j::Int) where T <: RingElement
end

@doc raw"""
reverse_rows!(a::MatrixElem{T}) where T <: RingElement
reverse_rows!(a::MatrixElem{T}) where T <: NCRingElement

Swap the $i$th and $r - i$th row of $a$ for $1 \leq i \leq r/2$,
where $r$ is the number of rows of $a$.
"""
function reverse_rows!(a::MatrixElem{T}) where T <: RingElement
function reverse_rows!(a::MatrixElem{T}) where T <: NCRingElement
k = div(nrows(a), 2)
for i in 1:k
swap_rows!(a, i, nrows(a) - i + 1)
Expand All @@ -6276,24 +6277,24 @@ function reverse_rows!(a::MatrixElem{T}) where T <: RingElement
end

@doc raw"""
reverse_rows(a::MatrixElem{T}) where T <: RingElement
reverse_rows(a::MatrixElem{T}) where T <: NCRingElement

Return a matrix $b$ with the entries of $a$, where the $i$th and $r - i$th
row is swapped for $1 \leq i \leq r/2$. Here $r$ is the number of rows of
$a$.
"""
function reverse_rows(a::MatrixElem{T}) where T <: RingElement
function reverse_rows(a::MatrixElem{T}) where T <: NCRingElement
b = deepcopy(a)
return reverse_rows!(b)
end

@doc raw"""
reverse_cols!(a::MatrixElem{T}) where T <: RingElement
reverse_cols!(a::MatrixElem{T}) where T <: NCRingElement

Swap the $i$th and $r - i$th column of $a$ for $1 \leq i \leq c/2$,
where $c$ is the number of columns of $a$.
"""
function reverse_cols!(a::MatrixElem{T}) where T <: RingElement
function reverse_cols!(a::MatrixElem{T}) where T <: NCRingElement
k = div(ncols(a), 2)
for i in 1:k
swap_cols!(a, i, ncols(a) - i + 1)
Expand All @@ -6302,13 +6303,13 @@ function reverse_cols!(a::MatrixElem{T}) where T <: RingElement
end

@doc raw"""
reverse_cols(a::MatrixElem{T}) where T <: RingElement
reverse_cols(a::MatrixElem{T}) where T <: NCRingElement

Return a matrix $b$ with the entries of $a$, where the $i$th and $r - i$th
column is swapped for $1 \leq i \leq c/2$. Here $c$ is the number of columns
of$a$.
"""
function reverse_cols(a::MatrixElem{T}) where T <: RingElement
function reverse_cols(a::MatrixElem{T}) where T <: NCRingElement
b = deepcopy(a)
return reverse_cols!(b)
end
Expand Down
Loading