diff --git a/torchsparsegradutils/jax/jax_sparse_solve.py b/torchsparsegradutils/jax/jax_sparse_solve.py index 547fefe..4a18f05 100644 --- a/torchsparsegradutils/jax/jax_sparse_solve.py +++ b/torchsparsegradutils/jax/jax_sparse_solve.py @@ -10,14 +10,22 @@ from .jax_bindings import t2j_csr as _t2j_csr -def sparse_solve_j4t(A, B): - return SparseSolveJ4T.apply(A, B) +def sparse_solve_j4t(A, B, solve=None, transpose_solve=None): + if solve == None or transpose_solve == None: + # Use bicgstab by default + if solve == None: + solve = jax.scipy.sparse.linalg.bicgstab + if transpose_solve == None: + transpose_solve = lambda A, B : jax.scipy.sparse.linalg.bicgstab(A.transpose(), B) + + return SparseSolveJ4T.apply(A, B, solve, transpose_solve) class SparseSolveJ4T(torch.autograd.Function): @staticmethod - def forward(ctx, A, B): + def forward(ctx, A, B, solve, transpose_solve): grad_flag = A.requires_grad or B.requires_grad + ctx.transpose_solve = transpose_solve if A.layout == torch.sparse_coo: A_j = _t2j_coo(A.detach()) @@ -27,11 +35,12 @@ def forward(ctx, A, B): raise TypeError(f"Unsupported layout type: {A.layout}") B_j = _t2j(B.detach()) - x_j, exit_code = jax.scipy.sparse.linalg.cg(A_j, B_j) + x_j, exit_code = solve(A_j, B_j) x = _j2t(x_j) ctx.save_for_backward(A, x) + ctx.A_j = A_j x.requires_grad = grad_flag return x @@ -39,18 +48,11 @@ def forward(ctx, A, B): def backward(ctx, grad): A, x = ctx.saved_tensors - if A.layout == torch.sparse_coo: - A_j = _t2j_coo(A.detach()) - elif A.layout == torch.sparse_csr: - A_j = _t2j_csr(A.detach()) - else: - raise TypeError(f"Unsupported layout type: {A.layout}") - x_j = _t2j(x.detach()) grad_j = _t2j(grad.detach()) # Backprop rule: gradB = A^{-T} grad - gradB_j, exit_code = jax.scipy.sparse.linalg.cg(A_j.transpose(), grad_j) + gradB_j, exit_code = ctx.transpose_solve(ctx.A_j.transpose(), grad_j) gradB = _j2t(gradB_j) # The gradient with respect to the matrix A seen as a dense matrix would @@ -86,4 +88,4 @@ def backward(ctx, grad): else: gradA = torch.sparse_csr_tensor(A_crow_idx, A_col_idx, gradA, A.shape) - return gradA, gradB + return gradA, gradB, None, None diff --git a/torchsparsegradutils/sparse_solve.py b/torchsparsegradutils/sparse_solve.py index 3434220..91f25f8 100644 --- a/torchsparsegradutils/sparse_solve.py +++ b/torchsparsegradutils/sparse_solve.py @@ -102,7 +102,9 @@ def sparse_generic_solve(A, B, solve=None, transpose_solve=None): if solve == None: solve = minres if transpose_solve == None: - transpose_solve = lambda A, B: minres(torch.t(A), B) + # MINRES assumes A to be symmetric -> no need to transpose A + transpose_solve = minres + return SparseGenericSolve.apply(A, B, solve, transpose_solve) @@ -124,7 +126,6 @@ class SparseGenericSolve(torch.autograd.Function): @staticmethod def forward(ctx, A, B, solve, transpose_solve): grad_flag = A.requires_grad or B.requires_grad - ctx.solve = solve ctx.transpose_solve = transpose_solve x = solve(A.detach(), B.detach())