Skip to content

Commit

Permalink
feat: make Vector{fmpz} * fmpz_mat faster
Browse files Browse the repository at this point in the history
  • Loading branch information
thofma committed Nov 20, 2024
1 parent 80d1172 commit f81da0b
Showing 1 changed file with 90 additions and 10 deletions.
100 changes: 90 additions & 10 deletions src/flint/fmpz_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1784,16 +1784,6 @@ end
addmul!(z::ZZMatrixOrPtr, a::ZZMatrixOrPtr, b::Integer) = addmul!(z, a, flintify(b))
addmul!(z::ZZMatrixOrPtr, a::IntegerUnionOrPtr, b::ZZMatrixOrPtr) = addmul!(z, b, a)

function mul!(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem})
@ccall libflint.fmpz_mat_mul_fmpz_vec_ptr(z::Ptr{Ref{ZZRingElem}}, a::Ref{ZZMatrix}, b::Ptr{Ref{ZZRingElem}}, length(b)::Int)::Nothing
return z
end

function mul!(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr)
@ccall libflint.fmpz_mat_fmpz_vec_mul_ptr(z::Ptr{Ref{ZZRingElem}}, a::Ptr{Ref{ZZRingElem}}, length(a)::Int, b::Ref{ZZMatrix})::Nothing
return z
end

function Generic.add_one!(a::ZZMatrix, i::Int, j::Int)
@boundscheck _checkbounds(a, i, j)
GC.@preserve a begin
Expand All @@ -1819,6 +1809,96 @@ function shift!(g::ZZMatrix, l::Int)
return g
end

################################################################################
#
# Vector * Matrix and Matrix * Vector
#
################################################################################

# Vector{fmpz} * fmpz_mat can be performed using
# - fmpz_mat_fmpz_vec_mul_ptr
# - or conversion + fmpz_mat_mul
#
# The fmpz_mat_fmpz_vec_mul_ptr variants are not optimized.
# Thus, if the conversion is negliable, we convert and call fmpz_mat.
# The conversion is done on the julia side, trying to reduce the number of
# allocations and objects tracked by the GC.

function _very_unsafe_convert(::Type{ZZMatrix}, a::Vector{ZZRingElem}, row = true)
# a must be GC.@preserved
# row = true -> make it a row
# row = false -> make it a column
M = Nemo.@new_struct(ZZMatrix)
Me = zeros(Int, length(a))
M.entries = reinterpret(Ptr{ZZRingElem}, pointer(Me))
if row
Mep = [pointer(Me)]
M.rows = reinterpret(Ptr{Ptr{ZZRingElem}}, pointer(Mep))
M.r = 1
M.c = length(a)
else
M.r = length(a)
M.c = 1
Mep = [pointer(Me) + 8*(i - 1) for i in 1:length(a)]
M.rows = reinterpret(Ptr{Ptr{ZZRingElem}}, pointer(Mep))
end
for i in 1:length(a)
Me[i] = a[i].d
end
return M, Me, Mep
end

function mul!(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem})
# cutoff for the flint method
if nrows(a) < 50 && maximum(nbits, a) < 10
return mul!_flint(z, a, b)
end

GC.@preserve z b begin
bb, dk1, dk2 = _very_unsafe_convert(ZZMatrix, b, false)
zz, dk3, dk4 = _very_unsafe_convert(ZZMatrix, z, false)
GC.@preserve dk1 dk2 dk3 dk4 begin
mul!(zz, a, bb)
for i in 1:length(z)
z[i].d = unsafe_load(zz.entries, i).d
end
end
end
return z
end

function mul!_flint(z::Vector{ZZRingElem}, a::ZZMatrixOrPtr, b::Vector{ZZRingElem})
ccall((:fmpz_mat_mul_fmpz_vec_ptr, libflint), Nothing,
(Ptr{Ref{ZZRingElem}}, Ref{ZZMatrix}, Ptr{Ref{ZZRingElem}}, Int),
z, a, b, length(b))
return z
end

function mul!(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr)
# cutoff for the flint method
if nrows(b) < 50 && maximum(nbits, b) < 10
return mul!_flint(z, a, b)
end
GC.@preserve z a begin
aa, dk1, dk2 = _very_unsafe_convert(ZZMatrix, a)
zz, dk3, dk4 = _very_unsafe_convert(ZZMatrix, z)
GC.@preserve dk1 dk2 dk3 dk4 begin
mul!(zz, aa, b)
for i in 1:length(z)
z[i].d = unsafe_load(zz.entries, i).d
end
end
end
return z
end

function mul!_flint(z::Vector{ZZRingElem}, a::Vector{ZZRingElem}, b::ZZMatrixOrPtr)
ccall((:fmpz_mat_fmpz_vec_mul_ptr, libflint), Nothing,
(Ptr{Ref{ZZRingElem}}, Ptr{Ref{ZZRingElem}}, Int, Ref{ZZMatrix}),
z, a, length(a), b)
return z
end

###############################################################################
#
# Parent object call overloads
Expand Down

0 comments on commit f81da0b

Please sign in to comment.