From 18da3804954e1433fa16e592c3d1e0c4e004aaa4 Mon Sep 17 00:00:00 2001 From: Claus Fieker Date: Fri, 3 Nov 2023 08:52:53 +0100 Subject: [PATCH 1/7] add some triangular ring solver but what do we want? we have solve_triu for fields Ax = b rings (this PR) Ax = b solve_triu_left rings (this PR) xA = b All three accept large "b", ie. can solve multiple equations in one. However, they require "A" to be non-singular square We also have can_solve_left_reduced_triu which can deal with arbitrary matrices "A" in rref, but can only deal with single row "b" solve_triu for fields also has flags for the diagonal being 1 if b and A are square, this is asymptotically not optimal as it will be a n^3/2 algo I can add test - if we want this "interface" Note: for triangular, one cannot transpose to reduce one case to the other Note: in serious use, should also be supplemented by special implementations in Nemo/ Flint for nmod and/or ZZ and friends --- src/Matrix.jl | 83 +++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 78 insertions(+), 5 deletions(-) diff --git a/src/Matrix.jl b/src/Matrix.jl index d8980ba050..cf77ca7837 100644 --- a/src/Matrix.jl +++ b/src/Matrix.jl @@ -3369,9 +3369,9 @@ end @doc raw""" solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T <: FieldElement} -Given a non-singular $n\times n$ matrix over a field which is upper -triangular, and an $n\times m$ matrix over the same field, return an -$n\times m$ matrix $x$ such that $Ax = b$. If $A$ is singular an exception +Given a non-singular $n\times n$ matrix $U$ over a field which is upper +triangular, and an $n\times m$ matrix $b$ over the same field, return an +$n\times m$ matrix $x$ such that $Ux = b$. If $U$ is singular an exception is raised. If unit is true then $U$ is assumed to have ones on its diagonal, and the diagonal will not be read. """ @@ -3411,6 +3411,79 @@ function solve_triu(U::MatElem{T}, b::MatElem{T}, unit::Bool = false) where {T < return X end +@doc raw""" + solve_triu(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement} + +Given a non-singular $n\times n$ matrix $U$ over a field which is upper +triangular, and an $n\times m$ matrix $b$ over the same ring, return an +$n\times m$ matrix $x$ such that $Ux = b$. If this is not possible, an error +will be raised. + +See also [`AbstractAlgebra.solve_triu_left`](@ref) +""" +function solve_triu(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement} + n = nrows(U) + m = ncols(b) + R = base_ring(U) + X = zero(b) + tmp = Vector{elem_type(R)}(undef, n) + t = R() + for i = 1:m + for j = 1:n + tmp[j] = X[j, i] + end + for j = n:-1:1 + for k = j + 1:n +# s = addmul!(s, U[j, k], tmp[k], t) + s = s + U[j, k] * tmp[k] + end + s = b[j, i] - s + tmp[j] = divexact(s, U[j,j]) + end + for j = 1:n + X[j, i] = tmp[j] + end + end + return X +end + +@doc raw""" + solve_triu_left(b::MatElem{T}, U::MatElem{T}) where {T <: RingElement} + +Given a non-singular $n\times n$ matrix $U$ over a field which is upper +triangular, and an $m\times n$ matrix $b$ over the same ring, return an +$m\times n$ matrix $x$ such that $xU = b$. If this is not possible, an error +will be raised. + +See also [`solve_triu`](@ref) or [`can_solve_left_reduced_triu`](@ref) when +$U$ is not square or not of full rank. +""" +function solve_triu_left(b::MatElem{T}, U::MatElem{T}) where {T <: RingElement} + n = ncols(U) + m = nrows(b) + R = base_ring(U) + X = zero(b) + tmp = Vector{elem_type(R)}(undef, n) + t = R() + for i = 1:m + for j = 1:n + tmp[j] = X[i, j] + end + for j = 1:n + s = R() + for k = 1:j-1 + s = addmul!(s, U[k, j], tmp[k], t) + end + s = b[i, j] - s + tmp[j] = divexact(s, U[j,j]) + end + for j = 1:n + X[i, j] = tmp[j] + end + end + return X +end + ############################################################################### # # @@ -3465,8 +3538,8 @@ end Return a tuple `flag, x` where `flag` is set to true if $xM = r$ has a solution, where $M$ is an $m\times n$ matrix in (upper triangular) Hermite normal form or reduced row echelon form and $r$ and $x$ are row vectors with -$m$ columns. If there is no solution, flag is set to `false` and $x$ is set -to the zero row. +$m$ columns (i.e. $1 \times m$ matrices). If there is no solution, flag is set +to `false` and $x$ is set to zero. """ function can_solve_left_reduced_triu(r::MatElem{T}, M::MatElem{T}) where T <: RingElement From 1f81892f557f4bece49f392a432209e6ca542a55 Mon Sep 17 00:00:00 2001 From: Claus Fieker Date: Thu, 30 Nov 2023 17:19:59 +0100 Subject: [PATCH 2/7] add generic fast matrix support in module Strassen --- src/AbstractAlgebra.jl | 2 + src/Matrix-Strassen.jl | 309 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 311 insertions(+) create mode 100644 src/Matrix-Strassen.jl diff --git a/src/AbstractAlgebra.jl b/src/AbstractAlgebra.jl index ae388b33b3..36577dd5f7 100644 --- a/src/AbstractAlgebra.jl +++ b/src/AbstractAlgebra.jl @@ -593,6 +593,7 @@ include("CommonTypes.jl") # types needed by AbstractAlgebra and Generic include("Poly.jl") include("NCPoly.jl") include("Matrix.jl") +include("Matrix-Strassen.jl") include("MatrixAlgebra.jl") include("AbsSeries.jl") include("RelSeries.jl") @@ -1145,6 +1146,7 @@ export solve_triu export solve_with_det export sort_terms! export SparsePolynomialRing +export Strassen export strictly_lower_triangular_matrix export strictly_upper_triangular_matrix export sub diff --git a/src/Matrix-Strassen.jl b/src/Matrix-Strassen.jl new file mode 100644 index 0000000000..95924242ef --- /dev/null +++ b/src/Matrix-Strassen.jl @@ -0,0 +1,309 @@ +""" +Provides generic asymptotically fast matrix methods: + - mul and mul! using the Strassen scheme + - solve_tril! + - lu! + - solve_triu + +Just prefix the function by "Strassen." all 4 functions support a keyword +argument "cutoff" to indicate when the base case should be used. + +The speedup depends on the ring and the entry sizes. + +#Examples: + +```jldoctest +julia> m = matrix(ZZ, rand(-10:10, 1000, 1000)); + +julia> n = similar(m); + +julia> mul!(n, m, m); + +julia> Strassen.mul!(n, m, m); + +julia> Strassen.mul!(n, m, m; cutoff = 100); + +``` +""" +module Strassen +using AbstractAlgebra +import AbstractAlgebra:Perm + +const cutoff = 1500 + +function mul(A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T} + C = zero_matrix(base_ring(A), nrows(A), ncols(B)) + mul!(C, A, B; cutoff) + return C +end + +#scheduling copied from the nmod_mat_mul in Flint +function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T} + sA = size(A) + sB = size(B) + sC = size(C) + a = sA[1] + b = sA[2] + c = sB[2] + + @assert a == sC[1] && b == sB[1] && c == sC[2] + + if (a <= cutoff || b <= cutoff || c <= cutoff) + AbstractAlgebra.mul!(C, A, B) + return + end + + anr = div(a, 2) + anc = div(b, 2) + bnr = anc + bnc = div(c, 2) + + #nmod_mat_window_init(A11, A, 0, 0, anr, anc); + #nmod_mat_window_init(A12, A, 0, anc, anr, 2*anc); + #nmod_mat_window_init(A21, A, anr, 0, 2*anr, anc); + #nmod_mat_window_init(A22, A, anr, anc, 2*anr, 2*anc); + A11 = view(A, 1:anr, 1:anc) + A12 = view(A, 1:anr, anc+1:2*anc) + A21 = view(A, anr+1:2*anr, 1:anc) + A22 = view(A, anr+1:2*anr, anc+1:2*anc) + + #nmod_mat_window_init(B11, B, 0, 0, bnr, bnc); + #nmod_mat_window_init(B12, B, 0, bnc, bnr, 2*bnc); + #nmod_mat_window_init(B21, B, bnr, 0, 2*bnr, bnc); + #nmod_mat_window_init(B22, B, bnr, bnc, 2*bnr, 2*bnc); + B11 = view(B, 1:bnr, 1:bnc) + B12 = view(B, 1:bnr, bnc+1:2*bnc) + B21 = view(B, bnr+1:2*bnr, 1:bnc) + B22 = view(B, bnr+1:2*bnr, bnc+1:2*bnc) + + #nmod_mat_window_init(C11, C, 0, 0, anr, bnc); + #nmod_mat_window_init(C12, C, 0, bnc, anr, 2*bnc); + #nmod_mat_window_init(C21, C, anr, 0, 2*anr, bnc); + #nmod_mat_window_init(C22, C, anr, bnc, 2*anr, 2*bnc); + C11 = view(C, 1:anr, 1:bnc) + C12 = view(C, 1:anr, bnc+1:2*bnc) + C21 = view(C, anr+1:2*anr, 1:bnc) + C22 = view(C, anr+1:2*anr, bnc+1:2*bnc) + + #nmod_mat_init(X1, anr, FLINT_MAX(bnc, anc), A->mod.n); + #nmod_mat_init(X2, anc, bnc, A->mod.n); + + #X1->c = anc; + + #= + See Jean-Guillaume Dumas, Clement Pernet, Wei Zhou; "Memory + efficient scheduling of Strassen-Winograd's matrix multiplication + algorithm"; http://arxiv.org/pdf/0707.2347v3 for reference on the + used operation scheduling. + =# + + X1 = A11 - A21 + X2 = B22 - B12 + #nmod_mat_mul(C21, X1, X2); + mul!(C21, X1, X2; cutoff) + + add!(X1, A21, A22); + sub!(X2, B12, B11); + #nmod_mat_mul(C22, X1, X2); + mul!(C22, X1, X2; cutoff) + + sub!(X1, X1, A11); + sub!(X2, B22, X2); + #nmod_mat_mul(C12, X1, X2); + mul!(C12, X1, X2; cutoff) + + sub!(X1, A12, X1); + #nmod_mat_mul(C11, X1, B22); + mul!(C11, X1, B22; cutoff) + + #X1->c = bnc; + #nmod_mat_mul(X1, A11, B11); + mul!(X1, A11, B11; cutoff) + + add!(C12, X1, C12); + add!(C21, C12, C21); + add!(C12, C12, C22); + add!(C22, C21, C22); + add!(C12, C12, C11); + sub!(X2, X2, B21); + #nmod_mat_mul(C11, A22, X2); + mul!(C11, A22, X2; cutoff) + + sub!(C21, C21, C11); + + #nmod_mat_mul(C11, A12, B21); + mul!(C11, A12, B21; cutoff) + + add!(C11, X1, C11); + + if c > 2*bnc #A by last col of B -> last col of C + #nmod_mat_window_init(Bc, B, 0, 2*bnc, b, c); + Bc = view(B, 1:b, 2*bnc+1:c) + #nmod_mat_window_init(Cc, C, 0, 2*bnc, a, c); + Cc = view(C, 1:a, 2*bnc+1:c) + #nmod_mat_mul(Cc, A, Bc); + AbstractAlgebra.mul!(Cc, A, Bc) + end + + if a > 2*anr #last row of A by B -> last row of C + #nmod_mat_window_init(Ar, A, 2*anr, 0, a, b); + Ar = view(A, 2*anr+1:a, 1:b) + #nmod_mat_window_init(Cr, C, 2*anr, 0, a, c); + Cr = view(C, 2*anr+1:a, 1:c) + #nmod_mat_mul(Cr, Ar, B); + AbstractAlgebra.mul!(Cr, Ar, B) + end + + if b > 2*anc # last col of A by last row of B -> C + #nmod_mat_window_init(Ac, A, 0, 2*anc, 2*anr, b); + Ac = view(A, 1:2*anr, 2*anc+1:b) + #nmod_mat_window_init(Br, B, 2*bnr, 0, b, 2*bnc); + Br = view(B, 2*bnr+1:b, 1:2*bnc) + #nmod_mat_window_init(Cb, C, 0, 0, 2*anr, 2*bnc); + Cb = view(C, 1:2*anr, 1:2*bnc) + #nmod_mat_addmul(Cb, Cb, Ac, Br); + AbstractAlgebra.mul!(Cb, Ac, Br, true) + end +end + +#solve_tril fast, recursive +# A * X = U +# B C Y V +# => X = solve(A, U) +# Y = solve(C, V - B*X) +function solve_tril!(A::MatElem{T}, B::MatElem{T}, C::MatElem{T}, f::Int = 0; cutoff::Int = 2) where T + if nrows(A) < cutoff || ncols(A) < cutoff + return AbstractAlgebra.solve_tril!(A, B, C, f) + end + n = nrows(B) + n2 = div(n, 2) + B11 = view(B, 1:n2, 1:n2) + B21 = view(B, n2+1:n, 1:n2) + B22 = view(B, n2+1:n, n2+1:ncols(B)) + X1 = view(A, 1:n2, 1:ncols(A)) + X2 = view(A, n2+1:n, 1:ncols(A)) + C1 = view(C, 1:n2, 1:ncols(A)) + C2 = view(C, n2+1:n, 1:ncols(A)) + solve_tril!(X1, B11, C1, f; cutoff) + x = B21 * X1 # strassen... + sub!(X2, C2, x) + solve_tril!(X2, B22, X2, f; cutoff) +end + +function apply!(A::MatElem, P::Perm{Int}; offset::Int = 0) + n = length(P.d) + Q = copy(inv(P).d) #the inv is experimentally verified with the other apply + cnt = 0 + start = 0 + while cnt < n + ptr = start = findnext(!iszero, Q, start +1)::Int + next = Q[start] + cnt += 1 + while next != start + swap_rows!(A, ptr, next) + Q[ptr] = 0 + next = Q[next] + cnt += 1 + end + Q[ptr] = 0 + end +end + +function apply!(Q::Perm{Int}, P::Perm{Int}; offset::Int = 0) + n = length(P.d) + t = zeros(Int, n-offset) + for i=1:n-offset + t[i] = Q.d[P.d[i] + offset] + end + for i=1:n-offset + Q.d[i + offset] = t[i] + end +end + +function lu!(P::Perm{Int}, A; cutoff::Int = 300) + m = nrows(A) + + @assert length(P.d) == m + n = ncols(A) + if n < cutoff + return AbstractAlgebra.lu!(P, A) + end + n1 = div(n, 2) + for i=1:m + P.d[i] = i + end + P1 = AbstractAlgebra.Perm(m) + A0 = view(A, 1:m, 1:n1) + r1 = lu!(P1, A0; cutoff) + @assert r1 == n1 + if r1 > 0 + apply!(A, P1) + apply!(P, P1) + end + + A00 = view(A, 1:r1, 1:r1) + A10 = view(A, r1+1:m, 1:r1) + A01 = view(A, 1:r1, n1+1:n) + A11 = view(A, r1+1:m, n1+1:n) + + if r1 > 0 + #Note: A00 is a view of A0 thus a view of A + # A0 is lu!, thus implicitly two triangular matrices giving the + # lu decomosition. solve_tril! looks ONLY at the lower part of A00 + solve_tril!(A01, A00, A01, 1) + X = A10 * A01 + sub!(A11, A11, X) + end + + P1 = Perm(nrows(A11)) + r2 = lu!(P1, A11) + apply!(A, P1, offset = r1) + apply!(P, P1, offset = r1) + + if (r1 != n1) + for i=1:m-r1 + for j=1:min(i, r2) + A[r1+i-1, r1+j-1] = A[r1+i-1, n1+j-1] + A[r1+i-1, n1+j-1] = 0 + end + end + end + return r1 + r2 +end + +function solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff) + #b*inv(T), thus solves Tx = b for T upper triangular + n = ncols(T) + if n <= cutoff + R = AbstractAlgebra.solve_triu(T, b) + return R + end + + n2 = div(n, 2) + n % 2 + m = nrows(b) + m2 = div(m, 2) + m % 2 + + U = view(b, 1:m2, 1:n2) + V = view(b, 1:m2, n2+1:n) + X = view(b, m2+1:m, 1:n2) + Y = view(b, m2+1:m, n2+1:n) + + A = view(T, 1:n2, 1:n2) + B = view(T, 1:n2, 1+n2:n) + C = view(T, 1+n2:n, 1+n2:n) + + S = solve_triu(A, U; cutoff) + R = solve_triu(A, X; cutoff) + + SS = mul(S, B; cutoff) + sub!(SS, V, SS) + SS = solve_triu(C, SS; cutoff) + + RR = mul(R, B; cutoff) + sub!(RR, Y, RR) + RR = solve_triu(C, RR; cutoff) + + return [S SS; R RR] +end + +end # module From ccdaa7d1a7880fdda5cc76f5a17dcbd33ebf68ed Mon Sep 17 00:00:00 2001 From: Claus Fieker Date: Thu, 30 Nov 2023 17:21:31 +0100 Subject: [PATCH 3/7] solve_tril! --- src/Matrix.jl | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/Matrix.jl b/src/Matrix.jl index cf77ca7837..949047a215 100644 --- a/src/Matrix.jl +++ b/src/Matrix.jl @@ -3433,6 +3433,7 @@ function solve_triu(U::MatElem{T}, b::MatElem{T}) where {T <: RingElement} tmp[j] = X[j, i] end for j = n:-1:1 + s = R(0) for k = j + 1:n # s = addmul!(s, U[j, k], tmp[k], t) s = s + U[j, k] * tmp[k] @@ -3484,6 +3485,35 @@ function solve_triu_left(b::MatElem{T}, U::MatElem{T}) where {T <: RingElement} return X end +#solves A x = B for A intended to be lower triangular +#only the lower part is used. if f is true, then the diagonal is assumed to be 1 +#used to use lu! +#can be combined with Strassen.solve_tril! +function solve_tril!(A::MatElem{T}, B::MatElem{T}, C::MatElem{T}, f::Int = 0) where T + + # a x u ax = u + # b c * y = v bx + cy = v + # d e f z w .... + + @assert ncols(A) == ncols(C) + s = base_ring(A)(0) + for i=1:ncols(A) + for j = 1:nrows(A) + t = C[j, i] + for k = 1:j-1 + mul_red!(s, A[k, i], B[j, k], false) + sub!(t, t, s) + end + reduce!(t) + if f == 1 + A[j,i] = t + else + A[j,i] = divexact(t, B[j, j]) + end + end + end +end + ############################################################################### # # From 16a26512dbd3087a57860f7c1dbab13065d97f2c Mon Sep 17 00:00:00 2001 From: Claus Fieker Date: Thu, 30 Nov 2023 17:22:05 +0100 Subject: [PATCH 4/7] sub! and mul! for generic matrices, INPLACE --- src/generic/Matrix.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/generic/Matrix.jl b/src/generic/Matrix.jl index cbee44fb6b..71a70c697b 100644 --- a/src/generic/Matrix.jl +++ b/src/generic/Matrix.jl @@ -211,3 +211,18 @@ function matrix_space(R::AbstractAlgebra.NCRing, r::Int, c::Int; cached::Bool = T = elem_type(R) return MatSpace{T}(R, r, c) end + +function AbstractAlgebra.sub!(A::MatElem{T}, B::MatElem{T}, C::MatElem{T}) where T + A.entries.= B.entries .- C.entries +end + +#since type(view(MatElem{T})) != MatElem{T} which breaks +# sub!(A::T, B::T, C::T) where T in AA +function AbstractAlgebra.mul!(A::MatElem{T}, B::MatElem{T}, C::MatElem{T}, f::Bool = false) where T + if f + A.entries .+= (B * C).entries + else + A.entries .= (B * C).entries + end +end + From 88df7b06a9a514f330796c1ac535984e0e21bab7 Mon Sep 17 00:00:00 2001 From: Claus Fieker Date: Thu, 30 Nov 2023 17:22:37 +0100 Subject: [PATCH 5/7] add test for Strassen --- test/Matrix-test.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/Matrix-test.jl b/test/Matrix-test.jl index fe012a97a0..c7d49e84d5 100644 --- a/test/Matrix-test.jl +++ b/test/Matrix-test.jl @@ -98,3 +98,18 @@ end T = scalar_matrix(QQ, 3, 42) @test T == matrix(QQ, [42 0 0; 0 42 0; 0 0 42]) end + +@testset "Strassen" begin + S = matrix(QQ, rand(-10:10, 100, 100)) + T = S*S + TT = Strassen.mul(S, S; cutoff = 50) + @test T == TT + + P1 = Pern(100) + S1 = deepcopy(S) + r1 = lu!(P1, S1) + P = Perm(100) + r2 = Strassen.lu!(P, S; cutoff = 50) + @test r1 == r2 + @test S1 == S +end From e57a51671e51ae8fea1e6f429f6623f798b3a63d Mon Sep 17 00:00:00 2001 From: Claus Fieker Date: Fri, 1 Dec 2023 14:16:12 +0100 Subject: [PATCH 6/7] fix typo --- test/Matrix-test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Matrix-test.jl b/test/Matrix-test.jl index c7d49e84d5..6d50ba0ffa 100644 --- a/test/Matrix-test.jl +++ b/test/Matrix-test.jl @@ -105,7 +105,7 @@ end TT = Strassen.mul(S, S; cutoff = 50) @test T == TT - P1 = Pern(100) + P1 = Perm(100) S1 = deepcopy(S) r1 = lu!(P1, S1) P = Perm(100) From 1d6093b58b49dc13b00c74ae06a3484c420d498d Mon Sep 17 00:00:00 2001 From: Claus Fieker Date: Fri, 1 Dec 2023 16:40:51 +0100 Subject: [PATCH 7/7] fix doctest --- src/Matrix-Strassen.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Matrix-Strassen.jl b/src/Matrix-Strassen.jl index 95924242ef..bcf62330ab 100644 --- a/src/Matrix-Strassen.jl +++ b/src/Matrix-Strassen.jl @@ -12,7 +12,7 @@ The speedup depends on the ring and the entry sizes. #Examples: -```jldoctest +```jldoctest; setup = :(using AbstractAlgebra) julia> m = matrix(ZZ, rand(-10:10, 1000, 1000)); julia> n = similar(m);