From 41adee8947c27574185cbdcedc5f75ffa273cba7 Mon Sep 17 00:00:00 2001 From: theo-barfoot Date: Mon, 5 Jun 2023 21:37:59 +0100 Subject: [PATCH] batched coo to csr and crow decomp --- tests/test_utils.py | 31 +++++++++++++-- torchsparsegradutils/utils/utils.py | 61 +++++++++++++++++++---------- 2 files changed, 69 insertions(+), 23 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index e59afc8..3a20068 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,13 +1,13 @@ import torch import unittest from parameterized import parameterized_class, parameterized -from torchsparsegradutils.utils.random_sparse import generate_random_sparse_coo_matrix, generate_random_sparse_strictly_triangular_coo_matrix -from torchsparsegradutils.utils.utils import _compress_row_indices, demcompress_crow_indices, convert_coo_to_csr +from torchsparsegradutils.utils.random_sparse import generate_random_sparse_coo_matrix, generate_random_sparse_csr_matrix from torchsparsegradutils.utils.utils import ( _compress_row_indices, - demcompress_crow_indices, + _demcompress_crow_indices, _sort_coo_indices, + convert_coo_to_csr ) @@ -102,6 +102,8 @@ def test_compress_row_indices(self, _, size, nnz): @parameterized.expand([ ("4x4_12n", torch.Size([4, 4]), 12), ("8x16_32n", torch.Size([8, 16]), 32), + ("4x4x4_12n", torch.Size([4, 4, 4]), 12), + ("6x8x14_32n", torch.Size([6, 8, 14]), 32), ]) def test_coo_to_csr_indices(self, _, size, nnz): A_coo = generate_random_sparse_coo_matrix(size, nnz, device=self.device) @@ -120,6 +122,8 @@ def test_coo_to_csr_indices(self, _, size, nnz): @parameterized.expand([ ("4x4_12n", torch.Size([4, 4]), 12), ("8x16_32n", torch.Size([8, 16]), 32), + ("4x4x4_12n", torch.Size([4, 4, 4]), 12), + ("6x8x14_32n", torch.Size([6, 8, 14]), 32), ]) def test_coo_to_csr_values(self, _, size, nnz): A_coo = generate_random_sparse_coo_matrix(size, nnz, device=self.device) @@ -133,3 +137,24 @@ def test_coo_to_csr_values(self, _, size, nnz): self.assertTrue(torch.equal(A_csr.values(), A_csr_2.values())) + +@parameterized_class(('name', 'device',), [ + ("CPU", torch.device("cpu")), + ("CUDA", torch.device("cuda"),), +]) +class TestCSRtoCOO(unittest.TestCase): + def setUp(self) -> None: + if not torch.cuda.is_available() and self.device == torch.device("cuda"): + self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available") + + + @parameterized.expand([ + ("4x4_12n", torch.Size([4, 4]), 12), + ("8x16_32n", torch.Size([8, 16]), 32), + ]) + def test_compress_row_indices(self, _, size, nnz): + A_csr = generate_random_sparse_csr_matrix(size, nnz, device=self.device) + A_coo = A_csr.to_sparse_coo() + A_coo_row_indices = A_coo.indices()[0] + row_indices = _demcompress_crow_indices(A_csr.crow_indices(), A_coo.size()[0]) + self.assertTrue(torch.equal(A_coo_row_indices, row_indices)) \ No newline at end of file diff --git a/torchsparsegradutils/utils/utils.py b/torchsparsegradutils/utils/utils.py index d62440d..910243d 100644 --- a/torchsparsegradutils/utils/utils.py +++ b/torchsparsegradutils/utils/utils.py @@ -3,8 +3,7 @@ def _sort_coo_indices(indices): """ - Sorts COO (Coordinate List Format) indices in ascending order and returns a permutation tensor that indicates - the indices in the original data that result in a sorted tensor. + Sorts COO (Coordinate List Format) indices in ascending order and returns a permutation tensor that indicates the indices in the original data that result in a sorted tensor. This function can support both unbatched and batched COO indices, and essentially performs the same operation as .coalesce() called on a COO tensor. The advantage is that COO coordinates can be sorted prior to conversion to CSR, without having to use the torch.sparse_coo_tensor object which only supports int64 indices. @@ -37,45 +36,67 @@ def _compress_row_indices(row_indices, num_rows): # Compute the cumulative sum of counts to get CSR indices - crow_indices = torch.cat([torch.zeros(1, dtype=row_indices.dtype, device=counts.device), counts.cumsum(dim=0, dtype=row_indices.dtype)]) + crow_indices = torch.cat([torch.zeros(1, dtype=row_indices.dtype, device=counts.device), counts.cumsum_(dim=0)]) return crow_indices -# TODO: add support for batched -def covert_coo_to_csr_indices_values(coo_indices, num_rows, values=None): +def convert_coo_to_csr_indices_values(coo_indices, num_rows, values=None): """Converts COO row and column indices to CSR crow and col indices. + Supports batched indices, which would have shape [3, nnz]. Or, [2, nnz] for unbatched indices. This function sorts the row and column indices (similar to torch.sparse_coo_tensor.coalesce()) and then compresses the row indices to CSR crow indices. - If values are provided, the tensor is permuted according to the sorted COO indices. + If values are provided, the values tensor is permuted according to the sorted COO indices. If no values are provided, the permutation indices are returned. Args: - row_indices (torch.Tensor): Tensor of COO row indices. - col_indices (torch.Tensor): Tensor of COO column indices. + coo_indices (torch.Tensor): Tensor of COO indices num_rows (int): Number of rows in the matrix. Returns: torch.Tensor: Compressed CSR crow indices. - torch.Tensor: Compressed CSR col indices. - torch.Tensor: Permutation indices from sorting the col indices. Or permuted values if values are provided. + torch.Tensor: CSR Col indices. + torch.Tensor: Permutation indices from sorting COO indices. Or permuted values if values are provided. """ + if coo_indices.shape[0] < 2: + raise ValueError(f"Indices tensor must have at least 2 rows (row and column indices). Got {coo_indices.shape[0]} rows.") + elif coo_indices.shape[0] > 3: + raise ValueError(f"Current implementation only supports single batch diomension, therefore indices tensor must have at most 3 rows (batch, row and column indices). Got {coo_indices.shape[0]} rows.") + + if coo_indices[-2].max() >= num_rows: + raise ValueError(f"Row indices must be less than num_rows ({num_rows}). Got max row index {coo_indices[-2].max()}") + + if values != None and values.shape[0] != coo_indices.shape[1]: + raise ValueError(f"Number of values ({values.shape[0]}) does not match number of indices ({coo_indices.shape[1]})") + coo_indices, permutation = _sort_coo_indices(coo_indices) - row_indices, col_indices = coo_indices - crow_indices = _compress_row_indices(row_indices, num_rows) - if values == None: - return crow_indices, col_indices, permutation + if coo_indices.shape[0] == 2: + row_indices, col_indices = coo_indices + crow_indices = _compress_row_indices(row_indices, num_rows) + + values = values[permutation] if values is not None else permutation + else: - return crow_indices, col_indices, values[permutation] + batch_indices, row_indices, col_indices = coo_indices + crow_indices = torch.cat([_compress_row_indices(row_indices[batch_indices == batch], num_rows) for batch in torch.unique(batch_indices)]) + num_batches = torch.unique(batch_indices).shape[0] + + crow_indices = crow_indices.reshape(num_batches, -1) + col_indices = col_indices.reshape(num_batches, -1) + + values = values[permutation] if values is not None else permutation + + values = values.reshape(num_batches, -1) + + return crow_indices, col_indices, values -# TODO: add support for batched -def convert_coo_to_csr(sparse_coo_tensor): - """Converts a COO sparse tensor to CSR format. +def convert_coo_to_csr(sparse_coo_tensor): + """Converts a COO sparse tensor to CSR format. COO tensor can have optional single leading batch dimension. Args: sparse_coo_tensor (torch.Tensor): COO sparse tensor. @@ -85,14 +106,14 @@ def convert_coo_to_csr(sparse_coo_tensor): torch.Tensor: CSR sparse tensor. """ if sparse_coo_tensor.layout == torch.sparse_coo: - crow_indices, col_indices, values = covert_coo_to_csr_indices_values(sparse_coo_tensor.indices(), sparse_coo_tensor.size()[0], sparse_coo_tensor.values()) + crow_indices, col_indices, values = convert_coo_to_csr_indices_values(sparse_coo_tensor.indices(), sparse_coo_tensor.size()[-2], sparse_coo_tensor.values()) return torch.sparse_csr_tensor(crow_indices, col_indices, values, sparse_coo_tensor.size()) else: raise ValueError(f"Unsupported layout: {sparse_coo_tensor.layout}") -def demcompress_crow_indices(crow_indices, num_rows): +def _demcompress_crow_indices(crow_indices, num_rows): """Decompresses CSR crow indices to COO row indices.