Skip to content

Commit

Permalink
interim commit
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 2, 2023
1 parent 12efd6b commit 11f0d58
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 50 deletions.
115 changes: 71 additions & 44 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,66 +2,93 @@
import unittest
from parameterized import parameterized_class, parameterized
from torchsparsegradutils.utils.random_sparse import generate_random_sparse_coo_matrix, generate_random_sparse_strictly_triangular_coo_matrix
from torchsparsegradutils.utils.utils import compress_row_indices, demcompress_crow_indices
from torchsparsegradutils.utils.utils import _compress_row_indices, demcompress_crow_indices, convert_coo_to_csr

from torchsparsegradutils.utils.utils import (
compress_row_indices,
_compress_row_indices,
demcompress_crow_indices,
)

@parameterized_class(('name', 'device',), [
("CPU", torch.device("cpu")),
("CUDA", torch.device("cuda"),),
])
class TestRowIndicesCompressionDecompression(unittest.TestCase):
class TestCOOtoCSR(unittest.TestCase):
def setUp(self) -> None:
self.A_coo = generate_random_sparse_coo_matrix(torch.Size([8, 8]), 12, device=self.device)
self.A_coo_tril = generate_random_sparse_strictly_triangular_coo_matrix(torch.Size([8, 8]), 12, device=self.device)
self.A_csr = self.A_coo.to_sparse_csr()
self.A_csr_tril = self.A_coo_tril.to_sparse_csr()
if not torch.cuda.is_available() and self.device == torch.device("cuda"):
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")


def batched_coo_to_csr(self, A_coo):
"""
Converts batched sparse COO matrix A with shape [B, N, M]
to batched sparse CSR matrix with shape [B, N, M]
Inneficient implementation as the COO tensors only support int64 indices.
Meaning that int32 indices cannot be maintained if a CSR matrix is created via to_sparse_csr() from COO.
"""
A_crow_indices_list = []
A_row_indices_list = []
A_values_list = []

# row compression cannot be done without sorting the col indices and applying the same sort change to the values
size = A_coo.size()

for a_coo in A_coo:
a_csr = a_coo.detach().to_sparse_csr() # detach to prevent grad on indices
A_crow_indices_list.append(a_csr.crow_indices())
A_row_indices_list.append(a_csr.col_indices())
A_values_list.append(a_csr.values())

A_crow_indices = torch.stack(A_crow_indices_list, dim=0)
A_row_indices = torch.stack(A_row_indices_list, dim=0)
A_values = torch.stack(A_values_list, dim=0)

# TODO: these unit tests need to build a CSR from the compressed and then convert back
# as they do not check either the values or the column indices
return torch.sparse_csr_tensor(A_crow_indices, A_row_indices, A_values, size=size)

# TODO: let's also check the other way around, i.e. CSR to COO
# I could just use to .to_sparse_coo() and .to_sparse_csr() methods
# however, that would force conversion to int64, which won't be ideal for my use case
# Having a way to convert COO - CSR indices, strictly in int32 is useful and will prevent the memory errors I have been having
# Would also be nice to be able to do this batched.

def test_compress_row_indices(self):
row_idx, col_idx = self.A_coo.indices()
crow_idx = compress_row_indices(row_idx, self.A_coo.shape[0])
self.assertTrue(torch.allclose(crow_idx, self.A_csr.crow_indices()))

def test_demcompress_crow_indices(self):
crow_idx = self.A_csr.crow_indices()
row_idx = demcompress_crow_indices(crow_idx, self.A_coo.shape[0])
self.assertTrue(torch.allclose(row_idx, self.A_coo.indices()[0]))
@parameterized.expand([
("4x4_12n", torch.Size([4, 4]), 12),
("8x16_32n", torch.Size([8, 16]), 32),
])
def test_compress_row_indices(self, _, size, nnz):
A_coo = generate_random_sparse_coo_matrix(size, nnz, device=self.device)
A_csr = A_coo.to_sparse_csr()
A_csr_crow_indices = A_csr.crow_indices()
crow_indices = _compress_row_indices(A_coo.indices()[0], A_coo.size()[0])
self.assertTrue(torch.equal(A_csr_crow_indices, crow_indices))


@parameterized.expand([
("4x4_12n", torch.Size([4, 4]), 12),
("8x16_32n", torch.Size([8, 16]), 32),
])
def test_coo_to_csr_indices(self, _, size, nnz):
A_coo = generate_random_sparse_coo_matrix(size, nnz, device=self.device)
A_csr = convert_coo_to_csr(A_coo)
if len(size) == 2:
A_csr_2 = A_coo.to_sparse_csr()
elif len(size) == 3:
A_csr_2 = self.batched_coo_to_csr(A_coo)
else:
raise ValueError(f"Size {size} not supported")

def test_compress_row_indices_tril(self):
row_idx, col_idx = self.A_coo_tril.indices()
crow_idx = compress_row_indices(row_idx, self.A_coo_tril.shape[0])
self.assertTrue(torch.allclose(crow_idx, self.A_csr_tril.crow_indices()))
self.assertTrue(torch.equal(A_csr.crow_indices(), A_csr_2.crow_indices()))
self.assertTrue(torch.equal(A_csr.col_indices(), A_csr_2.col_indices()))

def test_demcompress_crow_indices_tril(self):
crow_idx = self.A_csr_tril.crow_indices()
row_idx = demcompress_crow_indices(crow_idx, self.A_coo_tril.shape[0])
self.assertTrue(torch.allclose(row_idx, self.A_coo_tril.indices()[0]))

@parameterized.expand([
("int32", torch.int32),
("int64", torch.int64),
("4x4_12n", torch.Size([4, 4]), 12),
("8x16_32n", torch.Size([8, 16]), 32),
])
def test_indices_dtype(self, _, indices_dtype):
num_rows = 4
row_idx = torch.tensor([0, 0, 0, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=indices_dtype, device=self.device)
crow_idx = compress_row_indices(row_idx, num_rows)
self.assertEqual(crow_idx.dtype, indices_dtype)
def test_coo_to_csr_values(self, _, size, nnz):
A_coo = generate_random_sparse_coo_matrix(size, nnz, device=self.device)
A_csr = convert_coo_to_csr(A_coo)
if len(size) == 2:
A_csr_2 = A_coo.to_sparse_csr()
elif len(size) == 3:
A_csr_2 = self.batched_coo_to_csr(A_coo)
else:
raise ValueError(f"Size {size} not supported")

self.assertTrue(torch.equal(A_csr.values(), A_csr_2.values()))

def test_device(self):
num_rows = 4
row_idx = torch.tensor([0, 0, 0, 1, 2, 2, 2, 2, 3, 3, 3, 3], dtype=torch.int32, device=self.device)
crow_idx = compress_row_indices(row_idx, num_rows)
self.assertEqual(crow_idx.device.type, self.device.type)
6 changes: 3 additions & 3 deletions torchsparsegradutils/utils/random_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import warnings
import torch
import random
from torchsparsegradutils.utils.utils import compress_row_indices
from torchsparsegradutils.utils.utils import _compress_row_indices

def _gen_indices_2d_coo(nr, nc, nnz, *, dtype=torch.int64, device=torch.device("cpu")):
"""Generates nnz random unique coordinates in COO format.
Expand Down Expand Up @@ -111,7 +111,7 @@ def generate_random_sparse_csr_matrix(size, nnz, *, indices_dtype=torch.int64, v
warnings.warn(f"A bit depth of less than torch.int32 is not recommended for sparse CSR tensors", UserWarning)

row_indices, col_indices = _gen_indices_2d_coo(size[-2], size[-1], nnz, dtype=indices_dtype, device=device)
crow_indices = compress_row_indices(row_indices, size[-2])
crow_indices = _compress_row_indices(row_indices, size[-2])

if len(size) == 2:
values = torch.rand(nnz, dtype=values_dtype, device=device)
Expand Down Expand Up @@ -232,7 +232,7 @@ def generate_random_sparse_strictly_triangular_csr_matrix(size, nnz, *, upper=Tr
warnings.warn(f"A bit depth of less than torch.int32 is not recommended for sparse CSR tensors", UserWarning)

row_indices, col_indices = _gen_indices_2d_coo_strictly_tri(size[-2], nnz, upper=upper, dtype=indices_dtype, device=device)
crow_indices = compress_row_indices(row_indices, size[-2])
crow_indices = _compress_row_indices(row_indices, size[-2])

if len(size) == 2:
values = torch.rand(nnz, dtype=values_dtype, device=device)
Expand Down
64 changes: 61 additions & 3 deletions torchsparsegradutils/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,95 @@
import torch

def compress_row_indices(row_indices, num_rows):

def _compress_row_indices(row_indices, num_rows):
"""Compresses COO row indices to CSR crow indices.
Args:
row_indices (torch.Tensor): Tensor of COO row indices.
num_rows (int): Number of rows in the matrix.
Returns:
torch.Tensor: Compressed CSR crow indices.
"""
# Compute the number of non-zero elements in each row
counts = torch.bincount(row_indices, minlength=num_rows).to(row_indices.dtype)


# Compute the cumulative sum of counts to get CSR indices
crow_indices = torch.cat([torch.zeros(1, dtype=row_indices.dtype, device=counts.device), counts.cumsum(dim=0, dtype=row_indices.dtype)])


return crow_indices


# TODO: add support for batched
def covert_coo_to_csr_indices_values(coo_indices, num_rows, values=None):
"""Converts COO row and column indices to CSR crow and col indices.
This function compressed the row indices and sorts the column indices.
If values are provided, the tensor is permuted according to the sorted column indices.
If no values are provided, the permutation indices are returned.
Args:
row_indices (torch.Tensor): Tensor of COO row indices.
col_indices (torch.Tensor): Tensor of COO column indices.
num_rows (int): Number of rows in the matrix.
Returns:
torch.Tensor: Compressed CSR crow indices.
torch.Tensor: Compressed CSR col indices.
torch.Tensor: Permutation indices from sorting the col indices. Or permuted values if values are provided.
"""
row_indices, col_indices = coo_indices
crow_indices = _compress_row_indices(row_indices, num_rows)
return crow_indices, col_indices, values

# col_indices, permutation = torch.sort(col_indices)

# if values == None:
# return crow_indices, col_indices, permutation
# else:
# return crow_indices, col_indices, values[permutation]


# TODO: add support for batched
def convert_coo_to_csr(sparse_coo_tensor):
"""Converts a COO sparse tensor to CSR format.
Args:
sparse_coo_tensor (torch.Tensor): COO sparse tensor.
Returns:
torch.Tensor: CSR sparse tensor.
"""
if sparse_coo_tensor.layout == torch.sparse_coo:
crow_indices, col_indices, values = covert_coo_to_csr_indices_values(sparse_coo_tensor.indices(), sparse_coo_tensor.size()[0], sparse_coo_tensor.values())
return torch.sparse_csr_tensor(crow_indices, col_indices, values, sparse_coo_tensor.size())
else:
raise ValueError(f"Unsupported layout: {sparse_coo_tensor.layout}")



def demcompress_crow_indices(crow_indices, num_rows):
"""Decompresses CSR crow indices to COO row indices.
Args:
csr_crow_indices (torch.Tensor): Tensor of CSR crow indices.
num_rows (int): Number of rows in the matrix.
Returns:
torch.Tensor: Decompressed COO row indices.
"""

row_indices = torch.repeat_interleave(
torch.arange(num_rows, dtype=crow_indices.dtype, device=crow_indices.device), crow_indices[1:] - crow_indices[:-1]
)

return row_indices

0 comments on commit 11f0d58

Please sign in to comment.