From b34f342ea240612c046a98f7251000955dd959f8 Mon Sep 17 00:00:00 2001 From: Theo Barfoot Date: Tue, 13 Jun 2023 18:48:44 +0100 Subject: [PATCH] add support for batched mm --- torchsparsegradutils/sparse_matmul.py | 67 +++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/torchsparsegradutils/sparse_matmul.py b/torchsparsegradutils/sparse_matmul.py index 5a4958e..6d301d9 100644 --- a/torchsparsegradutils/sparse_matmul.py +++ b/torchsparsegradutils/sparse_matmul.py @@ -1,7 +1,40 @@ import torch +from torchsparsegradutils.utils import sparse_block_diag, sparse_block_diag_split, stack_csr def sparse_mm(A, B): + """ + Performs a matrix multiplication between a sparse matrix A and a dense matrix B, + preserving the sparsity of the gradient with respect to A, permitting sparse backpropagation. + + The sparse matrix A can be in either COO or CSR format, and is expected + to be 2-dimensional, with an optional leading batch dimension. The dense matrix B + should also be 2-dimensional, with a matching optional leading batch dimension. + The batch size must be the same for both A and B. + + Args: + A (torch.Tensor): The sparse matrix in COO or CSR format. + B (torch.Tensor): The dense matrix. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ + + if not isinstance(A, torch.Tensor) or not isinstance(B, torch.Tensor): + raise ValueError("Both A and B should be instances of torch.Tensor") + + if A.dim() < 2 or B.dim() < 2: + raise ValueError("Both A and B should be at least 2-dimensional tensors") + + if A.dim() != B.dim(): + raise ValueError("Both A and B should have the same number of dimensions") + + if A.layout not in {torch.sparse_coo, torch.sparse_csr}: + raise ValueError("A should be in either COO or CSR format") + + if A.dim() == 3 and A.size(0) != B.size(0): + raise ValueError("If A and B have a leading batch dimension, they should have the same batch size") + return SparseMatMul.apply(A, B) @@ -18,11 +51,26 @@ class SparseMatMul(torch.autograd.Function): @staticmethod def forward(ctx, A, B): + ctx.batch_size = B.size()[0] if B.dim() == 3 else None + ctx.A_shape = A.size() # (b), n, m + ctx.B_shape = B.size() # (b), m, p + grad_flag = A.requires_grad or B.requires_grad + A, B = A.detach(), B.detach() + + if ctx.batch_size is not None: + A = sparse_block_diag(*A) + B = torch.cat([*B]) + x = torch.sparse.mm(A, B) - x.requires_grad = grad_flag + ctx.save_for_backward(A, B) + + if ctx.batch_size is not None: + x = x.view(ctx.batch_size, ctx.A_shape[-2], ctx.B_shape[-1]) + + x.requires_grad = grad_flag return x @staticmethod @@ -40,7 +88,7 @@ def backward(ctx, grad): # We start by getting the i and j indices: if A.layout == torch.sparse_coo: - A_row_idx, A_col_idx = A.indices() + A_row_idx, A_col_idx = A._indices() elif A.layout == torch.sparse_csr: A_col_idx = A.col_indices() A_crow_idx = A.crow_indices() @@ -50,6 +98,9 @@ def backward(ctx, grad): ) else: raise ValueError(f"Unsupported layout: {A.layout}") + + if ctx.batch_size is not None: + grad = torch.cat([*grad]) grad_select = grad.index_select(0, A_row_idx) # grad[i, :] B_select = B.index_select(0, A_col_idx) # B[j, :] @@ -60,10 +111,20 @@ def backward(ctx, grad): # Create a sparse matrix of the gradient with respect to the nnz of A if A.layout == torch.sparse_coo: - gradA = torch.sparse_coo_tensor(A.indices(), gradA, A.shape) + gradA = torch.sparse_coo_tensor(A._indices(), gradA, A.shape) elif A.layout == torch.sparse_csr: gradA = torch.sparse_csr_tensor(A_crow_idx, A_col_idx, gradA, A.shape) # Now compute the dense gradient with respect to B gradB = torch.sparse.mm(A.t(), grad) + + if ctx.batch_size is not None: + shapes = ctx.A_shape[0]*(ctx.A_shape[-2:],) + gradA = sparse_block_diag_split(gradA, *shapes) + if A.layout == torch.sparse_coo: + gradA = torch.stack([*gradA]) + else: + gradA = stack_csr([*gradA]) # NOTE: torch.stack does not work for csr tensors + + gradB = gradB.view(ctx.B_shape) return gradA, gradB