Skip to content

Commit

Permalink
fully working triangular solve
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 15, 2023
1 parent 84e00af commit b0636f4
Showing 1 changed file with 75 additions and 10 deletions.
85 changes: 75 additions & 10 deletions torchsparsegradutils/sparse_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,68 @@


def sparse_triangular_solve(A, B, upper=True, unitriangular=False):
"""
Solves a system of equations given by AX = B, where A is a sparse triangular matrix,
and B is a dense right-hand side matrix.
This function accepts both batched and unbatched inputs, and can work with either upper
or lower triangular matrices.
A can be in either COO (Coordinate Format) or CSR (Compressed Sparse Row) format. However,
if it is in COO format, it will be converted to CSR format before solving as the
triangular solve operation doesn't work well with COO format.
This function supports backpropagation, and preserves the sparsity of the gradient during
the backpass.
Args:
A (torch.Tensor): The left-hand side sparse triangular matrix. Must be a 2-dimensional
(matrix) or 3-dimensional (batch of matrices) tensor, and must be in
either COO or CSR format.
B (torch.Tensor): The right-hand side dense matrix. Must be a 2-dimensional (matrix) or
3-dimensional (batch of matrices) tensor.
upper (bool, optional): If True, A is assumed to be an upper triangular matrix. If False,
A is assumed to be a lower triangular matrix. Default is True.
unitriangular (bool, optional): If True, the diagonal elements of A are assumed to be 1
and are not used in the solve operation. Default is False.
Returns:
torch.Tensor: The solution of the system of equations.
Raises:
ValueError: If A and B are not both torch.Tensor instances, or if they don't have the same
number of dimensions, or if they are not at least 2-dimensional, or if A is not
in COO or CSR format, or if A and B are batched but don't have the same batch size.
Note:
The gradient with respect to the sparse matrix A is computed only for its
non-zero values to save memory.
For the backpropagation, a workaround is implemented for a known issue with
torch.triangular_solve on the CPU for lower triangular matrices. This issue and the
subsequent workaround are relevant only for PyTorch versions lower than 2.0 (see PyTorch
issue #88890).
References:
https://github.com/pytorch/pytorch/issues/87358
https://github.com/pytorch/pytorch/issues/88890
"""

if not isinstance(A, torch.Tensor) or not isinstance(B, torch.Tensor):
raise ValueError("Both A and B should be instances of torch.Tensor")

if A.dim() < 2 or B.dim() < 2:
raise ValueError("Both A and B should be at least 2-dimensional tensors")

if A.dim() != B.dim():
raise ValueError("Both A and B should have the same number of dimensions")

if A.layout not in {torch.sparse_coo, torch.sparse_csr}:
raise ValueError("A should be in either COO or CSR format")

if A.dim() == 3 and A.size(0) != B.size(0):
raise ValueError("If A and B have a leading batch dimension, they should have the same batch size")

return SparseTriangularSolve.apply(A, B, upper, unitriangular)


Expand Down Expand Up @@ -30,9 +92,9 @@ def forward(ctx, A, B, upper, unitriangular):
ctx.csr = True
ctx.upper = upper
ctx.ut = unitriangular

grad_flag = A.requires_grad or B.requires_grad

if ctx.batch_size is not None:
A = sparse_block_diag(*A)
B = torch.cat([*B])
Expand All @@ -43,23 +105,26 @@ def forward(ctx, A, B, upper, unitriangular):

x = torch.triangular_solve(B.detach(), A.detach(), upper=upper, unitriangular=unitriangular).solution

if ctx.batch_size is not None:
x = x.view(ctx.batch_size, ctx.A_shape[-2], ctx.B_shape[-1])

x.requires_grad = grad_flag
ctx.save_for_backward(A, x.detach())

if ctx.batch_size is not None:
x = x.view(ctx.batch_size, ctx.A_shape[-2], ctx.B_shape[-1])

return x

@staticmethod
def backward(ctx, grad):
if ctx.batch_size is not None:
grad = torch.cat([*grad])

A, x = ctx.saved_tensors

# Backprop rule: gradB = A^{-T} grad
# Check if a workaround for https://github.com/pytorch/pytorch/issues/88890 is needed
workaround88890 = A.device == torch.device("cpu") and (not ctx.upper) and ctx.ut
workaround88890 = (
A.device == torch.device("cpu") and (not ctx.upper) and ctx.ut and (int(torch.__version__[0]) < 2)
)
if not workaround88890:
gradB = torch.triangular_solve(grad, A, upper=ctx.upper, transpose=True, unitriangular=ctx.ut).solution
else:
Expand Down Expand Up @@ -106,16 +171,16 @@ def backward(ctx, grad):
gradA = torch.sparse_coo_tensor(torch.stack([A_row_idx, A_col_idx]), gradA, A.shape)
else:
gradA = torch.sparse_csr_tensor(A_crow_idx, A_col_idx, gradA, A.shape)

if ctx.batch_size is not None:
shapes = ctx.A_shape[0] * (ctx.A_shape[-2:],)
gradA = sparse_block_diag_split(gradA, *shapes)
if A.layout == torch.sparse_coo:
if not ctx.csr:
gradA = torch.stack([*gradA])
else:
gradA = stack_csr([*gradA])

gradB = gradB.view(ctx.B_shape)
gradB = gradB.view(ctx.B_shape)

return gradA, gradB, None, None

Expand Down

0 comments on commit b0636f4

Please sign in to comment.