Skip to content

Commit

Permalink
add csr stack and block diag methods
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 13, 2023
1 parent a9ab949 commit 67cb9cb
Show file tree
Hide file tree
Showing 3 changed files with 593 additions and 2 deletions.
339 changes: 338 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import unittest
from unittest.mock import Mock

from parameterized import parameterized_class, parameterized
from torchsparsegradutils.utils.random_sparse import (
generate_random_sparse_coo_matrix,
Expand All @@ -11,12 +13,49 @@
_demcompress_crow_indices,
_sort_coo_indices,
convert_coo_to_csr,
sparse_block_diag,
sparse_block_diag_split,
stack_csr,
)

# https://pytorch.org/docs/stable/generated/torch.sparse.check_sparse_tensor_invariants.html#torch.sparse.check_sparse_tensor_invariants
# https://pytorch.org/docs/stable/generated/torch.sparse.check_sparse_tensor_invariants.html
torch.sparse.check_sparse_tensor_invariants.enable()


@parameterized_class(
(
"name",
"device",
),
[
("CPU", torch.device("cpu")),
(
"CUDA",
torch.device("cuda"),
),
],
)
class TestStackCSR(unittest.TestCase):
def setUp(self) -> None:
if not torch.cuda.is_available() and self.device == torch.device("cuda"):
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")


@parameterized.expand(
[
("4x4_12n_d0", torch.Size([4, 4]), 12, 0),
("8x16_32n_d0", torch.Size([8, 16]), 32, 0),
("4x4_12n_d-1", torch.Size([4, 4]), 12, -1),
("8x16_32n_d-1", torch.Size([8, 16]), 32, -1),
]
)
def test_stack_csr(self, _, size, nnz, dim):
csr_list = [generate_random_sparse_csr_matrix(size, nnz) for _ in range(3)]
dense_list = [csr.to_dense() for csr in csr_list]
csr_stacked = stack_csr(csr_list)
dense_stacked = torch.stack(dense_list)
self.assertTrue(torch.equal(csr_stacked.to_dense(), dense_stacked))

@parameterized_class(
(
"name",
Expand Down Expand Up @@ -194,3 +233,301 @@ def test_compress_row_indices(self, _, size, nnz):
A_coo_row_indices = A_coo.indices()[0]
row_indices = _demcompress_crow_indices(A_csr.crow_indices(), A_coo.size()[0])
self.assertTrue(torch.equal(A_coo_row_indices, row_indices))

@parameterized_class(
(
"name",
"device",
),
[
("CPU", torch.device("cpu")),
(
"CUDA",
torch.device("cuda"),
),
],
)
class TestSparseBlockDiag(unittest.TestCase):
def setUp(self) -> None:
if not torch.cuda.is_available() and self.device == torch.device("cuda"):
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")


@parameterized.expand(
[
("1x4x4_12n", torch.Size([1, 4, 4]), 12),
("4x4x4_12n", torch.Size([4, 4, 4]), 12),
("6x8x14_32n", torch.Size([6, 8, 14]), 32),
]
)
def test_sparse_block_diag_coo(self, _, size, nnz):
A_coo = generate_random_sparse_coo_matrix(size, nnz, device=self.device)
A_d = A_coo.to_dense()
A_coo_block_diag = sparse_block_diag(*A_coo)
Ad_block_diag = torch.block_diag(*A_d)
self.assertTrue(torch.equal(A_coo_block_diag.to_dense(), Ad_block_diag))


@parameterized.expand(
[
("1x4x4_12n", torch.Size([1, 4, 4]), 12),
("4x4x4_12n", torch.Size([4, 4, 4]), 12),
("6x8x14_32n", torch.Size([6, 8, 14]), 32),
]
)
def test_sparse_block_diag_coo_backward(self, _, size, nnz):
A_coo = generate_random_sparse_coo_matrix(size, nnz, device=self.device)
A_d = A_coo.detach().clone().to_dense()

A_coo.requires_grad_(True)
A_d.requires_grad_(True)

A_coo_block_diag = sparse_block_diag(*A_coo)
A_d_block_diag = torch.block_diag(*A_d)

A_coo_block_diag.sum().backward()
A_d_block_diag.sum().backward()

nz_mask = A_coo.grad.to_dense() != 0.0

self.assertTrue(torch.allclose(A_coo.grad.to_dense()[nz_mask], A_d.grad[nz_mask]))


@parameterized.expand(
[
("1x4x4_12n", torch.Size([1, 4, 4]), 12),
("4x4x4_12n", torch.Size([4, 4, 4]), 12),
("6x8x14_32n", torch.Size([6, 8, 14]), 32),
]
)
def test_sparse_block_diag_csr(self, _, size, nnz):
A_csr = generate_random_sparse_csr_matrix(size, nnz, device=self.device)
A_d = A_csr.to_dense()
A_csr_block_diag = sparse_block_diag(*A_csr)
Ad_block_diag = torch.block_diag(*A_d)
self.assertTrue(torch.equal(A_csr_block_diag.to_dense(), Ad_block_diag))

@parameterized.expand(
[
("1x4x4_12n", torch.Size([1, 4, 4]), 12), # passes for this case
("4x4x4_12n", torch.Size([4, 4, 4]), 12),
("6x8x14_32n", torch.Size([6, 8, 14]), 32),
]
)
def test_sparse_block_diag_csr_backward_dense_grad(self, _, size, nnz):
self.skipTest(reason="This test is not passing due to a BUG in PyTorch, which tries to differential CSR indices and results in: RuntimeError: isDifferentiableType(variable.scalar_type())")
A_csr = generate_random_sparse_csr_matrix(size, nnz, device=self.device)
A_d = A_csr.detach().clone().to_dense()

A_csr.requires_grad_(True)
A_d.requires_grad_(True)

A_csr_block_diag = sparse_block_diag(*A_csr)
A_d_block_diag = torch.block_diag(*A_d)

A_csr_block_diag.sum().backward()
A_d_block_diag.sum().backward()

nz_mask = A_csr.grad.to_dense() != 0.0

self.assertTrue(torch.allclose(A_csr.grad.to_dense()[nz_mask], A_d.grad[nz_mask]))

@parameterized.expand(
[
("1x4x4_12n", torch.Size([1, 4, 4]), 12), # Fails with is_conitous error
("4x4x4_12n", torch.Size([4, 4, 4]), 12), # Fails with differentiably error
("6x8x14_32n", torch.Size([6, 8, 14]), 32),
]
)
def test_sparse_block_diag_csr_backward(self, _, size, nnz):
self.skipTest(reason="This test is not passing due to a BUG in PyTorch, which tries to differential CSR indices and results in: RuntimeError: isDifferentiableType(variable.scalar_type())")
A_csr = generate_random_sparse_csr_matrix(size, nnz, device=self.device)
A_d = A_csr.detach().clone().to_dense()

A_csr.requires_grad_(True)
A_d.requires_grad_(True)

A_csr_block_diag = sparse_block_diag(*A_csr)
A_d_block_diag = torch.block_diag(*A_d)

# generate a sparse CSR tensor of the same sparsity pattern as A_csr_block_diag, but all values are unique:
grad_output = torch.sparse_csr_tensor(crow_indices=A_csr_block_diag.crow_indices(),
col_indices=A_csr_block_diag.col_indices(),
values=torch.arange(A_csr_block_diag._nnz(),
dtype=torch.float),
size=A_csr_block_diag.shape).to(self.device)

# set the gradient manually
A_csr_block_diag.backward(grad_output)
A_d_block_diag.backward(grad_output.to_dense())

nz_mask = A_csr.grad.to_dense() != 0.0

self.assertTrue(torch.allclose(A_csr.grad.to_dense()[nz_mask], A_d.grad[nz_mask]))

def test_no_arguments(self):
with self.assertRaises(ValueError):
sparse_block_diag()

def test_incorrect_tensor_layout(self):
coo_tensor = Mock(spec=torch.Tensor)
coo_tensor.layout = torch.sparse_coo
csr_tensor = Mock(spec=torch.Tensor)
csr_tensor.layout = torch.sparse_csr
with self.assertRaises(ValueError):
sparse_block_diag(coo_tensor, csr_tensor)

def test_incorrect_sparse_dim(self):
coo_tensor = Mock(spec=torch.Tensor)
coo_tensor.layout = torch.sparse_coo
coo_tensor.sparse_dim.return_value = 1
with self.assertRaises(ValueError):
sparse_block_diag(coo_tensor)

def test_incorrect_dense_dim(self):
coo_tensor = Mock(spec=torch.Tensor)
coo_tensor.layout = torch.sparse_coo
coo_tensor.dense_dim.return_value = 1
with self.assertRaises(ValueError):
sparse_block_diag(coo_tensor)

def test_incorrect_input_type(self):
with self.assertRaises(TypeError):
sparse_block_diag("Not a list or tuple")

def test_incorrect_tensor_type_in_list(self):
tensor1 = torch.randn(5, 5).to_sparse().to(device=self.device)
tensor2 = "Not a tensor"
with self.assertRaises(TypeError):
sparse_block_diag(tensor1, tensor2)

def test_different_shapes_coo(self):
tensor1 = torch.randn(5, 5).to_sparse_coo().to(device=self.device)
tensor2 = torch.randn(3, 3).to_sparse_coo().to(device=self.device)
tensor3 = torch.randn(2, 2).to_sparse_coo().to(device=self.device)
result = sparse_block_diag(tensor1, tensor2, tensor3)
self.assertEqual(result.shape, torch.Size([10, 10]))

def test_different_shapes_csr(self):
tensor1 = torch.randn(5, 5).to_sparse_csr().to(device=self.device)
tensor2 = torch.randn(3, 3).to_sparse_csr().to(device=self.device)
tensor3 = torch.randn(2, 2).to_sparse_csr().to(device=self.device)
result = sparse_block_diag(tensor1, tensor2, tensor3)
self.assertEqual(result.shape, torch.Size([10, 10]))

def test_too_many_tensor_dimensions(self):
tensor1 = torch.randn(5, 5, 5).to_sparse().to(device=self.device)
with self.assertRaises(ValueError):
sparse_block_diag(tensor1)

def test_empty_tensor_coo(self):
tensor1 = torch.sparse_coo_tensor(torch.empty([2, 0]), torch.empty([0])).to(device=self.device)
result = sparse_block_diag(tensor1)
self.assertEqual(result.shape, torch.Size([0, 0]))

# TODO: these commented tests are failing
# def test_empty_tensor_coo_mix(self):
# tensor1 = torch.empty((0, 0)).to_sparse_coo().to(device=self.device)
# tensor2 = torch.randn(3, 3).to_sparse_coo().to(device=self.device)
# result = sparse_block_diag(tensor1, tensor2)
# expected = tensor2.to_dense()
# self.assertTrue(torch.equal(result.to_dense(), expected))

# def test_empty_tensor_csr(self):
# tensor1 = torch.sparse_csr_tensor(torch.empty([2, 0]), torch.empty([0])).to(device=self.device)
# result = sparse_block_diag(tensor1)
# self.assertEqual(result.shape, torch.Size([0, 0]))

# def test_empty_tensor_csr_mix(self):
# tensor1 = torch.empty((0, 0)).to_sparse_csr().to(device=self.device)
# tensor2 = torch.randn(3, 3).to_sparse_csr().to(device=self.device)
# result = sparse_block_diag(tensor1, tensor2)
# expected = tensor2.to_dense()
# self.assertTrue(torch.equal(result.to_dense(), expected))

def test_single_tensor_coo(self):
tensor1 = torch.randn(5, 5).to_sparse_coo().to(device=self.device)
result = sparse_block_diag(tensor1)
self.assertTrue(torch.equal(result.to_dense(), tensor1.to_dense()))

def test_single_tensor_csr(self):
tensor1 = torch.randn(5, 5).to_sparse_csr().to(device=self.device)
result = sparse_block_diag(tensor1)
self.assertTrue(torch.equal(result.to_dense(), tensor1.to_dense()))

def test_zero_tensor_coo(self):
tensor1 = torch.zeros(5, 5).to_sparse().to(device=self.device)
tensor2 = torch.zeros(3, 3).to_sparse().to(device=self.device)
result = sparse_block_diag(tensor1, tensor2)
self.assertEqual(result.shape, torch.Size([8, 8]))
self.assertTrue((result.to_dense() == 0).all())

def test_zero_tensor_csr(self):
tensor1 = torch.zeros(5, 5).to_sparse_csr().to(device=self.device)
tensor2 = torch.zeros(3, 3).to_sparse_csr().to(device=self.device)
result = sparse_block_diag(tensor1, tensor2)
expected = torch.zeros(8, 8).to(device=self.device)
self.assertTrue(torch.equal(result.to_dense(), expected))

def test_non_square_tensor_coo(self):
tensor1 = torch.randn(5, 7).to_sparse().to(device=self.device)
tensor2 = torch.randn(3, 2).to_sparse().to(device=self.device)
result = sparse_block_diag(tensor1, tensor2)
self.assertEqual(result.shape, torch.Size([8, 9]))

def test_non_square_tensor_csr(self):
tensor1 = torch.randn(5, 7).to_sparse_csr().to(device=self.device)
tensor2 = torch.randn(3, 2).to_sparse_csr().to(device=self.device)
result = sparse_block_diag(tensor1, tensor2)
self.assertEqual(result.shape, torch.Size([8, 9]))


@parameterized_class(
(
"name",
"device",
),
[
("CPU", torch.device("cpu")),
(
"CUDA",
torch.device("cuda"),
),
],
)
class TestSparseBlockDiaSplit(unittest.TestCase):
def setUp(self) -> None:
if not torch.cuda.is_available() and self.device == torch.device("cuda"):
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")


@parameterized.expand(
[
("1x4x4_12n", torch.Size([1, 4, 4]), 12),
("4x4x4_12n", torch.Size([4, 4, 4]), 12),
("6x8x14_32n", torch.Size([6, 8, 14]), 32),
]
)
def test_coo(self, _, shape, nnz):
A_coo = generate_random_sparse_coo_matrix(shape, nnz, device=self.device)
A_coo_block_diag = sparse_block_diag(*A_coo)
shapes = shape[0]*(shape[-2:],)
A_coo_block_diag_split = sparse_block_diag_split(A_coo_block_diag, *shapes)
for i, A in enumerate(A_coo):
self.assertTrue(torch.equal(A.to_dense(), A_coo_block_diag_split[i].to_dense()))


@parameterized.expand(
[
("1x4x4_12n", torch.Size([1, 4, 4]), 12),
("4x4x4_12n", torch.Size([4, 4, 4]), 12),
("6x8x14_32n", torch.Size([6, 8, 14]), 32),
]
)
def test_csr(self, _, shape, nnz):
A_csr = generate_random_sparse_csr_matrix(shape, nnz, device=self.device)
A_csr_block_diag = sparse_block_diag(*A_csr)
shapes = shape[0]*(shape[-2:],)
A_csr_block_diag_split = sparse_block_diag_split(A_csr_block_diag, *shapes)
for i, A in enumerate(A_csr):
self.assertTrue(torch.equal(A.to_dense(), A_csr_block_diag_split[i].to_dense()))
4 changes: 3 additions & 1 deletion torchsparsegradutils/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@
from .minres import minres, MINRESSettings
from .bicgstab import bicgstab, BICGSTABSettings
from .lsmr import lsmr
from .utils import convert_coo_to_csr_indices_values, convert_coo_to_csr, sparse_block_diag, sparse_block_diag_split, stack_csr
from .random_sparse import rand_sparse, rand_sparse_tri

__all__ = ["linear_cg", "LinearCGSettings", "minres", "MINRESSettings", "bicgstab", "BICGSTABSettings", "lsmr"]
__all__ = ["linear_cg", "LinearCGSettings", "minres", "MINRESSettings", "bicgstab", "BICGSTABSettings", "lsmr", "convert_coo_to_csr_indices_values", "convert_coo_to_csr", "sparse_block_diag", "rand_sparse", "rand_sparse_tri", "sparse_block_diag_split", "stack_csr"]
Loading

0 comments on commit 67cb9cb

Please sign in to comment.