Skip to content

Commit

Permalink
sort coo indices correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 5, 2023
1 parent f87f821 commit fba138e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 9 deletions.
41 changes: 41 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,49 @@
from torchsparsegradutils.utils.utils import (
_compress_row_indices,
demcompress_crow_indices,
_sort_coo_indices,
)


@parameterized_class(('name', 'device',), [
("CPU", torch.device("cpu")),
("CUDA", torch.device("cuda"),),
])
class TestSortCOOIndices(unittest.TestCase):
def setUp(self) -> None:
if not torch.cuda.is_available() and self.device == torch.device("cuda"):
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")

def test_unbatched_sort(self):
nr, nc = 4, 4
indices = torch.randperm(nr * nc, device=self.device)
indices = torch.stack([indices // nc, indices % nc])

values = torch.arange(16, device=self.device)
sorted_indices_coalesced = torch.sparse_coo_tensor(indices, values).coalesce().indices()
coalesce_permutation = torch.sparse_coo_tensor(indices, values).coalesce().values()
sorted_indices, permutation = _sort_coo_indices(indices)
self.assertTrue(torch.equal(sorted_indices_coalesced, sorted_indices))
self.assertTrue(torch.equal(coalesce_permutation, permutation))

def test_batched_sort(self):
nr, nc = 4, 4
batch_size = 3
indices = torch.randperm(nr * nc, device=self.device)
indices = torch.stack([indices // nc, indices % nc])
sparse_indices = torch.cat([indices] * batch_size, dim=-1)
batch_indices = torch.arange(batch_size, device=self.device).repeat(16).unsqueeze(0)
batched_sparse_indices = torch.cat([batch_indices, sparse_indices])

values = torch.arange(nr * nc * batch_size, device=self.device)
sorted_indices_coalesced = torch.sparse_coo_tensor(batched_sparse_indices, values).coalesce().indices()
coalesce_permutation = torch.sparse_coo_tensor(batched_sparse_indices, values).coalesce().values()

sorted_indices, permutation = _sort_coo_indices(batched_sparse_indices)
self.assertTrue(torch.equal(sorted_indices_coalesced, sorted_indices))
self.assertTrue(torch.equal(coalesce_permutation, permutation))


@parameterized_class(('name', 'device',), [
("CPU", torch.device("cpu")),
("CUDA", torch.device("cuda"),),
Expand Down
35 changes: 26 additions & 9 deletions torchsparsegradutils/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
import torch


def _sort_coo_indices(indices):
"""
Sorts COO (Coordinate List Format) indices in ascending order and returns a permutation tensor that indicates
the indices in the original data that result in a sorted tensor.
This function can support both unbatched and batched COO indices, and essentially performs the same operation as .coalesce() called on a COO tensor.
The advantage is that COO coordinates can be sorted prior to conversion to CSR, without having to use the torch.sparse_coo_tensor object which only supports int64 indices.
Args:
indices (torch.Tensor): The input indices in COO format to be sorted.
Returns:
torch.Tensor: A tensor containing sorted indices.
torch.Tensor: A permutation tensor that contains the indices in the original tensor that give the sorted tensor.
"""
indices_sorted, permutation = torch.unique(indices, dim=-1, sorted=True, return_inverse=True)
return indices_sorted, torch.argsort(permutation)


def _compress_row_indices(row_indices, num_rows):
"""Compresses COO row indices to CSR crow indices.
Expand All @@ -27,8 +46,8 @@ def _compress_row_indices(row_indices, num_rows):
# TODO: add support for batched
def covert_coo_to_csr_indices_values(coo_indices, num_rows, values=None):
"""Converts COO row and column indices to CSR crow and col indices.
This function compressed the row indices and sorts the column indices.
If values are provided, the tensor is permuted according to the sorted column indices.
This function sorts the row and column indices (similar to torch.sparse_coo_tensor.coalesce()) and then compresses the row indices to CSR crow indices.
If values are provided, the tensor is permuted according to the sorted COO indices.
If no values are provided, the permutation indices are returned.
Expand All @@ -43,16 +62,14 @@ def covert_coo_to_csr_indices_values(coo_indices, num_rows, values=None):
torch.Tensor: Compressed CSR col indices.
torch.Tensor: Permutation indices from sorting the col indices. Or permuted values if values are provided.
"""
coo_indices, permutation = _sort_coo_indices(coo_indices)
row_indices, col_indices = coo_indices
crow_indices = _compress_row_indices(row_indices, num_rows)
return crow_indices, col_indices, values

# col_indices, permutation = torch.sort(col_indices)

# if values == None:
# return crow_indices, col_indices, permutation
# else:
# return crow_indices, col_indices, values[permutation]
if values == None:
return crow_indices, col_indices, permutation
else:
return crow_indices, col_indices, values[permutation]


# TODO: add support for batched
Expand Down

0 comments on commit fba138e

Please sign in to comment.