Skip to content

Commit

Permalink
black and flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 14, 2023
1 parent ff6ea18 commit cb06141
Show file tree
Hide file tree
Showing 16 changed files with 264 additions and 166 deletions.
4 changes: 2 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[flake8]
ignore = E203, E266, E501, W503, F403, F401
ignore = E203, E266, E501, W503, F403, F401, E731, E402
max-line-length = 120
max-complexity = 10
max-complexity = 25
select = B,C,E,F,W,T4,B9
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
torch-version: ["1.13.1", "2.0.1"]

steps:
- uses: actions/checkout@v3
Expand All @@ -36,8 +37,7 @@ jobs:
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
flake8 . --count --max-complexity=10 --max-line-length=120 --statistics
flake8 . --count --show-source --statistics
- name: Check code formatting with black
run: |
black --check .
Expand Down
7 changes: 6 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,9 @@ repos:
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
- id: flake8

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
4 changes: 2 additions & 2 deletions tests/test_lsmr.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def setUp(self):

def test_unchanged_x0(self):
# x, istop, itn, normr, normar, normA, condA, normx = self.returnValuesX0
x = self.returnValuesX0[0]
# x = self.returnValuesX0[0] # variable unused
self.assertTrue(torch.allclose(self.x00, self.x0, atol=1e-3, rtol=1e-4))

def testNormr(self):
Expand Down Expand Up @@ -283,7 +283,7 @@ def lsmrtest(m, n, damp, dtype, device):
btol = 1.0e-7
conlim = 1.0e10
itnlim = 10 * n
show = 1
# show = 1 # variable unused

# x, istop, itn, normr, normar, norma, conda, normx \
# = lsmr(A, b, damp, atol, btol, conlim, itnlim, show)
Expand Down
110 changes: 79 additions & 31 deletions tests/test_sparse_matmul.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import pytest

from torchsparsegradutils import sparse_mm#, sparse_bmm
from torchsparsegradutils import sparse_mm # , sparse_bmm
from torchsparsegradutils.utils import rand_sparse, rand_sparse_tri

# Identify Testing Parameters
Expand All @@ -11,129 +11,177 @@

TEST_DATA = [
# name A_shape, B_shape, A_nnz
("unbat", (4, 6), (6, 2), 8), # unbatched
("unbat", (8, 16), (16, 10), 32), # -
("unbat", (7, 4), (4, 9), 14), # -

("bat", (1, 4, 6), (1, 6, 2), 8), # batched
("unbat", (4, 6), (6, 2), 8), # unbatched
("unbat", (8, 16), (16, 10), 32), # -
("unbat", (7, 4), (4, 9), 14), # -
("bat", (1, 4, 6), (1, 6, 2), 8), # batched
("bat", (4, 8, 16), (4, 16, 10), 32), # -
("bat", (11, 7, 4), (11, 4, 9), 14), # -

("bat", (11, 7, 4), (11, 4, 9), 14), # -
]

INDEX_DTYPES = [torch.int32, torch.int64]
VALUE_DTYPES = [torch.float32, torch.float64]

ATOL=1e-6 # relaxed tolerance to allow for float32
RTOL=1e-4
ATOL = 1e-6 # relaxed tolerance to allow for float32
RTOL = 1e-4


# Define Test Names:
def data_id(shapes):
return shapes[0]


def device_id(device):
return str(device)


def dtype_id(dtype):
return str(dtype).split('.')[-1]
return str(dtype).split(".")[-1]


# Define Fixtures


@pytest.fixture(params=TEST_DATA, ids=[data_id(d) for d in TEST_DATA])
def shapes(request):
return request.param


@pytest.fixture(params=VALUE_DTYPES, ids=[dtype_id(d) for d in VALUE_DTYPES])
def value_dtype(request):
return request.param


@pytest.fixture(params=INDEX_DTYPES, ids=[dtype_id(d) for d in INDEX_DTYPES])
def index_dtype(request):
return request.param


@pytest.fixture(params=DEVICES, ids=[device_id(d) for d in DEVICES])
def device(request):
return request.param


# Define Tests


def forward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, shapes):
if index_dtype == torch.int32 and layout is torch.sparse_coo:
pytest.skip("Skipping test as sparse COO tensors with int32 indices are not supported")

_, A_shape, B_shape, A_nnz = shapes
A = rand_sparse(A_shape, A_nnz, layout, indices_dtype=index_dtype, values_dtype=value_dtype, device=device)
B = torch.rand(*B_shape, dtype=value_dtype, device=device)
Ad = A.to_dense()

res_sparse = op_test(A, B) # both results are dense
res_dense = op_ref(Ad, B)

torch.allclose(res_sparse, res_dense, atol=ATOL, rtol=RTOL)


def backward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, shapes, is_backward=False):
if index_dtype == torch.int32 and layout is torch.sparse_coo:
pytest.skip("Skipping test as sparse COO tensors with int32 indices are not supported")

_, A_shape, B_shape, A_nnz = shapes
As1 = rand_sparse(A_shape, A_nnz, layout, indices_dtype=index_dtype, values_dtype=value_dtype, device=device)
Ad2 = As1.detach().clone().to_dense() # detach and clone to create seperate graph

Bd1 = torch.rand(*B_shape, dtype=value_dtype, device=device)
Bd2 = Bd1.detach().clone()

As1.requires_grad_()
Ad2.requires_grad_()
Bd1.requires_grad_()
Bd2.requires_grad_()

res1 = op_test(As1, Bd1) # both results are dense
res2 = op_ref(Ad2, Bd2)

# Generate random gradients for the backward pass
grad_output = torch.rand_like(res1, dtype=value_dtype, device=device)

res1.backward(grad_output)
res2.backward(grad_output)

nz_mask = As1.grad.to_dense() != 0.0

assert torch.allclose(As1.grad.to_dense()[nz_mask], Ad2.grad[nz_mask], atol=ATOL, rtol=RTOL)
assert torch.allclose(Bd1.grad, Bd2.grad, atol=ATOL, rtol=RTOL)


def test_sparse_mm_forward_result_coo(device, value_dtype, index_dtype, shapes):
forward_routine(sparse_mm, torch.matmul, torch.sparse_coo, device, value_dtype, index_dtype, shapes)


def test_sparse_mm_forward_result_csr(device, value_dtype, index_dtype, shapes):
forward_routine(sparse_mm, torch.matmul, torch.sparse_csr, device, value_dtype, index_dtype, shapes)


def test_sparse_mm_backward_result_coo(device, value_dtype, index_dtype, shapes):
backward_routine(sparse_mm, torch.matmul, torch.sparse_coo, device, value_dtype, index_dtype, shapes, is_backward=True)
backward_routine(
sparse_mm, torch.matmul, torch.sparse_coo, device, value_dtype, index_dtype, shapes, is_backward=True
)


def test_sparse_mm_backward_result_csr(device, value_dtype, index_dtype, shapes):
backward_routine(sparse_mm, torch.matmul, torch.sparse_csr, device, value_dtype, index_dtype, shapes, is_backward=True)
backward_routine(
sparse_mm, torch.matmul, torch.sparse_csr, device, value_dtype, index_dtype, shapes, is_backward=True
)


# Additional Testing Parameters
BAD_TEST_DATA = [
# name, A, B, expected_error, error_msg
("bad_tensor", 5, torch.rand(6, 2), ValueError, "Both A and B should be instances of torch.Tensor"),
("bad_dim_A", torch.tensor([0,1]).to_sparse(), torch.rand(6, 2), ValueError, "Both A and B should be at least 2-dimensional tensors"),
("bad_dim_B", torch.rand(4, 6).to_sparse(), torch.rand(6), ValueError, "Both A and B should be at least 2-dimensional tensors"),
("bad_dim_mismatch", torch.rand(4, 6).to_sparse(), torch.rand(1, 6, 2), ValueError, "Both A and B should have the same number of dimensions"),
("bad_format", torch.rand(4, 6).to_sparse_csc(), torch.rand(6, 2), ValueError, "A should be in either COO or CSR format"),
("bad_batch", torch.stack([torch.rand(4, 6).to_sparse(), torch.rand(4, 6).to_sparse()]), torch.rand(1, 6, 2), ValueError, "If A and B have a leading batch dimension, they should have the same batch size"),
("bad_tensor", 5, torch.rand(6, 2), ValueError, "Both A and B should be instances of torch.Tensor"),
(
"bad_dim_A",
torch.tensor([0, 1]).to_sparse(),
torch.rand(6, 2),
ValueError,
"Both A and B should be at least 2-dimensional tensors",
),
(
"bad_dim_B",
torch.rand(4, 6).to_sparse(),
torch.rand(6),
ValueError,
"Both A and B should be at least 2-dimensional tensors",
),
(
"bad_dim_mismatch",
torch.rand(4, 6).to_sparse(),
torch.rand(1, 6, 2),
ValueError,
"Both A and B should have the same number of dimensions",
),
(
"bad_format",
torch.rand(4, 6).to_sparse_csc(),
torch.rand(6, 2),
ValueError,
"A should be in either COO or CSR format",
),
(
"bad_batch",
torch.stack([torch.rand(4, 6).to_sparse(), torch.rand(4, 6).to_sparse()]),
torch.rand(1, 6, 2),
ValueError,
"If A and B have a leading batch dimension, they should have the same batch size",
),
]


# Additional Fixture
@pytest.fixture(params=BAD_TEST_DATA, ids=[data_id(d) for d in BAD_TEST_DATA])
def bad_inputs(request):
return request.param


# Additional Test
def test_sparse_mm_error(bad_inputs):
_, A, B, expected_error, error_msg = bad_inputs
with pytest.raises(expected_error) as e:
sparse_mm(A, B)
assert str(e.value) == error_msg
assert str(e.value) == error_msg
2 changes: 1 addition & 1 deletion tests/test_sparse_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_solver_gradient_csr_tril(self):
def test_solver_non_triangular_error(self):
"""Test to check that solver throws an ValueError if a diagonal is specified in input
but unitriangular=True in solver arguments"""
if self.unitriangular == False: # statement to stop test running in SparseUnitTriangularSolveTest
if self.unitriangular is False: # statement to stop test running in SparseUnitTriangularSolveTest
As1 = self.As_csr_triu.detach().clone()
As1.requires_grad = True
Bd1 = self.Bd.detach().clone()
Expand Down
Loading

0 comments on commit cb06141

Please sign in to comment.