Skip to content

Commit

Permalink
batched coo to csr and crow decomp
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 5, 2023
1 parent 55d92f7 commit 41adee8
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 23 deletions.
31 changes: 28 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch
import unittest
from parameterized import parameterized_class, parameterized
from torchsparsegradutils.utils.random_sparse import generate_random_sparse_coo_matrix, generate_random_sparse_strictly_triangular_coo_matrix
from torchsparsegradutils.utils.utils import _compress_row_indices, demcompress_crow_indices, convert_coo_to_csr
from torchsparsegradutils.utils.random_sparse import generate_random_sparse_coo_matrix, generate_random_sparse_csr_matrix

from torchsparsegradutils.utils.utils import (
_compress_row_indices,
demcompress_crow_indices,
_demcompress_crow_indices,
_sort_coo_indices,
convert_coo_to_csr
)


Expand Down Expand Up @@ -102,6 +102,8 @@ def test_compress_row_indices(self, _, size, nnz):
@parameterized.expand([
("4x4_12n", torch.Size([4, 4]), 12),
("8x16_32n", torch.Size([8, 16]), 32),
("4x4x4_12n", torch.Size([4, 4, 4]), 12),
("6x8x14_32n", torch.Size([6, 8, 14]), 32),
])
def test_coo_to_csr_indices(self, _, size, nnz):
A_coo = generate_random_sparse_coo_matrix(size, nnz, device=self.device)
Expand All @@ -120,6 +122,8 @@ def test_coo_to_csr_indices(self, _, size, nnz):
@parameterized.expand([
("4x4_12n", torch.Size([4, 4]), 12),
("8x16_32n", torch.Size([8, 16]), 32),
("4x4x4_12n", torch.Size([4, 4, 4]), 12),
("6x8x14_32n", torch.Size([6, 8, 14]), 32),
])
def test_coo_to_csr_values(self, _, size, nnz):
A_coo = generate_random_sparse_coo_matrix(size, nnz, device=self.device)
Expand All @@ -133,3 +137,24 @@ def test_coo_to_csr_values(self, _, size, nnz):

self.assertTrue(torch.equal(A_csr.values(), A_csr_2.values()))


@parameterized_class(('name', 'device',), [
("CPU", torch.device("cpu")),
("CUDA", torch.device("cuda"),),
])
class TestCSRtoCOO(unittest.TestCase):
def setUp(self) -> None:
if not torch.cuda.is_available() and self.device == torch.device("cuda"):
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")


@parameterized.expand([
("4x4_12n", torch.Size([4, 4]), 12),
("8x16_32n", torch.Size([8, 16]), 32),
])
def test_compress_row_indices(self, _, size, nnz):
A_csr = generate_random_sparse_csr_matrix(size, nnz, device=self.device)
A_coo = A_csr.to_sparse_coo()
A_coo_row_indices = A_coo.indices()[0]
row_indices = _demcompress_crow_indices(A_csr.crow_indices(), A_coo.size()[0])
self.assertTrue(torch.equal(A_coo_row_indices, row_indices))
61 changes: 41 additions & 20 deletions torchsparsegradutils/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

def _sort_coo_indices(indices):
"""
Sorts COO (Coordinate List Format) indices in ascending order and returns a permutation tensor that indicates
the indices in the original data that result in a sorted tensor.
Sorts COO (Coordinate List Format) indices in ascending order and returns a permutation tensor that indicates the indices in the original data that result in a sorted tensor.
This function can support both unbatched and batched COO indices, and essentially performs the same operation as .coalesce() called on a COO tensor.
The advantage is that COO coordinates can be sorted prior to conversion to CSR, without having to use the torch.sparse_coo_tensor object which only supports int64 indices.
Expand Down Expand Up @@ -37,45 +36,67 @@ def _compress_row_indices(row_indices, num_rows):


# Compute the cumulative sum of counts to get CSR indices
crow_indices = torch.cat([torch.zeros(1, dtype=row_indices.dtype, device=counts.device), counts.cumsum(dim=0, dtype=row_indices.dtype)])
crow_indices = torch.cat([torch.zeros(1, dtype=row_indices.dtype, device=counts.device), counts.cumsum_(dim=0)])


return crow_indices


# TODO: add support for batched
def covert_coo_to_csr_indices_values(coo_indices, num_rows, values=None):
def convert_coo_to_csr_indices_values(coo_indices, num_rows, values=None):
"""Converts COO row and column indices to CSR crow and col indices.
Supports batched indices, which would have shape [3, nnz]. Or, [2, nnz] for unbatched indices.
This function sorts the row and column indices (similar to torch.sparse_coo_tensor.coalesce()) and then compresses the row indices to CSR crow indices.
If values are provided, the tensor is permuted according to the sorted COO indices.
If values are provided, the values tensor is permuted according to the sorted COO indices.
If no values are provided, the permutation indices are returned.
Args:
row_indices (torch.Tensor): Tensor of COO row indices.
col_indices (torch.Tensor): Tensor of COO column indices.
coo_indices (torch.Tensor): Tensor of COO indices
num_rows (int): Number of rows in the matrix.
Returns:
torch.Tensor: Compressed CSR crow indices.
torch.Tensor: Compressed CSR col indices.
torch.Tensor: Permutation indices from sorting the col indices. Or permuted values if values are provided.
torch.Tensor: CSR Col indices.
torch.Tensor: Permutation indices from sorting COO indices. Or permuted values if values are provided.
"""
if coo_indices.shape[0] < 2:
raise ValueError(f"Indices tensor must have at least 2 rows (row and column indices). Got {coo_indices.shape[0]} rows.")
elif coo_indices.shape[0] > 3:
raise ValueError(f"Current implementation only supports single batch diomension, therefore indices tensor must have at most 3 rows (batch, row and column indices). Got {coo_indices.shape[0]} rows.")

if coo_indices[-2].max() >= num_rows:
raise ValueError(f"Row indices must be less than num_rows ({num_rows}). Got max row index {coo_indices[-2].max()}")

if values != None and values.shape[0] != coo_indices.shape[1]:
raise ValueError(f"Number of values ({values.shape[0]}) does not match number of indices ({coo_indices.shape[1]})")

coo_indices, permutation = _sort_coo_indices(coo_indices)
row_indices, col_indices = coo_indices
crow_indices = _compress_row_indices(row_indices, num_rows)

if values == None:
return crow_indices, col_indices, permutation
if coo_indices.shape[0] == 2:
row_indices, col_indices = coo_indices
crow_indices = _compress_row_indices(row_indices, num_rows)

values = values[permutation] if values is not None else permutation

else:
return crow_indices, col_indices, values[permutation]
batch_indices, row_indices, col_indices = coo_indices
crow_indices = torch.cat([_compress_row_indices(row_indices[batch_indices == batch], num_rows) for batch in torch.unique(batch_indices)])
num_batches = torch.unique(batch_indices).shape[0]

crow_indices = crow_indices.reshape(num_batches, -1)
col_indices = col_indices.reshape(num_batches, -1)

values = values[permutation] if values is not None else permutation

values = values.reshape(num_batches, -1)

return crow_indices, col_indices, values


# TODO: add support for batched
def convert_coo_to_csr(sparse_coo_tensor):
"""Converts a COO sparse tensor to CSR format.

def convert_coo_to_csr(sparse_coo_tensor):
"""Converts a COO sparse tensor to CSR format. COO tensor can have optional single leading batch dimension.
Args:
sparse_coo_tensor (torch.Tensor): COO sparse tensor.
Expand All @@ -85,14 +106,14 @@ def convert_coo_to_csr(sparse_coo_tensor):
torch.Tensor: CSR sparse tensor.
"""
if sparse_coo_tensor.layout == torch.sparse_coo:
crow_indices, col_indices, values = covert_coo_to_csr_indices_values(sparse_coo_tensor.indices(), sparse_coo_tensor.size()[0], sparse_coo_tensor.values())
crow_indices, col_indices, values = convert_coo_to_csr_indices_values(sparse_coo_tensor.indices(), sparse_coo_tensor.size()[-2], sparse_coo_tensor.values())
return torch.sparse_csr_tensor(crow_indices, col_indices, values, sparse_coo_tensor.size())
else:
raise ValueError(f"Unsupported layout: {sparse_coo_tensor.layout}")



def demcompress_crow_indices(crow_indices, num_rows):
def _demcompress_crow_indices(crow_indices, num_rows):
"""Decompresses CSR crow indices to COO row indices.
Expand Down

0 comments on commit 41adee8

Please sign in to comment.