Skip to content

Commit

Permalink
Make jax_sparse_solve more generic - addresses #19
Browse files Browse the repository at this point in the history
  • Loading branch information
tvercaut committed Nov 30, 2022
1 parent 70b5148 commit d5a6d4c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 15 deletions.
28 changes: 15 additions & 13 deletions torchsparsegradutils/jax/jax_sparse_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,22 @@
from .jax_bindings import t2j_csr as _t2j_csr


def sparse_solve_j4t(A, B):
return SparseSolveJ4T.apply(A, B)
def sparse_solve_j4t(A, B, solve=None, transpose_solve=None):
if solve == None or transpose_solve == None:
# Use bicgstab by default
if solve == None:
solve = jax.scipy.sparse.linalg.bicgstab
if transpose_solve == None:
transpose_solve = lambda A, B : jax.scipy.sparse.linalg.bicgstab(A.transpose(), B)

return SparseSolveJ4T.apply(A, B, solve, transpose_solve)


class SparseSolveJ4T(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B):
def forward(ctx, A, B, solve, transpose_solve):
grad_flag = A.requires_grad or B.requires_grad
ctx.transpose_solve = transpose_solve

if A.layout == torch.sparse_coo:
A_j = _t2j_coo(A.detach())
Expand All @@ -27,30 +35,24 @@ def forward(ctx, A, B):
raise TypeError(f"Unsupported layout type: {A.layout}")
B_j = _t2j(B.detach())

x_j, exit_code = jax.scipy.sparse.linalg.cg(A_j, B_j)
x_j, exit_code = solve(A_j, B_j)

x = _j2t(x_j)

ctx.save_for_backward(A, x)
ctx.A_j = A_j
x.requires_grad = grad_flag
return x

@staticmethod
def backward(ctx, grad):
A, x = ctx.saved_tensors

if A.layout == torch.sparse_coo:
A_j = _t2j_coo(A.detach())
elif A.layout == torch.sparse_csr:
A_j = _t2j_csr(A.detach())
else:
raise TypeError(f"Unsupported layout type: {A.layout}")

x_j = _t2j(x.detach())
grad_j = _t2j(grad.detach())

# Backprop rule: gradB = A^{-T} grad
gradB_j, exit_code = jax.scipy.sparse.linalg.cg(A_j.transpose(), grad_j)
gradB_j, exit_code = ctx.transpose_solve(ctx.A_j.transpose(), grad_j)
gradB = _j2t(gradB_j)

# The gradient with respect to the matrix A seen as a dense matrix would
Expand Down Expand Up @@ -86,4 +88,4 @@ def backward(ctx, grad):
else:
gradA = torch.sparse_csr_tensor(A_crow_idx, A_col_idx, gradA, A.shape)

return gradA, gradB
return gradA, gradB, None, None
5 changes: 3 additions & 2 deletions torchsparsegradutils/sparse_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ def sparse_generic_solve(A, B, solve=None, transpose_solve=None):
if solve == None:
solve = minres
if transpose_solve == None:
transpose_solve = lambda A, B: minres(torch.t(A), B)
# MINRES assumes A to be symmetric -> no need to transpose A
transpose_solve = minres

return SparseGenericSolve.apply(A, B, solve, transpose_solve)


Expand All @@ -124,7 +126,6 @@ class SparseGenericSolve(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, solve, transpose_solve):
grad_flag = A.requires_grad or B.requires_grad
ctx.solve = solve
ctx.transpose_solve = transpose_solve

x = solve(A.detach(), B.detach())
Expand Down

0 comments on commit d5a6d4c

Please sign in to comment.