Skip to content

Commit

Permalink
issue with row compression
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 1, 2023
1 parent 97c7d74 commit 12efd6b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
27 changes: 25 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import unittest
from parameterized import parameterized_class, parameterized
from torchsparsegradutils.utils.random_sparse import generate_random_sparse_coo_matrix
from torchsparsegradutils.utils.random_sparse import generate_random_sparse_coo_matrix, generate_random_sparse_strictly_triangular_coo_matrix
from torchsparsegradutils.utils.utils import compress_row_indices, demcompress_crow_indices

from torchsparsegradutils.utils.utils import (
Expand All @@ -14,9 +14,22 @@
])
class TestRowIndicesCompressionDecompression(unittest.TestCase):
def setUp(self) -> None:
self.A_coo = generate_random_sparse_coo_matrix(torch.Size([4, 4]), 12, device=self.device)
self.A_coo = generate_random_sparse_coo_matrix(torch.Size([8, 8]), 12, device=self.device)
self.A_coo_tril = generate_random_sparse_strictly_triangular_coo_matrix(torch.Size([8, 8]), 12, device=self.device)
self.A_csr = self.A_coo.to_sparse_csr()
self.A_csr_tril = self.A_coo_tril.to_sparse_csr()

# row compression cannot be done without sorting the col indices and applying the same sort change to the values

# TODO: these unit tests need to build a CSR from the compressed and then convert back
# as they do not check either the values or the column indices

# TODO: let's also check the other way around, i.e. CSR to COO
# I could just use to .to_sparse_coo() and .to_sparse_csr() methods
# however, that would force conversion to int64, which won't be ideal for my use case
# Having a way to convert COO - CSR indices, strictly in int32 is useful and will prevent the memory errors I have been having
# Would also be nice to be able to do this batched.

def test_compress_row_indices(self):
row_idx, col_idx = self.A_coo.indices()
crow_idx = compress_row_indices(row_idx, self.A_coo.shape[0])
Expand All @@ -27,6 +40,16 @@ def test_demcompress_crow_indices(self):
row_idx = demcompress_crow_indices(crow_idx, self.A_coo.shape[0])
self.assertTrue(torch.allclose(row_idx, self.A_coo.indices()[0]))

def test_compress_row_indices_tril(self):
row_idx, col_idx = self.A_coo_tril.indices()
crow_idx = compress_row_indices(row_idx, self.A_coo_tril.shape[0])
self.assertTrue(torch.allclose(crow_idx, self.A_csr_tril.crow_indices()))

def test_demcompress_crow_indices_tril(self):
crow_idx = self.A_csr_tril.crow_indices()
row_idx = demcompress_crow_indices(crow_idx, self.A_coo_tril.shape[0])
self.assertTrue(torch.allclose(row_idx, self.A_coo_tril.indices()[0]))

@parameterized.expand([
("int32", torch.int32),
("int64", torch.int64),
Expand Down
8 changes: 5 additions & 3 deletions torchsparsegradutils/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ def compress_row_indices(row_indices, num_rows):
Returns:
torch.Tensor: Compressed CSR crow indices.
"""
counts = torch.bincount(row_indices)
crow_indices = torch.zeros(num_rows + 1, dtype=row_indices.dtype, device=row_indices.device)
crow_indices[1:] = torch.cumsum(counts, dim=0)
# Compute the number of non-zero elements in each row
counts = torch.bincount(row_indices, minlength=num_rows).to(row_indices.dtype)

# Compute the cumulative sum of counts to get CSR indices
crow_indices = torch.cat([torch.zeros(1, dtype=row_indices.dtype, device=counts.device), counts.cumsum(dim=0, dtype=row_indices.dtype)])

return crow_indices

Expand Down

0 comments on commit 12efd6b

Please sign in to comment.