Skip to content

Commit

Permalink
rough bindings for cupy solver - see #5
Browse files Browse the repository at this point in the history
  • Loading branch information
tvercaut committed Nov 25, 2022
1 parent 015027c commit a50406e
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 3 deletions.
3 changes: 2 additions & 1 deletion tests/test_cupy_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
import numpy as np
import scipy.sparse as nsp


def _c2n(x_cupy):
if tsgucupy.have_cupy:
return cp.asnumpy(x_cupy)
else:
return np.asarray(x_cupy)


class C2TIOTest(unittest.TestCase):
"""IO conversion tests between torch and cupy"""
Expand Down
86 changes: 86 additions & 0 deletions tests/test_cupy_sparse_solve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import unittest
import torchsparsegradutils as tsgu
import torchsparsegradutils.cupy as tsgucupy

import warnings

if tsgucupy.have_cupy:
import cupy as cp
import cupyx.scipy.sparse as csp
else:
warnings.warn(
"Importing optional cupy-related module failed to find cupy -> cupy-related tests running as numpy only."
)

import numpy as np
import scipy.sparse as nsp


class SparseSolveTestC4T(unittest.TestCase):
def setUp(self) -> None:
# The device can be specialised by a daughter class
if not hasattr(self, "device"):
self.device = torch.device("cpu")
self.xp = np
self.xsp = nsp

self.RTOL = 1e-3

self.A_shape = (4, 4)
self.A = torch.randn(self.A_shape, dtype=torch.float64, device=self.device)
self.A = self.A + self.A.t()
self.A_csr = self.A.to_sparse_csr()
self.B_shape = (4, 2)
self.B = torch.randn(self.B_shape, dtype=torch.float64, device=self.device)

self.x_ref = torch.linalg.solve(self.A, self.B)

def test_solver_c4t(self):
x = tsgucupy.sparse_solve_c4t(self.A_csr.to(torch.float32), self.B.to(torch.float32))
self.assertTrue(torch.isclose(x, self.x_ref.to(torch.float32), rtol=self.RTOL).all())

def test_solver_gradient_c4t(self):
# Sparse solver:
As1 = self.A_csr.detach().to(torch.float32).clone()
As1.requires_grad = True
Bd1 = self.B.detach().to(torch.float32).clone()
Bd1.requires_grad = True
As1.retain_grad()
Bd1.retain_grad()
x = tsgucupy.sparse_solve_c4t(As1, Bd1)
loss = x.sum()
loss.backward()

# torch dense solver:
Ad2 = self.A.detach().to(torch.float32).clone()
Ad2.requires_grad = True
Bd2 = self.B.detach().to(torch.float32).clone()
Bd2.requires_grad = True
Ad2.retain_grad()
Bd2.retain_grad()
x2 = torch.linalg.solve(Ad2, Bd2)
loss_torch = x2.sum()
loss_torch.backward()

self.assertTrue(torch.isclose(x, x2, rtol=self.RTOL).all())

nz_mask = As1.grad.to_dense() != 0.0
self.assertTrue(torch.isclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask], rtol=self.RTOL).all())
self.assertTrue(torch.isclose(Bd1.grad, Bd2.grad, rtol=self.RTOL).all())


class SparseSolveTestC4TCUDA(SparseSolveTestC4T):
"""Override superclass setUp to run on GPU"""

def setUp(self) -> None:
if not torch.cuda.is_available():
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")
self.device = torch.device("cuda")
self.xp = cp
self.xsp = csp
super().setUp()


if __name__ == "__main__":
unittest.main()
5 changes: 3 additions & 2 deletions torchsparsegradutils/cupy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
else:
have_cupy = True

from .cupy_bindings import c2t_coo, t2c_coo, c2t_csr, t2c_csr
from .cupy_bindings import c2t_coo, t2c_coo, c2t_csr, t2c_csr, _get_array_modules
from .cupy_sparse_solve import sparse_solve_c4t

__all__ = ["c2t_coo", "t2c_coo", "c2t_csr", "t2c_csr"]
__all__ = ["c2t_coo", "t2c_coo", "c2t_csr", "t2c_csr", "_get_array_modules", "sparse_solve_c4t"]
2 changes: 2 additions & 0 deletions torchsparsegradutils/cupy/cupy_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
if tsgucupy.have_cupy:
import cupy as cp
import cupyx.scipy.sparse as csp
import cupyx.scipy.sparse.linalg

import numpy as np
import scipy.sparse as nsp
import scipy.sparse.linalg

import torch

Expand Down
99 changes: 99 additions & 0 deletions torchsparsegradutils/cupy/cupy_sparse_solve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import torchsparsegradutils.cupy as tsgucupy
import torch


def sparse_solve_c4t(A, B):
return SparseSolveC4T.apply(A, B)


class SparseSolveC4T(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B):
xp, xsp = tsgucupy._get_array_modules(A.data)
grad_flag = A.requires_grad or B.requires_grad

# Transfer data to cupy/scipy
if A.layout == torch.sparse_coo:
A_c = tsgucupy.t2c_coo(A.detach())
elif A.layout == torch.sparse_csr:
A_c = tsgucupy.t2c_csr(A.detach())
else:
raise TypeError(f"Unsupported layout type: {A.layout}")
B_c = xp.asarray(B.detach())

# Solve the sparse system
ctx.factorisedsolver = None
if (B.ndim == 1) or (B.shape[1] == 1):
# xp.sparse.linalg.spsolve only works if B is a vector but is fully on GPU with cupy
x_c = xsp.linalg.spsolve(A_c, B_c)
else:
# Make use of a factorisation (only the solver is then on the GPU with cupy)
# We store it in ctx to reuse it in the backward pass
ctx.factorisedsolver = xsp.linalg.factorized(A_c)
x_c = ctx.factorisedsolver(B_c)

x = torch.as_tensor(x_c, device=A.device)

ctx.save_for_backward(A, x)
x.requires_grad = grad_flag
return x

@staticmethod
def backward(ctx, grad):
A, x = ctx.saved_tensors
xp, xsp = tsgucupy._get_array_modules(A.data)

if A.layout == torch.sparse_coo:
A_c = tsgucupy.t2c_coo(A.detach())
elif A.layout == torch.sparse_csr:
A_c = tsgucupy.t2c_csr(A.detach())
else:
raise TypeError(f"Unsupported layout type: {A.layout}")

x_c = xp.asarray(x.detach())
grad_c = xp.asarray(grad.detach())

# Backprop rule: gradB = A^{-T} grad
if ctx.factorisedsolver is None:
gradB_c = xsp.linalg.spsolve(xp.transpose(A_c), grad_c)
else:
# Re-use factorised solver from forward pass
grad_c = xp.asarray(grad)
gradB_c = ctx.factorisedsolver(grad_c, trans="T")

gradB = torch.as_tensor(gradB_c, device=A.device)

# The gradient with respect to the matrix A seen as a dense matrix would
# lead to a backprop rule as follows
# gradA = -(A^{-T} grad)(A^{-1} B) = - gradB @ x.T
# but we are only interested in the gradient with respect to
# the (non-zero) values of A. To save memory, instead of computing the full
# dense matrix gradB @ x.T and then subsampling at the nnz locations in a,
# we can directly only compute the required values:
# gradA[i,j] = - dotprod(gradB[i,:], x[j,:])

# We start by getting the i and j indices:
if A.layout == torch.sparse_coo:
A_row_idx = A.indices()[0, :]
A_col_idx = A.indices()[1, :]
else:
A_col_idx = A.col_indices()
A_crow_idx = A.crow_indices()
# Uncompress row indices:
A_row_idx = torch.repeat_interleave(
torch.arange(A.size()[0], device=A.device), A_crow_idx[1:] - A_crow_idx[:-1]
)

mgradbselect = -gradB.index_select(0, A_row_idx) # -gradB[i, :]
xselect = x.index_select(0, A_col_idx) # x[j, :]

# Dot product:
mgbx = mgradbselect * xselect
gradA = torch.sum(mgbx, dim=1)

if A.layout == torch.sparse_coo:
gradA = torch.sparse_coo_tensor(torch.stack([A_row_idx, A_col_idx]), gradA, A.shape)
else:
gradA = torch.sparse_csr_tensor(A_crow_idx, A_col_idx, gradA, A.shape)

return gradA, gradB

0 comments on commit a50406e

Please sign in to comment.