From a9ab949a456bea98d78d58e366dd084c6d817ac9 Mon Sep 17 00:00:00 2001 From: Theo Barfoot Date: Tue, 13 Jun 2023 18:47:21 +0100 Subject: [PATCH] switch to pytest --- requirements-dev.txt | 3 +- tests/test_sparse_matmul.py | 246 ++++++++++++++++++++---------------- 2 files changed, 140 insertions(+), 109 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 0bf71d9..eb00893 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,4 +2,5 @@ torch>=1.13.0 numpy scipy jax[cpu] -parameterized \ No newline at end of file +parameterized +pytest \ No newline at end of file diff --git a/tests/test_sparse_matmul.py b/tests/test_sparse_matmul.py index 94a58e2..b7ec41e 100644 --- a/tests/test_sparse_matmul.py +++ b/tests/test_sparse_matmul.py @@ -1,109 +1,139 @@ import torch -import unittest -from random import randrange -from torchsparsegradutils import sparse_mm - - -def gencoordinates(nr, nc, ni, device="cuda"): - """Used to genererate ni random unique coordinates for sparse matrix with size [nr, nc]""" - coordinates = set() - while True: - r, c = randrange(nr), randrange(nc) - coordinates.add((r, c)) - if len(coordinates) == ni: - return torch.stack([torch.tensor(co) for co in coordinates], dim=-1).to(device) - - -class SparseMatMulTest(unittest.TestCase): - """Test Sparse x Dense matrix multiplication with back propagation for COO and CSR matrices""" - - def setUp(self) -> None: - # The device can be specialised by a daughter class - if not hasattr(self, "device"): - self.device = torch.device("cpu") - self.A_shape = (8, 16) - self.B_shape = (self.A_shape[1], 10) - self.A_nnz = 32 - self.A_idx = gencoordinates(*self.A_shape, self.A_nnz, device=self.device) - self.A_val = torch.randn(self.A_nnz, dtype=torch.float64, device=self.device) - self.As_coo = torch.sparse_coo_tensor(self.A_idx, self.A_val, self.A_shape, requires_grad=True).coalesce() - self.As_csr = self.As_coo.to_sparse_csr() - self.Ad = self.As_coo.to_dense() - - self.Bd = torch.randn(*self.B_shape, dtype=torch.float64, requires_grad=True, device=self.device) - self.matmul = sparse_mm - - def test_matmul_forward_coo(self): - x = self.matmul(self.As_coo, self.Bd) - self.assertIsInstance(x, torch.Tensor) - - def test_matmul_forward_csr(self): - x = self.matmul(self.As_csr, self.Bd) - self.assertIsInstance(x, torch.Tensor) - - def test_matmul_gradient_coo(self): - # Sparse matmul: - As1 = self.As_coo.detach().clone() - As1.requires_grad = True - Bd1 = self.Bd.detach().clone() - Bd1.requires_grad = True - As1.retain_grad() - Bd1.retain_grad() - x = self.matmul(As1, Bd1) - loss = x.sum() - loss.backward() - - # torch dense matmul: - Ad2 = self.Ad.detach().clone() - Ad2.requires_grad = True - Bd2 = self.Bd.detach().clone() - Bd2.requires_grad = True - Ad2.retain_grad() - Bd2.retain_grad() - x_torch = Ad2 @ Bd2 - loss_torch = x_torch.sum() - loss_torch.backward() - - nz_mask = As1.grad.to_dense() != 0.0 - self.assertTrue(torch.isclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask]).all()) - self.assertTrue(torch.isclose(Bd1.grad, Bd2.grad).all()) - - def test_matmul_gradient_csr(self): - # Sparse solver: - As1 = self.As_csr.detach().clone() - As1.requires_grad = True - Bd1 = self.Bd.detach().clone() - Bd1.requires_grad = True - As1.retain_grad() - Bd1.retain_grad() - x = self.matmul(As1, Bd1) - loss = x.sum() - loss.backward() - - # torch dense solver: - Ad2 = self.Ad.detach().clone() - Ad2.requires_grad = True - Bd2 = self.Bd.detach().clone() - Bd2.requires_grad = True - Ad2.retain_grad() - Bd2.retain_grad() - x_torch = Ad2 @ Bd2 - loss_torch = x_torch.sum() - loss_torch.backward() - nz_mask = As1.grad.to_dense() != 0.0 - self.assertTrue(torch.isclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask]).all()) - self.assertTrue(torch.isclose(Bd1.grad, Bd2.grad).all()) - - -class SparseMatMulTestCUDA(SparseMatMulTest): - """Override superclass setUp to run on GPU""" - - def setUp(self) -> None: - if not torch.cuda.is_available(): - self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available") - self.device = torch.device("cuda") - super().setUp() - - -if __name__ == "__main__": - unittest.main() +import pytest + +from torchsparsegradutils import sparse_mm#, sparse_bmm +from torchsparsegradutils.utils import rand_sparse, rand_sparse_tri + +# Identify Testing Parameters +DEVICES = [torch.device("cpu")] +if torch.cuda.is_available(): + DEVICES.append(torch.device("cuda")) + +TEST_DATA = [ + # name A_shape, B_shape, A_nnz + ("unbat", (4, 6), (6, 2), 8), # unbatched + ("unbat", (8, 16), (16, 10), 32), # - + ("unbat", (7, 4), (4, 9), 14), # - + + ("bat", (1, 4, 6), (1, 6, 2), 8), # batched + ("bat", (4, 8, 16), (4, 16, 10), 32), # - + ("bat", (11, 7, 4), (11, 4, 9), 14), # - + +] + +INDEX_DTYPES = [torch.int32, torch.int64] +VALUE_DTYPES = [torch.float32, torch.float64] + +ATOL=1e-6 # relaxed tolerance to allow for float32 +RTOL=1e-4 + +# Define Test Names: +def data_id(shapes): + return shapes[0] + +def device_id(device): + return str(device) + +def dtype_id(dtype): + return str(dtype).split('.')[-1] + +# Define Fixtures + +@pytest.fixture(params=TEST_DATA, ids=[data_id(d) for d in TEST_DATA]) +def shapes(request): + return request.param + +@pytest.fixture(params=VALUE_DTYPES, ids=[dtype_id(d) for d in VALUE_DTYPES]) +def value_dtype(request): + return request.param + +@pytest.fixture(params=INDEX_DTYPES, ids=[dtype_id(d) for d in INDEX_DTYPES]) +def index_dtype(request): + return request.param + +@pytest.fixture(params=DEVICES, ids=[device_id(d) for d in DEVICES]) +def device(request): + return request.param + +# Define Tests + +def forward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, shapes): + if index_dtype == torch.int32 and layout is torch.sparse_coo: + pytest.skip("Skipping test as sparse COO tensors with int32 indices are not supported") + + _, A_shape, B_shape, A_nnz = shapes + A = rand_sparse(A_shape, A_nnz, layout, indices_dtype=index_dtype, values_dtype=value_dtype, device=device) + B = torch.rand(*B_shape, dtype=value_dtype, device=device) + Ad = A.to_dense() + + res_sparse = op_test(A, B) # both results are dense + res_dense = op_ref(Ad, B) + + torch.allclose(res_sparse, res_dense, atol=ATOL, rtol=RTOL) + +def backward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, shapes, is_backward=False): + if index_dtype == torch.int32 and layout is torch.sparse_coo: + pytest.skip("Skipping test as sparse COO tensors with int32 indices are not supported") + + _, A_shape, B_shape, A_nnz = shapes + As1 = rand_sparse(A_shape, A_nnz, layout, indices_dtype=index_dtype, values_dtype=value_dtype, device=device) + Ad2 = As1.detach().clone().to_dense() # detach and clone to create seperate graph + + Bd1 = torch.rand(*B_shape, dtype=value_dtype, device=device) + Bd2 = Bd1.detach().clone() + + As1.requires_grad_() + Ad2.requires_grad_() + Bd1.requires_grad_() + Bd2.requires_grad_() + + res1 = op_test(As1, Bd1) # both results are dense + res2 = op_ref(Ad2, Bd2) + + # Generate random gradients for the backward pass + grad_output = torch.rand_like(res1, dtype=value_dtype, device=device) + + res1.backward(grad_output) + res2.backward(grad_output) + + nz_mask = As1.grad.to_dense() != 0.0 + + assert torch.allclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask], atol=ATOL, rtol=RTOL) + assert torch.allclose(Bd1.grad, Bd2.grad, atol=ATOL, rtol=RTOL) + + +def test_sparse_mm_forward_result_coo(device, value_dtype, index_dtype, shapes): + forward_routine(sparse_mm, torch.matmul, torch.sparse_coo, device, value_dtype, index_dtype, shapes) + +def test_sparse_mm_forward_result_csr(device, value_dtype, index_dtype, shapes): + forward_routine(sparse_mm, torch.matmul, torch.sparse_csr, device, value_dtype, index_dtype, shapes) + +def test_sparse_mm_backward_result_coo(device, value_dtype, index_dtype, shapes): + backward_routine(sparse_mm, torch.matmul, torch.sparse_coo, device, value_dtype, index_dtype, shapes, is_backward=True) + +def test_sparse_mm_backward_result_csr(device, value_dtype, index_dtype, shapes): + backward_routine(sparse_mm, torch.matmul, torch.sparse_csr, device, value_dtype, index_dtype, shapes, is_backward=True) + + +# Additional Testing Parameters +BAD_TEST_DATA = [ + # name, A, B, expected_error, error_msg + ("bad_tensor", 5, torch.rand(6, 2), ValueError, "Both A and B should be instances of torch.Tensor"), + ("bad_dim_A", torch.tensor([0,1]).to_sparse(), torch.rand(6, 2), ValueError, "Both A and B should be at least 2-dimensional tensors"), + ("bad_dim_B", torch.rand(4, 6).to_sparse(), torch.rand(6), ValueError, "Both A and B should be at least 2-dimensional tensors"), + ("bad_dim_mismatch", torch.rand(4, 6).to_sparse(), torch.rand(1, 6, 2), ValueError, "Both A and B should have the same number of dimensions"), + ("bad_format", torch.rand(4, 6).to_sparse_csc(), torch.rand(6, 2), ValueError, "A should be in either COO or CSR format"), + ("bad_batch", torch.stack([torch.rand(4, 6).to_sparse(), torch.rand(4, 6).to_sparse()]), torch.rand(1, 6, 2), ValueError, "If A and B have a leading batch dimension, they should have the same batch size"), +] + +# Additional Fixture +@pytest.fixture(params=BAD_TEST_DATA, ids=[data_id(d) for d in BAD_TEST_DATA]) +def bad_inputs(request): + return request.param + +# Additional Test +def test_sparse_mm_error(bad_inputs): + _, A, B, expected_error, error_msg = bad_inputs + with pytest.raises(expected_error) as e: + sparse_mm(A, B) + assert str(e.value) == error_msg \ No newline at end of file