Skip to content

Commit

Permalink
Add mutating arithmetic for SRows (#1659)
Browse files Browse the repository at this point in the history
* Make `scale_row!` follow its docstring

i.e. don't throw on zero scalars, and coerce scalars if needed

* Fix a docstring typo

* Let `add_scaled_row` coerce the scalar

* Add mutating arithmetics for SRow

* Skip deepcopy in addmul! in case of aliasing

* Add `submul!`

* Add tests

* Some fixes (tests run now)

* Comment out tests

* Adapt to AA changes

* Bump AbstractAlgebra compat
  • Loading branch information
lgoettgens authored Nov 13, 2024
1 parent 5341ab1 commit ad39e8f
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ GAPExt = "GAP"
PolymakeExt = "Polymake"

[compat]
AbstractAlgebra = "^0.43.1"
AbstractAlgebra = "^0.43.10"
Dates = "1.6"
Distributed = "1.6"
GAP = "0.9.6, 0.10, 0.11, 0.12"
Expand Down
184 changes: 165 additions & 19 deletions src/Sparse/Row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ function Base.empty!(A::SRow)
return A
end

function Base.empty(A::SRow)
return sparse_row(base_ring(A))
end

function zero(A::SRow)
return empty(A)
end

function swap!(A::SRow, B::SRow)
A.pos, B.pos = B.pos, A.pos
A.values, B.values = B.values, A.values
Expand Down Expand Up @@ -447,15 +455,17 @@ end
# Inplace scaling
#
################################################################################

@doc raw"""
scale_row!(a::SRow, b::NCRingElem) -> SRow
Returns the (left) product of $b \times a$ and reassigns the value of $a$ to this product.
For rows, the standard multiplication is from the left.
"""
function scale_row!(a::SRow{T}, b::T) where T
@assert !iszero(b)
if isone(b)
if iszero(b)
return empty!(a)
elseif isone(b)
return a
end
i = 1
Expand All @@ -465,20 +475,23 @@ function scale_row!(a::SRow{T}, b::T) where T
deleteat!(a.values, i)
deleteat!(a.pos, i)
else
i += 1
i += 1
end
end
return a
end

scale_row!(a::SRow, b) = scale_row!(a, base_ring(a)(b))

@doc raw"""
scale_row_right!(a::SRow, b::NCRingElem) -> SRow
Returns the (right) product of $a \times b$ and modifies $a$ to this product.
"""
function scale_row_right!(a::SRow{T}, b::T) where T
@assert !iszero(b)
if isone(b)
if iszero(b)
return empty!(a)
elseif isone(b)
return a
end
i = 1
Expand All @@ -488,16 +501,20 @@ function scale_row_right!(a::SRow{T}, b::T) where T
deleteat!(a.values, i)
deleteat!(a.pos, i)
else
i += 1
i += 1
end
end
return a
end

scale_row_right!(a::SRow, b) = scale_row_right!(a, base_ring(a)(b))

function scale_row_left!(a::SRow{T}, b::T) where T
return scale_row!(a,b)
end

scale_row_left!(a::SRow, b) = scale_row_left!(a, base_ring(a)(b))

################################################################################
#
# Addition
Expand All @@ -506,22 +523,22 @@ end

function +(A::SRow{T}, B::SRow{T}) where T
if length(A.values) == 0
return B
return deepcopy(B)
elseif length(B.values) == 0
return A
return deepcopy(A)
end
return add_scaled_row(A, B, one(base_ring(A)))
end

function -(A::SRow{T}, B::SRow{T}) where T
if length(A) == 0
if length(B) == 0
return A
return deepcopy(A)
else
return add_scaled_row(B, A, base_ring(B)(-1))
return add_scaled_row(B, A, -1)
end
end
return add_scaled_row(B, A, base_ring(A)(-1))
return add_scaled_row(B, A, -1)
end

function -(A::SRow{T}) where {T}
Expand Down Expand Up @@ -683,10 +700,10 @@ end
Returns the row $c A + B$.
"""
add_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_scaled_row!(a, deepcopy(b), c)
add_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_scaled_row!(a, deepcopy(b), c)

add_left_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_left_scaled_row!(a, deepcopy(b), c)
add_right_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_right_scaled_row!(a, deepcopy(b), c)
add_left_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_left_scaled_row!(a, deepcopy(b), c)
add_right_scaled_row(a::SRow{T}, b::SRow{T}, c) where {T} = add_right_scaled_row!(a, deepcopy(b), c)



Expand All @@ -696,7 +713,9 @@ add_right_scaled_row(a::SRow{T}, b::SRow{T}, c::T) where {T} = add_right_scaled_
Adds the left scaled row $c A$ to $B$.
"""
function add_scaled_row!(a::SRow{T}, b::SRow{T}, c::T, ::Val{left_side} = Val(true)) where {T, left_side}
@assert a !== b
if a === b
a = deepcopy(a)
end
i = 1
j = 1
t = base_ring(a)()
Expand Down Expand Up @@ -735,17 +754,144 @@ function add_scaled_row!(a::SRow{T}, b::SRow{T}, c::T, ::Val{left_side} = Val(tr
return b
end

add_scaled_row!(a::SRow{T}, b::SRow{T}, c) where {T} = add_scaled_row!(a, b, base_ring(a)(c))

add_scaled_row!(a::SRow{T}, b::SRow{T}, c, side::Val) where {T} = add_scaled_row!(a, b, base_ring(a)(c), side)

# ignore tmp argument
add_scaled_row!(a::SRow{T}, b::SRow{T}, c::T, tmp::SRow{T}) where T = add_scaled_row!(a, b, c)
add_scaled_row!(a::SRow{T}, b::SRow{T}, c, tmp::SRow{T}) where T = add_scaled_row!(a, b, c)

add_left_scaled_row!(a::SRow{T}, b::SRow{T}, c::T) where T = add_scaled_row!(a, b, c)
add_left_scaled_row!(a::SRow{T}, b::SRow{T}, c) where T = add_scaled_row!(a, b, c)

@doc raw"""
add_right_scaled_row!(A::SRow{T}, B::SRow{T}, c::T) -> SRow{T}
Return the right scaled row $c A$ to $B$ by changing $B$ in place.
Return the right scaled row $A c$ to $B$ by changing $B$ in place.
"""
add_right_scaled_row!(a::SRow{T}, b::SRow{T}, c::T) where T = add_scaled_row!(a, b, c, Val(false))
add_right_scaled_row!(a::SRow{T}, b::SRow{T}, c) where T = add_scaled_row!(a, b, c, Val(false))


################################################################################
#
# Mutating arithmetics
#
################################################################################

function zero!(z::SRow)
return empty!(z)
end

function neg!(z::SRow{T}, x::SRow{T}) where T
if z === x
return neg!(x)
end
swap!(z, -x)
return z
end

function neg!(z::SRow)
for i in 1:length(z)
z.values[i] = neg!(z.values[i])
end
return z
end

function add!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
if z === x
return add!(x, y)
elseif z === y
return add!(y, x)
end
swap!(z, x + y)
return z
end

function add!(z::SRow{T}, x::SRow{T}) where T
if z === x
return scale_row!(z, 2)
end
return add_scaled_row!(x, z, one(base_ring(x)))
end

function sub!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
if z === x
return sub!(x, y)
elseif z === y
return neg!(sub!(y, x))
end
swap!(z, x - y)
return z
end

function sub!(z::SRow{T}, x::SRow{T}) where T
if z === x
return empty!(z)
end
return add_scaled_row!(x, z, -1)
end

function mul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
error("Not implemented")
end

function mul!(z::SRow{T}, x::SRow{T}, c) where T
if z === x
return scale_row_right!(x, c)
end
swap!(z, x * c)
return z
end

function mul!(z::SRow{T}, c, y::SRow{T}) where T
if z === y
return scale_row_left!(y, c)
end
swap!(z, c * y)
return z
end

function addmul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
error("Not implemented")
end

function addmul!(z::SRow{T}, x::SRow{T}, y) where T
if z === x
return scale_row_right!(x, y+1)
end
return add_right_scaled_row!(x, z, y)
end

function addmul!(z::SRow{T}, x, y::SRow{T}) where T
if z === x
return scale_row_left!(y, x+1)
end
return add_left_scaled_row!(y, z, x)
end

function submul!(z::SRow{T}, x::SRow{T}, y::SRow{T}) where T
error("Not implemented")
end

function submul!(z::SRow{T}, x::SRow{T}, y) where T
if z === x
return scale_row_right!(x, -y+1)
end
return add_right_scaled_row!(x, z, -y)
end

function submul!(z::SRow{T}, x, y::SRow{T}) where T
if z === x
return scale_row_left!(y, -x+1)
end
return add_left_scaled_row!(y, z, -x)
end


# ignore temp variable
addmul!(z::SRow{T}, x::SRow{T}, y, t) where T = addmul!(z, x, y)
addmul!(z::SRow{T}, x, y::SRow{T}, t) where T = addmul!(z, x, y)
submul!(z::SRow{T}, x::SRow{T}, y, t) where T = submul!(z, x, y)
submul!(z::SRow{T}, x, y::SRow{T}, t) where T = submul!(z, x, y)


################################################################################
Expand Down
4 changes: 3 additions & 1 deletion src/Sparse/ZZRow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ end

function add_scaled_row(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingElem, sr::SRow{ZZRingElem} = sparse_row(ZZ))
empty!(sr)
@assert c != 0
n = ZZRingElem()
pi = 1
pj = 1
Expand Down Expand Up @@ -323,6 +322,9 @@ function add_scaled_row(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingEle
end

function add_scaled_row!(Ai::SRow{ZZRingElem}, Aj::SRow{ZZRingElem}, c::ZZRingElem, sr::SRow{ZZRingElem} = sparse_row(ZZ))
if iszero(c)
return Aj
end
_t = sr
sr = add_scaled_row(Ai, Aj, c, sr)
@assert _t === sr
Expand Down
36 changes: 36 additions & 0 deletions test/Sparse/Row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,4 +204,40 @@
B = sparse_row(F,[1],[y])
C = add_scaled_row(A,B,F(1))
@test C == A+B

# mutating arithmetic
randcoeff() = begin
n = rand((1,1,1,2,5,7,15))
return rand(-2^n:2^n)
end
Main.equality(A::SRow, B::SRow) = A == B
@testset "mutating arithmetic; R = $R" for R in (ZZ, QQ)
for _ in 1:10
maxind_A = rand(0:10)
inds_A = Hecke.Random.randsubseq(1:maxind_A, rand())
vals_A = elem_type(R)[R(rand((-1, 1)) * rand(1:10)) for _ in 1:length(inds_A)]
A = sparse_row(R, inds_A, vals_A)

maxind_B = rand(0:10)
inds_B = Hecke.Random.randsubseq(1:maxind_B, rand())
vals_B = elem_type(R)[R(rand((-1, 1)) * rand(1:10)) for _ in 1:length(inds_B)]
B = sparse_row(R, inds_B, vals_B)

test_mutating_op_like_zero(zero, zero!, A)

test_mutating_op_like_neg(-, neg!, A)

test_mutating_op_like_add(+, add!, A, B)
test_mutating_op_like_add(-, sub!, A, B)
test_mutating_op_like_add(*, mul!, A, randcoeff(), SRow)
test_mutating_op_like_add(*, mul!, randcoeff(), A, SRow)
test_mutating_op_like_add(*, mul!, A, ZZ(randcoeff()), SRow)
test_mutating_op_like_add(*, mul!, ZZ(randcoeff()), A, SRow)

test_mutating_op_like_addmul((a, b, c) -> a + b*c, addmul!, A, B, randcoeff(), SRow)
test_mutating_op_like_addmul((a, b, c) -> a + b*c, addmul!, A, randcoeff(), B, SRow)
test_mutating_op_like_addmul((a, b, c) -> a - b*c, submul!, A, B, randcoeff(), SRow)
test_mutating_op_like_addmul((a, b, c) -> a - b*c, submul!, A, randcoeff(), B, SRow)
end
end
end

0 comments on commit ad39e8f

Please sign in to comment.