Skip to content

Commit

Permalink
switch to pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 13, 2023
1 parent a34f2ac commit a9ab949
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 109 deletions.
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ torch>=1.13.0
numpy
scipy
jax[cpu]
parameterized
parameterized
pytest
246 changes: 138 additions & 108 deletions tests/test_sparse_matmul.py
Original file line number Diff line number Diff line change
@@ -1,109 +1,139 @@
import torch
import unittest
from random import randrange
from torchsparsegradutils import sparse_mm


def gencoordinates(nr, nc, ni, device="cuda"):
"""Used to genererate ni random unique coordinates for sparse matrix with size [nr, nc]"""
coordinates = set()
while True:
r, c = randrange(nr), randrange(nc)
coordinates.add((r, c))
if len(coordinates) == ni:
return torch.stack([torch.tensor(co) for co in coordinates], dim=-1).to(device)


class SparseMatMulTest(unittest.TestCase):
"""Test Sparse x Dense matrix multiplication with back propagation for COO and CSR matrices"""

def setUp(self) -> None:
# The device can be specialised by a daughter class
if not hasattr(self, "device"):
self.device = torch.device("cpu")
self.A_shape = (8, 16)
self.B_shape = (self.A_shape[1], 10)
self.A_nnz = 32
self.A_idx = gencoordinates(*self.A_shape, self.A_nnz, device=self.device)
self.A_val = torch.randn(self.A_nnz, dtype=torch.float64, device=self.device)
self.As_coo = torch.sparse_coo_tensor(self.A_idx, self.A_val, self.A_shape, requires_grad=True).coalesce()
self.As_csr = self.As_coo.to_sparse_csr()
self.Ad = self.As_coo.to_dense()

self.Bd = torch.randn(*self.B_shape, dtype=torch.float64, requires_grad=True, device=self.device)
self.matmul = sparse_mm

def test_matmul_forward_coo(self):
x = self.matmul(self.As_coo, self.Bd)
self.assertIsInstance(x, torch.Tensor)

def test_matmul_forward_csr(self):
x = self.matmul(self.As_csr, self.Bd)
self.assertIsInstance(x, torch.Tensor)

def test_matmul_gradient_coo(self):
# Sparse matmul:
As1 = self.As_coo.detach().clone()
As1.requires_grad = True
Bd1 = self.Bd.detach().clone()
Bd1.requires_grad = True
As1.retain_grad()
Bd1.retain_grad()
x = self.matmul(As1, Bd1)
loss = x.sum()
loss.backward()

# torch dense matmul:
Ad2 = self.Ad.detach().clone()
Ad2.requires_grad = True
Bd2 = self.Bd.detach().clone()
Bd2.requires_grad = True
Ad2.retain_grad()
Bd2.retain_grad()
x_torch = Ad2 @ Bd2
loss_torch = x_torch.sum()
loss_torch.backward()

nz_mask = As1.grad.to_dense() != 0.0
self.assertTrue(torch.isclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask]).all())
self.assertTrue(torch.isclose(Bd1.grad, Bd2.grad).all())

def test_matmul_gradient_csr(self):
# Sparse solver:
As1 = self.As_csr.detach().clone()
As1.requires_grad = True
Bd1 = self.Bd.detach().clone()
Bd1.requires_grad = True
As1.retain_grad()
Bd1.retain_grad()
x = self.matmul(As1, Bd1)
loss = x.sum()
loss.backward()

# torch dense solver:
Ad2 = self.Ad.detach().clone()
Ad2.requires_grad = True
Bd2 = self.Bd.detach().clone()
Bd2.requires_grad = True
Ad2.retain_grad()
Bd2.retain_grad()
x_torch = Ad2 @ Bd2
loss_torch = x_torch.sum()
loss_torch.backward()
nz_mask = As1.grad.to_dense() != 0.0
self.assertTrue(torch.isclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask]).all())
self.assertTrue(torch.isclose(Bd1.grad, Bd2.grad).all())


class SparseMatMulTestCUDA(SparseMatMulTest):
"""Override superclass setUp to run on GPU"""

def setUp(self) -> None:
if not torch.cuda.is_available():
self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available")
self.device = torch.device("cuda")
super().setUp()


if __name__ == "__main__":
unittest.main()
import pytest

from torchsparsegradutils import sparse_mm#, sparse_bmm
from torchsparsegradutils.utils import rand_sparse, rand_sparse_tri

# Identify Testing Parameters
DEVICES = [torch.device("cpu")]
if torch.cuda.is_available():
DEVICES.append(torch.device("cuda"))

TEST_DATA = [
# name A_shape, B_shape, A_nnz
("unbat", (4, 6), (6, 2), 8), # unbatched
("unbat", (8, 16), (16, 10), 32), # -
("unbat", (7, 4), (4, 9), 14), # -

("bat", (1, 4, 6), (1, 6, 2), 8), # batched
("bat", (4, 8, 16), (4, 16, 10), 32), # -
("bat", (11, 7, 4), (11, 4, 9), 14), # -

]

INDEX_DTYPES = [torch.int32, torch.int64]
VALUE_DTYPES = [torch.float32, torch.float64]

ATOL=1e-6 # relaxed tolerance to allow for float32
RTOL=1e-4

# Define Test Names:
def data_id(shapes):
return shapes[0]

def device_id(device):
return str(device)

def dtype_id(dtype):
return str(dtype).split('.')[-1]

# Define Fixtures

@pytest.fixture(params=TEST_DATA, ids=[data_id(d) for d in TEST_DATA])
def shapes(request):
return request.param

@pytest.fixture(params=VALUE_DTYPES, ids=[dtype_id(d) for d in VALUE_DTYPES])
def value_dtype(request):
return request.param

@pytest.fixture(params=INDEX_DTYPES, ids=[dtype_id(d) for d in INDEX_DTYPES])
def index_dtype(request):
return request.param

@pytest.fixture(params=DEVICES, ids=[device_id(d) for d in DEVICES])
def device(request):
return request.param

# Define Tests

def forward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, shapes):
if index_dtype == torch.int32 and layout is torch.sparse_coo:
pytest.skip("Skipping test as sparse COO tensors with int32 indices are not supported")

_, A_shape, B_shape, A_nnz = shapes
A = rand_sparse(A_shape, A_nnz, layout, indices_dtype=index_dtype, values_dtype=value_dtype, device=device)
B = torch.rand(*B_shape, dtype=value_dtype, device=device)
Ad = A.to_dense()

res_sparse = op_test(A, B) # both results are dense
res_dense = op_ref(Ad, B)

torch.allclose(res_sparse, res_dense, atol=ATOL, rtol=RTOL)

def backward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, shapes, is_backward=False):
if index_dtype == torch.int32 and layout is torch.sparse_coo:
pytest.skip("Skipping test as sparse COO tensors with int32 indices are not supported")

_, A_shape, B_shape, A_nnz = shapes
As1 = rand_sparse(A_shape, A_nnz, layout, indices_dtype=index_dtype, values_dtype=value_dtype, device=device)
Ad2 = As1.detach().clone().to_dense() # detach and clone to create seperate graph

Bd1 = torch.rand(*B_shape, dtype=value_dtype, device=device)
Bd2 = Bd1.detach().clone()

As1.requires_grad_()
Ad2.requires_grad_()
Bd1.requires_grad_()
Bd2.requires_grad_()

res1 = op_test(As1, Bd1) # both results are dense
res2 = op_ref(Ad2, Bd2)

# Generate random gradients for the backward pass
grad_output = torch.rand_like(res1, dtype=value_dtype, device=device)

res1.backward(grad_output)
res2.backward(grad_output)

nz_mask = As1.grad.to_dense() != 0.0

assert torch.allclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask], atol=ATOL, rtol=RTOL)
assert torch.allclose(Bd1.grad, Bd2.grad, atol=ATOL, rtol=RTOL)


def test_sparse_mm_forward_result_coo(device, value_dtype, index_dtype, shapes):
forward_routine(sparse_mm, torch.matmul, torch.sparse_coo, device, value_dtype, index_dtype, shapes)

def test_sparse_mm_forward_result_csr(device, value_dtype, index_dtype, shapes):
forward_routine(sparse_mm, torch.matmul, torch.sparse_csr, device, value_dtype, index_dtype, shapes)

def test_sparse_mm_backward_result_coo(device, value_dtype, index_dtype, shapes):
backward_routine(sparse_mm, torch.matmul, torch.sparse_coo, device, value_dtype, index_dtype, shapes, is_backward=True)

def test_sparse_mm_backward_result_csr(device, value_dtype, index_dtype, shapes):
backward_routine(sparse_mm, torch.matmul, torch.sparse_csr, device, value_dtype, index_dtype, shapes, is_backward=True)


# Additional Testing Parameters
BAD_TEST_DATA = [
# name, A, B, expected_error, error_msg
("bad_tensor", 5, torch.rand(6, 2), ValueError, "Both A and B should be instances of torch.Tensor"),
("bad_dim_A", torch.tensor([0,1]).to_sparse(), torch.rand(6, 2), ValueError, "Both A and B should be at least 2-dimensional tensors"),
("bad_dim_B", torch.rand(4, 6).to_sparse(), torch.rand(6), ValueError, "Both A and B should be at least 2-dimensional tensors"),
("bad_dim_mismatch", torch.rand(4, 6).to_sparse(), torch.rand(1, 6, 2), ValueError, "Both A and B should have the same number of dimensions"),
("bad_format", torch.rand(4, 6).to_sparse_csc(), torch.rand(6, 2), ValueError, "A should be in either COO or CSR format"),
("bad_batch", torch.stack([torch.rand(4, 6).to_sparse(), torch.rand(4, 6).to_sparse()]), torch.rand(1, 6, 2), ValueError, "If A and B have a leading batch dimension, they should have the same batch size"),
]

# Additional Fixture
@pytest.fixture(params=BAD_TEST_DATA, ids=[data_id(d) for d in BAD_TEST_DATA])
def bad_inputs(request):
return request.param

# Additional Test
def test_sparse_mm_error(bad_inputs):
_, A, B, expected_error, error_msg = bad_inputs
with pytest.raises(expected_error) as e:
sparse_mm(A, B)
assert str(e.value) == error_msg

0 comments on commit a9ab949

Please sign in to comment.