Skip to content

Commit

Permalink
Initial implementation of backward pass for sparse linear least squar…
Browse files Browse the repository at this point in the history
…es (#28)

* Initial implementation of backward pass for sparse linear least squares

* ran black
  • Loading branch information
tvercaut authored Dec 12, 2022
1 parent acbbf9e commit 4c6f47a
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 1 deletion.
78 changes: 78 additions & 0 deletions tests/test_sparse_lstsq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
import unittest

from torchsparsegradutils import sparse_generic_lstsq


class SparseGenericLstsqTest(unittest.TestCase):
def setUp(self) -> None:
# The device can be specialised by a daughter class
if not hasattr(self, "device"):
self.device = torch.device("cpu")

self.RTOL = 1e-2

self.A_shape = (7, 4)
self.A = torch.randn(self.A_shape, dtype=torch.float64, device=self.device)
self.A_csr = self.A.to_sparse_csr()
self.B_shape = (7, 1)
self.B = torch.randn(self.B_shape, dtype=torch.float64, device=self.device)

self.x_ref = torch.linalg.lstsq(self.A, self.B).solution

def test_generic_lstsq_default(self):
x = sparse_generic_lstsq(self.A_csr, self.B)
# print("x",x)
# print("self.x_ref",self.x_ref)
self.assertTrue(torch.isclose(x, self.x_ref, rtol=self.RTOL).all())

def test_generic_lstsq_gradient_default(self):
# Sparse lstsq:
As1 = self.A_csr.detach().clone()
As1.requires_grad = True
Bd1 = self.B.detach().clone()
Bd1.requires_grad = True
As1.retain_grad()
Bd1.retain_grad()
x = sparse_generic_lstsq(As1, Bd1)
loss = x.sum()
loss.backward()

# torch dense lstsq:
Ad2 = self.A.detach().clone()
Ad2.requires_grad = True
Bd2 = self.B.detach().clone()
Bd2.requires_grad = True
Ad2.retain_grad()
Bd2.retain_grad()
x2 = torch.linalg.lstsq(Ad2, Bd2).solution
loss_torch = x2.sum()
loss_torch.backward()

# print("x",x)
# print("x2",x2)

self.assertTrue(torch.isclose(x, x2, rtol=self.RTOL).all())

# print("Bd1.grad",Bd1.grad)
# print("Bd2.grad",Bd2.grad)
# print("As1.grad.to_dense()",As1.grad.to_dense())
# print("Ad2.grad",Ad2.grad)

nz_mask = As1.grad.to_dense() != 0.0
self.assertTrue(torch.isclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask], rtol=self.RTOL).all())
self.assertTrue(torch.isclose(Bd1.grad, Bd2.grad, rtol=self.RTOL).all())


class SparseGenericLstsqTestCUDA(SparseGenericLstsqTest):
"""Override superclass setUp to run on GPU"""

def setUp(self) -> None:
if not torch.cuda.is_available():
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")
self.device = torch.device("cuda")
super().setUp()


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion torchsparsegradutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .sparse_matmul import sparse_mm
from .sparse_solve import sparse_triangular_solve, sparse_generic_solve
from .sparse_lstsq import sparse_generic_lstsq

__all__ = ["sparse_mm", "sparse_triangular_solve", "sparse_generic_solve"]
__all__ = ["sparse_mm", "sparse_triangular_solve", "sparse_generic_solve", "sparse_generic_lstsq"]
136 changes: 136 additions & 0 deletions torchsparsegradutils/sparse_lstsq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch


def sparse_generic_lstsq(A, B, lstsq=None, transpose_lstsq=None):
if lstsq == None or transpose_lstsq == None:
from .utils import lsmr

if lstsq == None:
lstsq = lambda AA, BB: lsmr(AA, BB)[0]
if transpose_lstsq == None:
# MINRES assumes A to be symmetric -> no need to transpose A
transpose_lstsq = lambda AA, BB: lsmr(torch.adjoint(AA), BB, AA)[0]

return SparseGenericLstsq.apply(A, B, lstsq, transpose_lstsq)


class SparseGenericLstsq(torch.autograd.Function):
"""
Solves a linear least squares problem with a full-rank, tall
sparse matrix A and dense right-hand side matrix B,
with backpropagation support
Solves: min_x || Ax - B ||^2
A can be in either COO or CSR format.
lstsq: higher level function that solves for the linear least squares problem. This function need not be differentiable.
transpose_lstsq: higher level function for solving the transpose linear least squares problem. This function need not be differentiable.
This implementation preserves the sparsity of the gradient. We make use of the derivation in
Golub GH, Pereyra V. The differentiation of pseudo-inverses and nonlinear least squares problems whose variables separate.
SIAM Journal on numerical analysis. 1973 Apr;10(2):413-32.
We also assume that A is tall and full-rank so that A^+ A = Id where A^+ is the pseudo-inverse of A
"""

@staticmethod
def forward(ctx, A, B, lstsq, transpose_lstsq):
grad_flag = A.requires_grad or B.requires_grad
ctx.lstsq = lstsq
ctx.transpose_lstsq = transpose_lstsq

x = lstsq(A.detach(), B.detach())

x.requires_grad = grad_flag

if B.dim() == 1:
if x.dim() == 2:
x = x.squeeze()
else:
if x.dim() == 1:
x = x.unsqueeze(1)

ctx.save_for_backward(A.detach(), B.detach(), x.detach())
return x

@staticmethod
def backward(ctx, grad):
A, B, x = ctx.saved_tensors
if B.dim() == 1:
B = B.unsqueeze(1)
if x.dim() == 1:
x = x.unsqueeze(1)

# Backprop rule: gradB = (A^T)^{+} grad
gradB = ctx.transpose_lstsq(A, grad)
if gradB.dim() == 1:
gradB = gradB.unsqueeze(1)

# We make use of equation 4.12 in https://www.jstor.org/stable/2156365
# but assume A is tall and full rank to get A^+ A = Id and simplify the derivation.
# We don't try and compute the rank of A for computational reason but at least check
# that A is a tall matrix
if A.shape[1] > A.shape[0]:
raise ValueError(f"A should be a tall full-rank matrix. Got A.shape={A.shape}")
# Following the derivation in https://blog.flaport.net/solving-sparse-linear-systems-in-pytorch.html
# but using the pseudo-inverse instead of the inverse:
# The gradient with respect to the matrix A seen as a dense matrix would
# lead to a backprop rule as follows
# gradA = -((A^T)^{+} grad)(A^{+} B) - (Ax-B)(A^+ (A^T)^{+} grad )
# = - gradB @ x.T - (Ax-B) @ (A^+ gradB).T
# but we are only interested in the gradient with respect to
# the (non-zero) values of A. To save memory, instead of computing the full
# dense matrices gradB @ x.T and (Ax-B) @ (A^+ gradB).T
# and then subsampling at the nnz locations in A,
# we can directly only compute the required values:
# gradA_u1[i,j] = - dotprod(gradB[i,:], x[j,:])
# gradA_u2[i,j] = - dotprod(residuals[i,:], (A^+ gradB)[j,:])

# Dense equivalent
# gradA_u1 = - gradB @ torch.t(x)
# mresiduals = B - A@x
# Apgb = ctx.lstsq(A,gradB)
# if Apgb.dim() == 1:
# Apgb = Apgb.unsqueeze(1)
# gradA_u2 = mresiduals @ torch.t(Apgb)
# gradA = gradA_u1 + gradA_u2
# return gradA, gradB, None, None

# We start by getting the i and j indices:
if A.layout == torch.sparse_coo:
A_row_idx = A.indices()[0, :]
A_col_idx = A.indices()[1, :]
else:
A_col_idx = A.col_indices()
A_crow_idx = A.crow_indices()
# Uncompress row indices:
A_row_idx = torch.repeat_interleave(
torch.arange(A.size()[0], device=A.device), A_crow_idx[1:] - A_crow_idx[:-1]
)

mgradbselect = -gradB.index_select(0, A_row_idx) # -gradB[i, :]
xselect = x.index_select(0, A_col_idx) # x[j, :]

# Dot product:
mgbx = mgradbselect * xselect
gradA_u1 = torch.sum(mgbx, dim=1)

# residuals
mresiduals = B - A @ x
mresidualsselect = mresiduals.index_select(0, A_row_idx)
Apgb = ctx.lstsq(A, gradB)
if Apgb.dim() == 1:
Apgb = Apgb.unsqueeze(1)
Apgbselect = Apgb.index_select(0, A_col_idx)

# Dot product:
mresApgb = mresidualsselect * Apgbselect
gradA_u2 = torch.sum(mresApgb, dim=1)

gradA = gradA_u1 + gradA_u2

if A.layout == torch.sparse_coo:
gradA = torch.sparse_coo_tensor(torch.stack([A_row_idx, A_col_idx]), gradA, A.shape)
else:
gradA = torch.sparse_csr_tensor(A_crow_idx, A_col_idx, gradA, A.shape)

return gradA, gradB, None, None

0 comments on commit 4c6f47a

Please sign in to comment.