From cb061410c412a61523beecaa508129fd12e7d126 Mon Sep 17 00:00:00 2001 From: Theo Barfoot Date: Wed, 14 Jun 2023 12:40:29 +0100 Subject: [PATCH] black and flake8 --- .flake8 | 4 +- .github/workflows/python-package.yml | 4 +- .pre-commit-config.yaml | 7 +- tests/test_lsmr.py | 4 +- tests/test_sparse_matmul.py | 110 +++++++++++++------ tests/test_sparse_solve.py | 2 +- tests/test_utils.py | 87 +++++++-------- torchsparsegradutils/jax/jax_sparse_solve.py | 6 +- torchsparsegradutils/sparse_lstsq.py | 6 +- torchsparsegradutils/sparse_matmul.py | 42 +++---- torchsparsegradutils/sparse_solve.py | 8 +- torchsparsegradutils/utils/__init__.py | 25 ++++- torchsparsegradutils/utils/bicgstab.py | 3 +- torchsparsegradutils/utils/lsmr.py | 2 - torchsparsegradutils/utils/random_sparse.py | 45 ++++++-- torchsparsegradutils/utils/utils.py | 75 ++++++------- 16 files changed, 264 insertions(+), 166 deletions(-) diff --git a/.flake8 b/.flake8 index 093ef4a..eba21b3 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] - ignore = E203, E266, E501, W503, F403, F401 + ignore = E203, E266, E501, W503, F403, F401, E731, E402 max-line-length = 120 - max-complexity = 10 + max-complexity = 25 select = B,C,E,F,W,T4,B9 \ No newline at end of file diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index a0631a8..96111bb 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -17,6 +17,7 @@ jobs: fail-fast: false matrix: python-version: ["3.8", "3.9", "3.10"] + torch-version: ["1.13.1", "2.0.1"] steps: - uses: actions/checkout@v3 @@ -36,8 +37,7 @@ jobs: - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - flake8 . --count --max-complexity=10 --max-line-length=120 --statistics + flake8 . --count --show-source --statistics - name: Check code formatting with black run: | black --check . diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3e1d5a7..c6b953d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,4 +8,9 @@ repos: - repo: https://github.com/pycqa/flake8 rev: 6.0.0 hooks: - - id: flake8 \ No newline at end of file + - id: flake8 + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: trailing-whitespace diff --git a/tests/test_lsmr.py b/tests/test_lsmr.py index 7481a95..11f25c5 100644 --- a/tests/test_lsmr.py +++ b/tests/test_lsmr.py @@ -203,7 +203,7 @@ def setUp(self): def test_unchanged_x0(self): # x, istop, itn, normr, normar, normA, condA, normx = self.returnValuesX0 - x = self.returnValuesX0[0] + # x = self.returnValuesX0[0] # variable unused self.assertTrue(torch.allclose(self.x00, self.x0, atol=1e-3, rtol=1e-4)) def testNormr(self): @@ -283,7 +283,7 @@ def lsmrtest(m, n, damp, dtype, device): btol = 1.0e-7 conlim = 1.0e10 itnlim = 10 * n - show = 1 + # show = 1 # variable unused # x, istop, itn, normr, normar, norma, conda, normx \ # = lsmr(A, b, damp, atol, btol, conlim, itnlim, show) diff --git a/tests/test_sparse_matmul.py b/tests/test_sparse_matmul.py index b7ec41e..4466c65 100644 --- a/tests/test_sparse_matmul.py +++ b/tests/test_sparse_matmul.py @@ -1,7 +1,7 @@ import torch import pytest -from torchsparsegradutils import sparse_mm#, sparse_bmm +from torchsparsegradutils import sparse_mm # , sparse_bmm from torchsparsegradutils.utils import rand_sparse, rand_sparse_tri # Identify Testing Parameters @@ -11,93 +11,102 @@ TEST_DATA = [ # name A_shape, B_shape, A_nnz - ("unbat", (4, 6), (6, 2), 8), # unbatched - ("unbat", (8, 16), (16, 10), 32), # - - ("unbat", (7, 4), (4, 9), 14), # - - - ("bat", (1, 4, 6), (1, 6, 2), 8), # batched + ("unbat", (4, 6), (6, 2), 8), # unbatched + ("unbat", (8, 16), (16, 10), 32), # - + ("unbat", (7, 4), (4, 9), 14), # - + ("bat", (1, 4, 6), (1, 6, 2), 8), # batched ("bat", (4, 8, 16), (4, 16, 10), 32), # - - ("bat", (11, 7, 4), (11, 4, 9), 14), # - - + ("bat", (11, 7, 4), (11, 4, 9), 14), # - ] INDEX_DTYPES = [torch.int32, torch.int64] VALUE_DTYPES = [torch.float32, torch.float64] -ATOL=1e-6 # relaxed tolerance to allow for float32 -RTOL=1e-4 +ATOL = 1e-6 # relaxed tolerance to allow for float32 +RTOL = 1e-4 + # Define Test Names: def data_id(shapes): return shapes[0] + def device_id(device): return str(device) + def dtype_id(dtype): - return str(dtype).split('.')[-1] + return str(dtype).split(".")[-1] + # Define Fixtures + @pytest.fixture(params=TEST_DATA, ids=[data_id(d) for d in TEST_DATA]) def shapes(request): return request.param + @pytest.fixture(params=VALUE_DTYPES, ids=[dtype_id(d) for d in VALUE_DTYPES]) def value_dtype(request): return request.param + @pytest.fixture(params=INDEX_DTYPES, ids=[dtype_id(d) for d in INDEX_DTYPES]) def index_dtype(request): return request.param + @pytest.fixture(params=DEVICES, ids=[device_id(d) for d in DEVICES]) def device(request): return request.param + # Define Tests + def forward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, shapes): if index_dtype == torch.int32 and layout is torch.sparse_coo: pytest.skip("Skipping test as sparse COO tensors with int32 indices are not supported") - + _, A_shape, B_shape, A_nnz = shapes A = rand_sparse(A_shape, A_nnz, layout, indices_dtype=index_dtype, values_dtype=value_dtype, device=device) B = torch.rand(*B_shape, dtype=value_dtype, device=device) Ad = A.to_dense() - + res_sparse = op_test(A, B) # both results are dense res_dense = op_ref(Ad, B) - + torch.allclose(res_sparse, res_dense, atol=ATOL, rtol=RTOL) + def backward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, shapes, is_backward=False): if index_dtype == torch.int32 and layout is torch.sparse_coo: pytest.skip("Skipping test as sparse COO tensors with int32 indices are not supported") - + _, A_shape, B_shape, A_nnz = shapes As1 = rand_sparse(A_shape, A_nnz, layout, indices_dtype=index_dtype, values_dtype=value_dtype, device=device) Ad2 = As1.detach().clone().to_dense() # detach and clone to create seperate graph - + Bd1 = torch.rand(*B_shape, dtype=value_dtype, device=device) Bd2 = Bd1.detach().clone() - + As1.requires_grad_() Ad2.requires_grad_() Bd1.requires_grad_() Bd2.requires_grad_() - + res1 = op_test(As1, Bd1) # both results are dense res2 = op_ref(Ad2, Bd2) - + # Generate random gradients for the backward pass grad_output = torch.rand_like(res1, dtype=value_dtype, device=device) - + res1.backward(grad_output) res2.backward(grad_output) - + nz_mask = As1.grad.to_dense() != 0.0 - + assert torch.allclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask], atol=ATOL, rtol=RTOL) assert torch.allclose(Bd1.grad, Bd2.grad, atol=ATOL, rtol=RTOL) @@ -105,35 +114,74 @@ def backward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, def test_sparse_mm_forward_result_coo(device, value_dtype, index_dtype, shapes): forward_routine(sparse_mm, torch.matmul, torch.sparse_coo, device, value_dtype, index_dtype, shapes) + def test_sparse_mm_forward_result_csr(device, value_dtype, index_dtype, shapes): forward_routine(sparse_mm, torch.matmul, torch.sparse_csr, device, value_dtype, index_dtype, shapes) + def test_sparse_mm_backward_result_coo(device, value_dtype, index_dtype, shapes): - backward_routine(sparse_mm, torch.matmul, torch.sparse_coo, device, value_dtype, index_dtype, shapes, is_backward=True) + backward_routine( + sparse_mm, torch.matmul, torch.sparse_coo, device, value_dtype, index_dtype, shapes, is_backward=True + ) + def test_sparse_mm_backward_result_csr(device, value_dtype, index_dtype, shapes): - backward_routine(sparse_mm, torch.matmul, torch.sparse_csr, device, value_dtype, index_dtype, shapes, is_backward=True) + backward_routine( + sparse_mm, torch.matmul, torch.sparse_csr, device, value_dtype, index_dtype, shapes, is_backward=True + ) # Additional Testing Parameters BAD_TEST_DATA = [ # name, A, B, expected_error, error_msg - ("bad_tensor", 5, torch.rand(6, 2), ValueError, "Both A and B should be instances of torch.Tensor"), - ("bad_dim_A", torch.tensor([0,1]).to_sparse(), torch.rand(6, 2), ValueError, "Both A and B should be at least 2-dimensional tensors"), - ("bad_dim_B", torch.rand(4, 6).to_sparse(), torch.rand(6), ValueError, "Both A and B should be at least 2-dimensional tensors"), - ("bad_dim_mismatch", torch.rand(4, 6).to_sparse(), torch.rand(1, 6, 2), ValueError, "Both A and B should have the same number of dimensions"), - ("bad_format", torch.rand(4, 6).to_sparse_csc(), torch.rand(6, 2), ValueError, "A should be in either COO or CSR format"), - ("bad_batch", torch.stack([torch.rand(4, 6).to_sparse(), torch.rand(4, 6).to_sparse()]), torch.rand(1, 6, 2), ValueError, "If A and B have a leading batch dimension, they should have the same batch size"), + ("bad_tensor", 5, torch.rand(6, 2), ValueError, "Both A and B should be instances of torch.Tensor"), + ( + "bad_dim_A", + torch.tensor([0, 1]).to_sparse(), + torch.rand(6, 2), + ValueError, + "Both A and B should be at least 2-dimensional tensors", + ), + ( + "bad_dim_B", + torch.rand(4, 6).to_sparse(), + torch.rand(6), + ValueError, + "Both A and B should be at least 2-dimensional tensors", + ), + ( + "bad_dim_mismatch", + torch.rand(4, 6).to_sparse(), + torch.rand(1, 6, 2), + ValueError, + "Both A and B should have the same number of dimensions", + ), + ( + "bad_format", + torch.rand(4, 6).to_sparse_csc(), + torch.rand(6, 2), + ValueError, + "A should be in either COO or CSR format", + ), + ( + "bad_batch", + torch.stack([torch.rand(4, 6).to_sparse(), torch.rand(4, 6).to_sparse()]), + torch.rand(1, 6, 2), + ValueError, + "If A and B have a leading batch dimension, they should have the same batch size", + ), ] + # Additional Fixture @pytest.fixture(params=BAD_TEST_DATA, ids=[data_id(d) for d in BAD_TEST_DATA]) def bad_inputs(request): return request.param + # Additional Test def test_sparse_mm_error(bad_inputs): _, A, B, expected_error, error_msg = bad_inputs with pytest.raises(expected_error) as e: sparse_mm(A, B) - assert str(e.value) == error_msg \ No newline at end of file + assert str(e.value) == error_msg diff --git a/tests/test_sparse_solve.py b/tests/test_sparse_solve.py index 7578edc..4017766 100644 --- a/tests/test_sparse_solve.py +++ b/tests/test_sparse_solve.py @@ -203,7 +203,7 @@ def test_solver_gradient_csr_tril(self): def test_solver_non_triangular_error(self): """Test to check that solver throws an ValueError if a diagonal is specified in input but unitriangular=True in solver arguments""" - if self.unitriangular == False: # statement to stop test running in SparseUnitTriangularSolveTest + if self.unitriangular is False: # statement to stop test running in SparseUnitTriangularSolveTest As1 = self.As_csr_triu.detach().clone() As1.requires_grad = True Bd1 = self.Bd.detach().clone() diff --git a/tests/test_utils.py b/tests/test_utils.py index 3f3a50a..982c724 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -39,8 +39,7 @@ class TestStackCSR(unittest.TestCase): def setUp(self) -> None: if not torch.cuda.is_available() and self.device == torch.device("cuda"): self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available") - - + @parameterized.expand( [ ("4x4_12n_d0", torch.Size([4, 4]), 12, 0), @@ -56,6 +55,7 @@ def test_stack_csr(self, _, size, nnz, dim): dense_stacked = torch.stack(dense_list) self.assertTrue(torch.equal(csr_stacked.to_dense(), dense_stacked)) + @parameterized_class( ( "name", @@ -234,6 +234,7 @@ def test_compress_row_indices(self, _, size, nnz): row_indices = _demcompress_crow_indices(A_csr.crow_indices(), A_coo.size()[0]) self.assertTrue(torch.equal(A_coo_row_indices, row_indices)) + @parameterized_class( ( "name", @@ -251,10 +252,9 @@ class TestSparseBlockDiag(unittest.TestCase): def setUp(self) -> None: if not torch.cuda.is_available() and self.device == torch.device("cuda"): self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available") - - + @parameterized.expand( - [ + [ ("1x4x4_12n", torch.Size([1, 4, 4]), 12), ("4x4x4_12n", torch.Size([4, 4, 4]), 12), ("6x8x14_32n", torch.Size([6, 8, 14]), 32), @@ -266,8 +266,7 @@ def test_sparse_block_diag_coo(self, _, size, nnz): A_coo_block_diag = sparse_block_diag(*A_coo) Ad_block_diag = torch.block_diag(*A_d) self.assertTrue(torch.equal(A_coo_block_diag.to_dense(), Ad_block_diag)) - - + @parameterized.expand( [ ("1x4x4_12n", torch.Size([1, 4, 4]), 12), @@ -278,20 +277,19 @@ def test_sparse_block_diag_coo(self, _, size, nnz): def test_sparse_block_diag_coo_backward(self, _, size, nnz): A_coo = generate_random_sparse_coo_matrix(size, nnz, device=self.device) A_d = A_coo.detach().clone().to_dense() - + A_coo.requires_grad_(True) A_d.requires_grad_(True) - + A_coo_block_diag = sparse_block_diag(*A_coo) A_d_block_diag = torch.block_diag(*A_d) - + A_coo_block_diag.sum().backward() A_d_block_diag.sum().backward() - + nz_mask = A_coo.grad.to_dense() != 0.0 - - self.assertTrue(torch.allclose(A_coo.grad.to_dense()[nz_mask], A_d.grad[nz_mask])) + self.assertTrue(torch.allclose(A_coo.grad.to_dense()[nz_mask], A_d.grad[nz_mask])) @parameterized.expand( [ @@ -306,7 +304,7 @@ def test_sparse_block_diag_csr(self, _, size, nnz): A_csr_block_diag = sparse_block_diag(*A_csr) Ad_block_diag = torch.block_diag(*A_d) self.assertTrue(torch.equal(A_csr_block_diag.to_dense(), Ad_block_diag)) - + @parameterized.expand( [ ("1x4x4_12n", torch.Size([1, 4, 4]), 12), # passes for this case @@ -315,52 +313,57 @@ def test_sparse_block_diag_csr(self, _, size, nnz): ] ) def test_sparse_block_diag_csr_backward_dense_grad(self, _, size, nnz): - self.skipTest(reason="This test is not passing due to a BUG in PyTorch, which tries to differential CSR indices and results in: RuntimeError: isDifferentiableType(variable.scalar_type())") + self.skipTest( + reason="This test is not passing due to a BUG in PyTorch, which tries to differential CSR indices and results in: RuntimeError: isDifferentiableType(variable.scalar_type())" + ) A_csr = generate_random_sparse_csr_matrix(size, nnz, device=self.device) A_d = A_csr.detach().clone().to_dense() - + A_csr.requires_grad_(True) A_d.requires_grad_(True) - + A_csr_block_diag = sparse_block_diag(*A_csr) A_d_block_diag = torch.block_diag(*A_d) - + A_csr_block_diag.sum().backward() A_d_block_diag.sum().backward() - + nz_mask = A_csr.grad.to_dense() != 0.0 - + self.assertTrue(torch.allclose(A_csr.grad.to_dense()[nz_mask], A_d.grad[nz_mask])) - + @parameterized.expand( [ - ("1x4x4_12n", torch.Size([1, 4, 4]), 12), # Fails with is_conitous error - ("4x4x4_12n", torch.Size([4, 4, 4]), 12), # Fails with differentiably error + ("1x4x4_12n", torch.Size([1, 4, 4]), 12), # Fails with is_conitous error + ("4x4x4_12n", torch.Size([4, 4, 4]), 12), # Fails with differentiably error ("6x8x14_32n", torch.Size([6, 8, 14]), 32), ] ) def test_sparse_block_diag_csr_backward(self, _, size, nnz): - self.skipTest(reason="This test is not passing due to a BUG in PyTorch, which tries to differential CSR indices and results in: RuntimeError: isDifferentiableType(variable.scalar_type())") + self.skipTest( + reason="This test is not passing due to a BUG in PyTorch, which tries to differential CSR indices and results in: RuntimeError: isDifferentiableType(variable.scalar_type())" + ) A_csr = generate_random_sparse_csr_matrix(size, nnz, device=self.device) A_d = A_csr.detach().clone().to_dense() - + A_csr.requires_grad_(True) A_d.requires_grad_(True) - + A_csr_block_diag = sparse_block_diag(*A_csr) A_d_block_diag = torch.block_diag(*A_d) - + # generate a sparse CSR tensor of the same sparsity pattern as A_csr_block_diag, but all values are unique: - grad_output = torch.sparse_csr_tensor(crow_indices=A_csr_block_diag.crow_indices(), - col_indices=A_csr_block_diag.col_indices(), - values=torch.arange(A_csr_block_diag._nnz(), - dtype=torch.float), - size=A_csr_block_diag.shape).to(self.device) + grad_output = torch.sparse_csr_tensor( + crow_indices=A_csr_block_diag.crow_indices(), + col_indices=A_csr_block_diag.col_indices(), + values=torch.arange(A_csr_block_diag._nnz(), dtype=torch.float), + size=A_csr_block_diag.shape, + ).to(self.device) # set the gradient manually A_csr_block_diag.backward(grad_output) A_d_block_diag.backward(grad_output.to_dense()) - + nz_mask = A_csr.grad.to_dense() != 0.0 self.assertTrue(torch.allclose(A_csr.grad.to_dense()[nz_mask], A_d.grad[nz_mask])) @@ -480,8 +483,8 @@ def test_non_square_tensor_csr(self): tensor2 = torch.randn(3, 2).to_sparse_csr().to(device=self.device) result = sparse_block_diag(tensor1, tensor2) self.assertEqual(result.shape, torch.Size([8, 9])) - - + + @parameterized_class( ( "name", @@ -499,10 +502,9 @@ class TestSparseBlockDiaSplit(unittest.TestCase): def setUp(self) -> None: if not torch.cuda.is_available() and self.device == torch.device("cuda"): self.skipTest(f"Skipping {self.__class__.__name__} since CUDA is not available") - - + @parameterized.expand( - [ + [ ("1x4x4_12n", torch.Size([1, 4, 4]), 12), ("4x4x4_12n", torch.Size([4, 4, 4]), 12), ("6x8x14_32n", torch.Size([6, 8, 14]), 32), @@ -511,14 +513,13 @@ def setUp(self) -> None: def test_coo(self, _, shape, nnz): A_coo = generate_random_sparse_coo_matrix(shape, nnz, device=self.device) A_coo_block_diag = sparse_block_diag(*A_coo) - shapes = shape[0]*(shape[-2:],) + shapes = shape[0] * (shape[-2:],) A_coo_block_diag_split = sparse_block_diag_split(A_coo_block_diag, *shapes) for i, A in enumerate(A_coo): self.assertTrue(torch.equal(A.to_dense(), A_coo_block_diag_split[i].to_dense())) - - + @parameterized.expand( - [ + [ ("1x4x4_12n", torch.Size([1, 4, 4]), 12), ("4x4x4_12n", torch.Size([4, 4, 4]), 12), ("6x8x14_32n", torch.Size([6, 8, 14]), 32), @@ -527,7 +528,7 @@ def test_coo(self, _, shape, nnz): def test_csr(self, _, shape, nnz): A_csr = generate_random_sparse_csr_matrix(shape, nnz, device=self.device) A_csr_block_diag = sparse_block_diag(*A_csr) - shapes = shape[0]*(shape[-2:],) + shapes = shape[0] * (shape[-2:],) A_csr_block_diag_split = sparse_block_diag_split(A_csr_block_diag, *shapes) for i, A in enumerate(A_csr): self.assertTrue(torch.equal(A.to_dense(), A_csr_block_diag_split[i].to_dense())) diff --git a/torchsparsegradutils/jax/jax_sparse_solve.py b/torchsparsegradutils/jax/jax_sparse_solve.py index 4371110..1920294 100644 --- a/torchsparsegradutils/jax/jax_sparse_solve.py +++ b/torchsparsegradutils/jax/jax_sparse_solve.py @@ -11,11 +11,11 @@ def sparse_solve_j4t(A, B, solve=None, transpose_solve=None): - if solve == None or transpose_solve == None: + if solve is None or transpose_solve is None: # Use bicgstab by default - if solve == None: + if solve is None: solve = jax.scipy.sparse.linalg.bicgstab - if transpose_solve == None: + if transpose_solve is None: transpose_solve = lambda A, B: jax.scipy.sparse.linalg.bicgstab(A.transpose(), B) return SparseSolveJ4T.apply(A, B, solve, transpose_solve) diff --git a/torchsparsegradutils/sparse_lstsq.py b/torchsparsegradutils/sparse_lstsq.py index ed8db53..1bf5cf1 100644 --- a/torchsparsegradutils/sparse_lstsq.py +++ b/torchsparsegradutils/sparse_lstsq.py @@ -2,12 +2,12 @@ def sparse_generic_lstsq(A, B, lstsq=None, transpose_lstsq=None): - if lstsq == None or transpose_lstsq == None: + if lstsq is None or transpose_lstsq is None: from .utils import lsmr - if lstsq == None: + if lstsq is None: lstsq = lambda AA, BB: lsmr(AA, BB)[0] - if transpose_lstsq == None: + if transpose_lstsq is None: # MINRES assumes A to be symmetric -> no need to transpose A transpose_lstsq = lambda AA, BB: lsmr(torch.adjoint(AA), BB, AA)[0] diff --git a/torchsparsegradutils/sparse_matmul.py b/torchsparsegradutils/sparse_matmul.py index 6d301d9..77b23f2 100644 --- a/torchsparsegradutils/sparse_matmul.py +++ b/torchsparsegradutils/sparse_matmul.py @@ -6,35 +6,35 @@ def sparse_mm(A, B): """ Performs a matrix multiplication between a sparse matrix A and a dense matrix B, preserving the sparsity of the gradient with respect to A, permitting sparse backpropagation. - + The sparse matrix A can be in either COO or CSR format, and is expected - to be 2-dimensional, with an optional leading batch dimension. The dense matrix B - should also be 2-dimensional, with a matching optional leading batch dimension. + to be 2-dimensional, with an optional leading batch dimension. The dense matrix B + should also be 2-dimensional, with a matching optional leading batch dimension. The batch size must be the same for both A and B. - + Args: A (torch.Tensor): The sparse matrix in COO or CSR format. B (torch.Tensor): The dense matrix. - + Returns: torch.Tensor: The result of the matrix multiplication. """ - + if not isinstance(A, torch.Tensor) or not isinstance(B, torch.Tensor): raise ValueError("Both A and B should be instances of torch.Tensor") if A.dim() < 2 or B.dim() < 2: raise ValueError("Both A and B should be at least 2-dimensional tensors") - + if A.dim() != B.dim(): raise ValueError("Both A and B should have the same number of dimensions") - + if A.layout not in {torch.sparse_coo, torch.sparse_csr}: raise ValueError("A should be in either COO or CSR format") - + if A.dim() == 3 and A.size(0) != B.size(0): raise ValueError("If A and B have a leading batch dimension, they should have the same batch size") - + return SparseMatMul.apply(A, B) @@ -54,22 +54,22 @@ def forward(ctx, A, B): ctx.batch_size = B.size()[0] if B.dim() == 3 else None ctx.A_shape = A.size() # (b), n, m ctx.B_shape = B.size() # (b), m, p - + grad_flag = A.requires_grad or B.requires_grad - + A, B = A.detach(), B.detach() - + if ctx.batch_size is not None: A = sparse_block_diag(*A) B = torch.cat([*B]) - + x = torch.sparse.mm(A, B) - + ctx.save_for_backward(A, B) - + if ctx.batch_size is not None: x = x.view(ctx.batch_size, ctx.A_shape[-2], ctx.B_shape[-1]) - + x.requires_grad = grad_flag return x @@ -98,7 +98,7 @@ def backward(ctx, grad): ) else: raise ValueError(f"Unsupported layout: {A.layout}") - + if ctx.batch_size is not None: grad = torch.cat([*grad]) @@ -117,14 +117,14 @@ def backward(ctx, grad): # Now compute the dense gradient with respect to B gradB = torch.sparse.mm(A.t(), grad) - + if ctx.batch_size is not None: - shapes = ctx.A_shape[0]*(ctx.A_shape[-2:],) + shapes = ctx.A_shape[0] * (ctx.A_shape[-2:],) gradA = sparse_block_diag_split(gradA, *shapes) if A.layout == torch.sparse_coo: gradA = torch.stack([*gradA]) else: gradA = stack_csr([*gradA]) # NOTE: torch.stack does not work for csr tensors - + gradB = gradB.view(ctx.B_shape) return gradA, gradB diff --git a/torchsparsegradutils/sparse_solve.py b/torchsparsegradutils/sparse_solve.py index 02c9257..562d8dd 100644 --- a/torchsparsegradutils/sparse_solve.py +++ b/torchsparsegradutils/sparse_solve.py @@ -81,7 +81,7 @@ def backward(ctx, grad): xselect = x.index_select(0, A_col_idx) # x[j, :] if ctx.ut is True and torch.any(A_row_idx == A_col_idx): - raise ValueError(f"First input should be strictly triangular (i.e. unit diagonals is implicit)") + raise ValueError("First input should be strictly triangular (i.e. unit diagonals is implicit)") # Dot product: mgbx = mgradbselect * xselect @@ -96,12 +96,12 @@ def backward(ctx, grad): def sparse_generic_solve(A, B, solve=None, transpose_solve=None): - if solve == None or transpose_solve == None: + if solve is None or transpose_solve is None: from .utils import minres - if solve == None: + if solve is None: solve = minres - if transpose_solve == None: + if transpose_solve is None: # MINRES assumes A to be symmetric -> no need to transpose A transpose_solve = minres diff --git a/torchsparsegradutils/utils/__init__.py b/torchsparsegradutils/utils/__init__.py index 8cf1c63..789c046 100644 --- a/torchsparsegradutils/utils/__init__.py +++ b/torchsparsegradutils/utils/__init__.py @@ -2,7 +2,28 @@ from .minres import minres, MINRESSettings from .bicgstab import bicgstab, BICGSTABSettings from .lsmr import lsmr -from .utils import convert_coo_to_csr_indices_values, convert_coo_to_csr, sparse_block_diag, sparse_block_diag_split, stack_csr +from .utils import ( + convert_coo_to_csr_indices_values, + convert_coo_to_csr, + sparse_block_diag, + sparse_block_diag_split, + stack_csr, +) from .random_sparse import rand_sparse, rand_sparse_tri -__all__ = ["linear_cg", "LinearCGSettings", "minres", "MINRESSettings", "bicgstab", "BICGSTABSettings", "lsmr", "convert_coo_to_csr_indices_values", "convert_coo_to_csr", "sparse_block_diag", "rand_sparse", "rand_sparse_tri", "sparse_block_diag_split", "stack_csr"] \ No newline at end of file +__all__ = [ + "linear_cg", + "LinearCGSettings", + "minres", + "MINRESSettings", + "bicgstab", + "BICGSTABSettings", + "lsmr", + "convert_coo_to_csr_indices_values", + "convert_coo_to_csr", + "sparse_block_diag", + "rand_sparse", + "rand_sparse_tri", + "sparse_block_diag_split", + "stack_csr", +] diff --git a/torchsparsegradutils/utils/bicgstab.py b/torchsparsegradutils/utils/bicgstab.py index e2bde38..395acf3 100644 --- a/torchsparsegradutils/utils/bicgstab.py +++ b/torchsparsegradutils/utils/bicgstab.py @@ -118,7 +118,6 @@ def bicgstab( v = torch.zeros(n, dtype=res_dtype, device=res_device) while not finished: - beta = rho_next / rho * alpha / omega rho = rho_next @@ -181,7 +180,7 @@ def bicgstab( finished = True continue - converged = residNorm <= threshold + # converged = residNorm <= threshold # variable unused bestSolution = x residNorm = residNorm diff --git a/torchsparsegradutils/utils/lsmr.py b/torchsparsegradutils/utils/lsmr.py index b89c17e..979e2b9 100644 --- a/torchsparsegradutils/utils/lsmr.py +++ b/torchsparsegradutils/utils/lsmr.py @@ -199,7 +199,6 @@ def lsmr( # Main iteration loop. for itn in range(1, maxiter + 1): - # Perform the next step of the bidiagonalization to obtain the # next beta, u, alpha, v. These satisfy the relations # beta*u = a*v - alpha*u, @@ -287,7 +286,6 @@ def lsmr( # if itn % 10 == 0: if True: - # Compute norms for convergence testing. torch.abs(zetabar, out=normar) torch.norm(x, out=normx) diff --git a/torchsparsegradutils/utils/random_sparse.py b/torchsparsegradutils/utils/random_sparse.py index 17d9ec9..1617e26 100644 --- a/torchsparsegradutils/utils/random_sparse.py +++ b/torchsparsegradutils/utils/random_sparse.py @@ -2,10 +2,10 @@ utility functions for generating random sparse matrices NOTE: sparse COO tensors have indices tensor of size (ndim, nse) and with indices type torch.int64 -NOTE: Sparse CSR The index tensors crow_indices and col_indices should have element type either torch.int64 (default) or torch.int32. - If you want to use MKL-enabled matrix operations, use torch.int32. +NOTE: Sparse CSR The index tensors crow_indices and col_indices should have element type either torch.int64 (default) or torch.int32. + If you want to use MKL-enabled matrix operations, use torch.int32. This is as a result of the default linking of pytorch being with MKL LP64, which uses 32 bit integer indexing -NOTE: The batches of sparse CSR tensors are dependent: the number of specified elements in all batches must be the same. +NOTE: The batches of sparse CSR tensors are dependent: the number of specified elements in all batches must be the same. This somewhat artificial constraint allows efficient storage of the indices of different CSR batches. """ import warnings @@ -14,23 +14,48 @@ from torchsparsegradutils.utils.utils import convert_coo_to_csr_indices_values -def rand_sparse(size, nnz, layout=torch.sparse_coo, *, indices_dtype=torch.int64, values_dtype=torch.float32, device=torch.device("cpu")): +def rand_sparse( + size, + nnz, + layout=torch.sparse_coo, + *, + indices_dtype=torch.int64, + values_dtype=torch.float32, + device=torch.device("cpu"), +): if layout == torch.sparse_coo: - return generate_random_sparse_coo_matrix(size, nnz, indices_dtype=indices_dtype, values_dtype=values_dtype, device=device) + return generate_random_sparse_coo_matrix( + size, nnz, indices_dtype=indices_dtype, values_dtype=values_dtype, device=device + ) elif layout == torch.sparse_csr: - return generate_random_sparse_csr_matrix(size, nnz, indices_dtype=indices_dtype, values_dtype=values_dtype, device=device) + return generate_random_sparse_csr_matrix( + size, nnz, indices_dtype=indices_dtype, values_dtype=values_dtype, device=device + ) else: raise ValueError("Unsupported layout type. It should be either torch.sparse_coo or torch.sparse_csr") -def rand_sparse_tri(size, nnz, layout=torch.sparse_coo, *, upper=True, indices_dtype=torch.int64, values_dtype=torch.float32, device=torch.device("cpu")): +def rand_sparse_tri( + size, + nnz, + layout=torch.sparse_coo, + *, + upper=True, + indices_dtype=torch.int64, + values_dtype=torch.float32, + device=torch.device("cpu"), +): if layout == torch.sparse_coo: - return generate_random_sparse_strictly_triangular_coo_matrix(size, nnz, upper=upper, indices_dtype=indices_dtype, values_dtype=values_dtype, device=device) + return generate_random_sparse_strictly_triangular_coo_matrix( + size, nnz, upper=upper, indices_dtype=indices_dtype, values_dtype=values_dtype, device=device + ) elif layout == torch.sparse_csr: - return generate_random_sparse_strictly_triangular_csr_matrix(size, nnz, upper=upper, indices_dtype=indices_dtype, values_dtype=values_dtype, device=device) + return generate_random_sparse_strictly_triangular_csr_matrix( + size, nnz, upper=upper, indices_dtype=indices_dtype, values_dtype=values_dtype, device=device + ) else: raise ValueError("Unsupported layout type. It should be either torch.sparse_coo or torch.sparse_csr") - + def _gen_indices_2d_coo(nr, nc, nnz, *, dtype=torch.int64, device=torch.device("cpu")): """Generates nnz random unique coordinates in COO format. diff --git a/torchsparsegradutils/utils/utils.py b/torchsparsegradutils/utils/utils.py index d3f374e..d9921fb 100644 --- a/torchsparsegradutils/utils/utils.py +++ b/torchsparsegradutils/utils/utils.py @@ -1,5 +1,6 @@ import torch + def stack_csr(tensors, dim=0): """ Stacks a list of CSR tensors along the batch dimension. @@ -14,7 +15,7 @@ def stack_csr(tensors, dim=0): """ if not isinstance(tensors, (list, tuple)): raise TypeError("Expected a list of tensors, but got {}.".format(type(tensors))) - + if len(tensors) == 0: raise ValueError("Cannot stack empty list of tensors.") @@ -26,7 +27,7 @@ def stack_csr(tensors, dim=0): if not all([tensor.ndim == 2 for tensor in tensors]): raise ValueError("All tensors must be 2D.") - + crow_indices = torch.stack([tensor.crow_indices() for tensor in tensors], dim=dim) col_indices = torch.stack([tensor.col_indices() for tensor in tensors], dim=dim) values = torch.stack([tensor.values() for tensor in tensors], dim=dim) @@ -109,7 +110,7 @@ def convert_coo_to_csr_indices_values(coo_indices, num_rows, values=None): f"Row indices must be less than num_rows ({num_rows}). Got max row index {coo_indices[-2].max()}" ) - if values != None and values.shape[0] != coo_indices.shape[1]: + if values is not None and values.shape[0] != coo_indices.shape[1]: raise ValueError( f"Number of values ({values.shape[0]}) does not match number of indices ({coo_indices.shape[1]})" ) @@ -187,12 +188,12 @@ def sparse_block_diag(*sparse_tensors): """ Function to create a block diagonal sparse matrix from provided sparse tensors. This function is designed to replicate torch.block_diag(), but for sparse tensors, - but only supports 2D sparse tensors + but only supports 2D sparse tensors (whereas torch.block_diag() supports dense tensors of 0, 1 or 2 dimensions). Args: - *sparse_tensors (torch.Tensor): Variable length list of sparse tensors. All input tensors must either be all - sparse_coo or all sparse_csr format. The input sparse tensors must have exactly + *sparse_tensors (torch.Tensor): Variable length list of sparse tensors. All input tensors must either be all + sparse_coo or all sparse_csr format. The input sparse tensors must have exactly two sparse dimensions and no dense dimensions. Returns: @@ -206,42 +207,42 @@ def sparse_block_diag(*sparse_tensors): for i, sparse_tensor in enumerate(sparse_tensors): if not isinstance(sparse_tensor, torch.Tensor): - raise TypeError(f"TypeError: expected Tensor as element {i} in argument 0, but got {type(sparse_tensor).__name__}") + raise TypeError( + f"TypeError: expected Tensor as element {i} in argument 0, but got {type(sparse_tensor).__name__}" + ) if len(sparse_tensors) == 0: raise ValueError("At least one sparse tensor must be provided.") - + if all(sparse_tensor.layout == torch.sparse_coo for sparse_tensor in sparse_tensors): layout = torch.sparse_coo elif all(sparse_tensor.layout == torch.sparse_csr for sparse_tensor in sparse_tensors): layout = torch.sparse_csr else: raise ValueError("Sparse tensors must either be all sparse_coo or all sparse_csr.") - + if not all(sparse_tensor.sparse_dim() == 2 for sparse_tensor in sparse_tensors): raise ValueError("All sparse tensors must have two sparse dimensions.") - + if not all(sparse_tensor.dense_dim() == 0 for sparse_tensor in sparse_tensors): raise ValueError("All sparse tensors must have zero dense dimensions.") if len(sparse_tensors) == 1: return sparse_tensors[0] - + if layout == torch.sparse_coo: - row_indices_list = [] col_indices_list = [] values_list = [] num_row = 0 num_col = 0 - - for i, sparse_tensor in enumerate(sparse_tensors): - + + for i, sparse_tensor in enumerate(sparse_tensors): sparse_tensor = sparse_tensor.coalesce() if not sparse_tensor.is_coalesced() else sparse_tensor - + row_indices, col_indices = sparse_tensor.indices() - + # calculate block offsets # not in-place addition to avoid modifying the original tensor indices row_indices = row_indices + i * sparse_tensor.size()[-2] @@ -256,27 +257,24 @@ def sparse_block_diag(*sparse_tensors): num_row += sparse_tensor.size()[-2] num_col += sparse_tensor.size()[-1] - row_indices = torch.cat(row_indices_list) col_indices = torch.cat(col_indices_list) values = torch.cat(values_list) return torch.sparse_coo_tensor(torch.stack([row_indices, col_indices]), values, torch.Size([num_row, num_col])) - + elif layout == torch.sparse_csr: - crow_indices_list = [] col_indices_list = [] values_list = [] num_row = 0 num_col = 0 - + for i, sparse_tensor in enumerate(sparse_tensors): - crow_indices = sparse_tensor.crow_indices() col_indices = sparse_tensor.col_indices() - + # Calculate block offsets # not in-place addition to avoid modifying the original tensor indices if i > 0: @@ -298,13 +296,13 @@ def sparse_block_diag(*sparse_tensors): values = torch.cat(values_list) return torch.sparse_csr_tensor(crow_indices, col_indices, values, torch.Size([num_row, num_col])) - - + + def sparse_block_diag_split(sparse_block_diag_tensor, *shapes): """ Function to split a block diagonal sparse matrix into original sparse tensors. - - NOTE: Sparse COO tensors are assumed to already by coalesced. + + NOTE: Sparse COO tensors are assumed to already by coalesced. This is because newly created or indexed sparse COO tensors default to is_coalesced=False, and running coalesce() imposes an unnecessary performance penalty. @@ -325,18 +323,22 @@ def sparse_block_diag_split(sparse_block_diag_tensor, *shapes): layout = torch.sparse_csr else: raise ValueError("Input tensor format not supported. Only sparse_coo and sparse_csr are supported.") - + if not all(len(shape) == 2 for shape in shapes): raise ValueError("All shapes must be two-dimensional.") - + if layout == torch.sparse_coo: tensors = [] start_row = 0 start_col = 0 current_val_offset = 0 - sparse_block_diag_tensor = sparse_block_diag_tensor.coalesce() if not sparse_block_diag_tensor.is_coalesced() else sparse_block_diag_tensor - + sparse_block_diag_tensor = ( + sparse_block_diag_tensor.coalesce() + if not sparse_block_diag_tensor.is_coalesced() + else sparse_block_diag_tensor + ) + row_indices, col_indices = sparse_block_diag_tensor.indices() values = sparse_block_diag_tensor.values() @@ -364,7 +366,7 @@ def sparse_block_diag_split(sparse_block_diag_tensor, *shapes): start_row = 0 current_col_offset = 0 current_val_offset = 0 - + crow_indices = sparse_block_diag_tensor.crow_indices() col_indices = sparse_block_diag_tensor.col_indices() values = sparse_block_diag_tensor.values() @@ -377,14 +379,13 @@ def sparse_block_diag_split(sparse_block_diag_tensor, *shapes): # Find the starting and ending points in crow_indices start_ptr = crow_indices[start_row] - end_ptr = crow_indices[start_row + shape[0]] # Apply the pointers to get the sub-block indices and values - col_indices_sub = col_indices[current_val_offset:current_val_offset+values_count] - current_col_offset - values_sub = values[current_val_offset:current_val_offset+values_count] + col_indices_sub = col_indices[current_val_offset : current_val_offset + values_count] - current_col_offset + values_sub = values[current_val_offset : current_val_offset + values_count] # Create the sub-block crow_indices - crow_indices_sub = crow_indices[start_row:start_row + shape[0] + 1] - start_ptr + crow_indices_sub = crow_indices[start_row : start_row + shape[0] + 1] - start_ptr # Construct the sub-block as a CSR tensor tensor_sub = torch.sparse_csr_tensor(crow_indices_sub, col_indices_sub, values_sub, (rows, cols)) @@ -394,4 +395,4 @@ def sparse_block_diag_split(sparse_block_diag_tensor, *shapes): current_col_offset += cols current_val_offset += values_count - return tuple(tensors) \ No newline at end of file + return tuple(tensors)