-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rough bindings for cupy solver - see #5
- Loading branch information
Showing
5 changed files
with
192 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
import unittest | ||
import torchsparsegradutils as tsgu | ||
import torchsparsegradutils.cupy as tsgucupy | ||
|
||
import warnings | ||
|
||
if tsgucupy.have_cupy: | ||
import cupy as cp | ||
import cupyx.scipy.sparse as csp | ||
else: | ||
warnings.warn( | ||
"Importing optional cupy-related module failed to find cupy -> cupy-related tests running as numpy only." | ||
) | ||
|
||
import numpy as np | ||
import scipy.sparse as nsp | ||
|
||
|
||
class SparseSolveTestC4T(unittest.TestCase): | ||
def setUp(self) -> None: | ||
# The device can be specialised by a daughter class | ||
if not hasattr(self, "device"): | ||
self.device = torch.device("cpu") | ||
self.xp = np | ||
self.xsp = nsp | ||
|
||
self.RTOL = 1e-3 | ||
|
||
self.A_shape = (4, 4) | ||
self.A = torch.randn(self.A_shape, dtype=torch.float64, device=self.device) | ||
self.A = self.A + self.A.t() | ||
self.A_csr = self.A.to_sparse_csr() | ||
self.B_shape = (4, 2) | ||
self.B = torch.randn(self.B_shape, dtype=torch.float64, device=self.device) | ||
|
||
self.x_ref = torch.linalg.solve(self.A, self.B) | ||
|
||
def test_solver_c4t(self): | ||
x = tsgucupy.sparse_solve_c4t(self.A_csr.to(torch.float32), self.B.to(torch.float32)) | ||
self.assertTrue(torch.isclose(x, self.x_ref.to(torch.float32), rtol=self.RTOL).all()) | ||
|
||
def test_solver_gradient_c4t(self): | ||
# Sparse solver: | ||
As1 = self.A_csr.detach().to(torch.float32).clone() | ||
As1.requires_grad = True | ||
Bd1 = self.B.detach().to(torch.float32).clone() | ||
Bd1.requires_grad = True | ||
As1.retain_grad() | ||
Bd1.retain_grad() | ||
x = tsgucupy.sparse_solve_c4t(As1, Bd1) | ||
loss = x.sum() | ||
loss.backward() | ||
|
||
# torch dense solver: | ||
Ad2 = self.A.detach().to(torch.float32).clone() | ||
Ad2.requires_grad = True | ||
Bd2 = self.B.detach().to(torch.float32).clone() | ||
Bd2.requires_grad = True | ||
Ad2.retain_grad() | ||
Bd2.retain_grad() | ||
x2 = torch.linalg.solve(Ad2, Bd2) | ||
loss_torch = x2.sum() | ||
loss_torch.backward() | ||
|
||
self.assertTrue(torch.isclose(x, x2, rtol=self.RTOL).all()) | ||
|
||
nz_mask = As1.grad.to_dense() != 0.0 | ||
self.assertTrue(torch.isclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask], rtol=self.RTOL).all()) | ||
self.assertTrue(torch.isclose(Bd1.grad, Bd2.grad, rtol=self.RTOL).all()) | ||
|
||
|
||
class SparseSolveTestC4TCUDA(SparseSolveTestC4T): | ||
"""Override superclass setUp to run on GPU""" | ||
|
||
def setUp(self) -> None: | ||
if not torch.cuda.is_available(): | ||
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available") | ||
self.device = torch.device("cuda") | ||
self.xp = cp | ||
self.xsp = csp | ||
super().setUp() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import torchsparsegradutils.cupy as tsgucupy | ||
import torch | ||
|
||
|
||
def sparse_solve_c4t(A, B): | ||
return SparseSolveC4T.apply(A, B) | ||
|
||
|
||
class SparseSolveC4T(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, A, B): | ||
xp, xsp = tsgucupy._get_array_modules(A.data) | ||
grad_flag = A.requires_grad or B.requires_grad | ||
|
||
# Transfer data to cupy/scipy | ||
if A.layout == torch.sparse_coo: | ||
A_c = tsgucupy.t2c_coo(A.detach()) | ||
elif A.layout == torch.sparse_csr: | ||
A_c = tsgucupy.t2c_csr(A.detach()) | ||
else: | ||
raise TypeError(f"Unsupported layout type: {A.layout}") | ||
B_c = xp.asarray(B.detach()) | ||
|
||
# Solve the sparse system | ||
ctx.factorisedsolver = None | ||
if (B.ndim == 1) or (B.shape[1] == 1): | ||
# xp.sparse.linalg.spsolve only works if B is a vector but is fully on GPU with cupy | ||
x_c = xsp.linalg.spsolve(A_c, B_c) | ||
else: | ||
# Make use of a factorisation (only the solver is then on the GPU with cupy) | ||
# We store it in ctx to reuse it in the backward pass | ||
ctx.factorisedsolver = xsp.linalg.factorized(A_c) | ||
x_c = ctx.factorisedsolver(B_c) | ||
|
||
x = torch.as_tensor(x_c, device=A.device) | ||
|
||
ctx.save_for_backward(A, x) | ||
x.requires_grad = grad_flag | ||
return x | ||
|
||
@staticmethod | ||
def backward(ctx, grad): | ||
A, x = ctx.saved_tensors | ||
xp, xsp = tsgucupy._get_array_modules(A.data) | ||
|
||
if A.layout == torch.sparse_coo: | ||
A_c = tsgucupy.t2c_coo(A.detach()) | ||
elif A.layout == torch.sparse_csr: | ||
A_c = tsgucupy.t2c_csr(A.detach()) | ||
else: | ||
raise TypeError(f"Unsupported layout type: {A.layout}") | ||
|
||
x_c = xp.asarray(x.detach()) | ||
grad_c = xp.asarray(grad.detach()) | ||
|
||
# Backprop rule: gradB = A^{-T} grad | ||
if ctx.factorisedsolver is None: | ||
gradB_c = xsp.linalg.spsolve(xp.transpose(A_c), grad_c) | ||
else: | ||
# Re-use factorised solver from forward pass | ||
grad_c = xp.asarray(grad) | ||
gradB_c = ctx.factorisedsolver(grad_c, trans="T") | ||
|
||
gradB = torch.as_tensor(gradB_c, device=A.device) | ||
|
||
# The gradient with respect to the matrix A seen as a dense matrix would | ||
# lead to a backprop rule as follows | ||
# gradA = -(A^{-T} grad)(A^{-1} B) = - gradB @ x.T | ||
# but we are only interested in the gradient with respect to | ||
# the (non-zero) values of A. To save memory, instead of computing the full | ||
# dense matrix gradB @ x.T and then subsampling at the nnz locations in a, | ||
# we can directly only compute the required values: | ||
# gradA[i,j] = - dotprod(gradB[i,:], x[j,:]) | ||
|
||
# We start by getting the i and j indices: | ||
if A.layout == torch.sparse_coo: | ||
A_row_idx = A.indices()[0, :] | ||
A_col_idx = A.indices()[1, :] | ||
else: | ||
A_col_idx = A.col_indices() | ||
A_crow_idx = A.crow_indices() | ||
# Uncompress row indices: | ||
A_row_idx = torch.repeat_interleave( | ||
torch.arange(A.size()[0], device=A.device), A_crow_idx[1:] - A_crow_idx[:-1] | ||
) | ||
|
||
mgradbselect = -gradB.index_select(0, A_row_idx) # -gradB[i, :] | ||
xselect = x.index_select(0, A_col_idx) # x[j, :] | ||
|
||
# Dot product: | ||
mgbx = mgradbselect * xselect | ||
gradA = torch.sum(mgbx, dim=1) | ||
|
||
if A.layout == torch.sparse_coo: | ||
gradA = torch.sparse_coo_tensor(torch.stack([A_row_idx, A_col_idx]), gradA, A.shape) | ||
else: | ||
gradA = torch.sparse_csr_tensor(A_crow_idx, A_col_idx, gradA, A.shape) | ||
|
||
return gradA, gradB |