Skip to content

Commit

Permalink
add support for batched mm
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 13, 2023
1 parent 67cb9cb commit b34f342
Showing 1 changed file with 64 additions and 3 deletions.
67 changes: 64 additions & 3 deletions torchsparsegradutils/sparse_matmul.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,40 @@
import torch
from torchsparsegradutils.utils import sparse_block_diag, sparse_block_diag_split, stack_csr


def sparse_mm(A, B):
"""
Performs a matrix multiplication between a sparse matrix A and a dense matrix B,
preserving the sparsity of the gradient with respect to A, permitting sparse backpropagation.
The sparse matrix A can be in either COO or CSR format, and is expected
to be 2-dimensional, with an optional leading batch dimension. The dense matrix B
should also be 2-dimensional, with a matching optional leading batch dimension.
The batch size must be the same for both A and B.
Args:
A (torch.Tensor): The sparse matrix in COO or CSR format.
B (torch.Tensor): The dense matrix.
Returns:
torch.Tensor: The result of the matrix multiplication.
"""

if not isinstance(A, torch.Tensor) or not isinstance(B, torch.Tensor):
raise ValueError("Both A and B should be instances of torch.Tensor")

if A.dim() < 2 or B.dim() < 2:
raise ValueError("Both A and B should be at least 2-dimensional tensors")

if A.dim() != B.dim():
raise ValueError("Both A and B should have the same number of dimensions")

if A.layout not in {torch.sparse_coo, torch.sparse_csr}:
raise ValueError("A should be in either COO or CSR format")

if A.dim() == 3 and A.size(0) != B.size(0):
raise ValueError("If A and B have a leading batch dimension, they should have the same batch size")

return SparseMatMul.apply(A, B)


Expand All @@ -18,11 +51,26 @@ class SparseMatMul(torch.autograd.Function):

@staticmethod
def forward(ctx, A, B):
ctx.batch_size = B.size()[0] if B.dim() == 3 else None
ctx.A_shape = A.size() # (b), n, m
ctx.B_shape = B.size() # (b), m, p

grad_flag = A.requires_grad or B.requires_grad

A, B = A.detach(), B.detach()

if ctx.batch_size is not None:
A = sparse_block_diag(*A)
B = torch.cat([*B])

x = torch.sparse.mm(A, B)
x.requires_grad = grad_flag

ctx.save_for_backward(A, B)

if ctx.batch_size is not None:
x = x.view(ctx.batch_size, ctx.A_shape[-2], ctx.B_shape[-1])

x.requires_grad = grad_flag
return x

@staticmethod
Expand All @@ -40,7 +88,7 @@ def backward(ctx, grad):
# We start by getting the i and j indices:

if A.layout == torch.sparse_coo:
A_row_idx, A_col_idx = A.indices()
A_row_idx, A_col_idx = A._indices()
elif A.layout == torch.sparse_csr:
A_col_idx = A.col_indices()
A_crow_idx = A.crow_indices()
Expand All @@ -50,6 +98,9 @@ def backward(ctx, grad):
)
else:
raise ValueError(f"Unsupported layout: {A.layout}")

if ctx.batch_size is not None:
grad = torch.cat([*grad])

grad_select = grad.index_select(0, A_row_idx) # grad[i, :]
B_select = B.index_select(0, A_col_idx) # B[j, :]
Expand All @@ -60,10 +111,20 @@ def backward(ctx, grad):

# Create a sparse matrix of the gradient with respect to the nnz of A
if A.layout == torch.sparse_coo:
gradA = torch.sparse_coo_tensor(A.indices(), gradA, A.shape)
gradA = torch.sparse_coo_tensor(A._indices(), gradA, A.shape)
elif A.layout == torch.sparse_csr:
gradA = torch.sparse_csr_tensor(A_crow_idx, A_col_idx, gradA, A.shape)

# Now compute the dense gradient with respect to B
gradB = torch.sparse.mm(A.t(), grad)

if ctx.batch_size is not None:
shapes = ctx.A_shape[0]*(ctx.A_shape[-2:],)
gradA = sparse_block_diag_split(gradA, *shapes)
if A.layout == torch.sparse_coo:
gradA = torch.stack([*gradA])
else:
gradA = stack_csr([*gradA]) # NOTE: torch.stack does not work for csr tensors

gradB = gradB.view(ctx.B_shape)
return gradA, gradB

0 comments on commit b34f342

Please sign in to comment.