Skip to content

Commit

Permalink
add check invariants and fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 6, 2023
1 parent b221088 commit 49ffb59
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 35 deletions.
29 changes: 11 additions & 18 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
generate_random_sparse_strictly_triangular_csr_matrix,
)

# https://pytorch.org/docs/stable/generated/torch.sparse.check_sparse_tensor_invariants.html#torch.sparse.check_sparse_tensor_invariants
torch.sparse.check_sparse_tensor_invariants.enable()

@parameterized_class(
(
Expand Down Expand Up @@ -48,7 +50,7 @@ def test_too_many_nnz(self):
]
)
def test_incompatible_indices_dtype(self, _, indices_dtype):
with self.assertWarns(UserWarning):
with self.assertRaises(ValueError):
generate_random_sparse_coo_matrix(torch.Size([4, 4]), 12, indices_dtype=indices_dtype)

# basic properties:
Expand All @@ -58,15 +60,12 @@ def test_device(self):

@parameterized.expand(
[
("int08", torch.int8),
("int16", torch.int16),
("int32", torch.int32),
("int64", torch.int64),
]
) # NOTE: only torch.int64 is supported for COO indices, all other dtypes are casted to torch.int64
def test_indices_dtype(self, _, indices_dtype):
A = generate_random_sparse_coo_matrix(torch.Size([4, 4]), 12, indices_dtype=indices_dtype, device=self.device)
self.assertEqual(A.indices().dtype, torch.int64)
self.assertEqual(A.indices().dtype, indices_dtype)

@parameterized.expand(
[
Expand All @@ -79,6 +78,7 @@ def test_values_dtype(self, _, values_dtype):
A = generate_random_sparse_coo_matrix(torch.Size([4, 4]), 12, values_dtype=values_dtype, device=self.device)
self.assertEqual(A.values().dtype, values_dtype)


@parameterized.expand(
[
("4x4", torch.Size([4, 4]), 12),
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_too_many_nnz(self):
]
)
def test_incompatible_indices_dtype(self, _, indices_dtype):
with self.assertWarns(UserWarning):
with self.assertRaises(ValueError):
generate_random_sparse_coo_matrix(torch.Size([4, 4]), 12, indices_dtype=indices_dtype)

# basic properties:
Expand All @@ -153,12 +153,10 @@ def test_device(self):

@parameterized.expand(
[
("int08", torch.int8),
("int16", torch.int16),
("int32", torch.int32),
("int64", torch.int64),
]
) # NOTE: All integer dtypes are supported for CSR indices, although below int32 is not recommended
) # NOTE: Only int32 and int64 are supported for CSR indices
def test_indices_dtype(self, _, indices_dtype):
A = generate_random_sparse_csr_matrix(torch.Size([4, 4]), 12, indices_dtype=indices_dtype, device=self.device)
self.assertEqual(A.crow_indices().dtype, indices_dtype)
Expand Down Expand Up @@ -241,7 +239,7 @@ def test_too_many_nnz(self):
]
)
def test_incompatible_indices_dtype(self, _, indices_dtype):
with self.assertWarns(UserWarning):
with self.assertRaises(ValueError):
generate_random_sparse_strictly_triangular_coo_matrix(torch.Size([4, 4]), 5, indices_dtype=indices_dtype)

# basic properties:
Expand All @@ -251,17 +249,14 @@ def test_device(self):

@parameterized.expand(
[
("int08", torch.int8),
("int16", torch.int16),
("int32", torch.int32),
("int64", torch.int64),
]
) # NOTE: only torch.int64 is supported for COO indices, all other dtypes are casted to torch.int64
def test_indices_dtype(self, _, indices_dtype):
A = generate_random_sparse_strictly_triangular_coo_matrix(
torch.Size([4, 4]), 5, indices_dtype=indices_dtype, device=self.device
)
self.assertEqual(A.indices().dtype, torch.int64)
self.assertEqual(A.indices().dtype, indices_dtype)

@parameterized.expand(
[
Expand Down Expand Up @@ -381,7 +376,7 @@ def test_too_many_nnz(self):
]
)
def test_incompatible_indices_dtype(self, _, indices_dtype):
with self.assertWarns(UserWarning):
with self.assertRaises(ValueError):
generate_random_sparse_strictly_triangular_csr_matrix(torch.Size([4, 4]), 5, indices_dtype=indices_dtype)

# basic properties:
Expand All @@ -391,12 +386,10 @@ def test_device(self):

@parameterized.expand(
[
("int08", torch.int8),
("int16", torch.int16),
("int32", torch.int32),
("int64", torch.int64),
]
) # NOTE: only torch.int64 is supported for COO indices, all other dtypes are casted to torch.int64
)
def test_indices_dtype(self, _, indices_dtype):
A = generate_random_sparse_strictly_triangular_csr_matrix(
torch.Size([4, 4]), 5, indices_dtype=indices_dtype, device=self.device
Expand Down
2 changes: 2 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
convert_coo_to_csr,
)

# https://pytorch.org/docs/stable/generated/torch.sparse.check_sparse_tensor_invariants.html#torch.sparse.check_sparse_tensor_invariants
torch.sparse.check_sparse_tensor_invariants.enable()

@parameterized_class(
(
Expand Down
26 changes: 10 additions & 16 deletions torchsparsegradutils/utils/random_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def generate_random_sparse_coo_matrix(
ValueError: Raised if size has less than 2 dimensions.
ValueError: Raised if size has more than 3 dimensions, as this implementation only supports 1 batch dimension.
ValueError: Raised if nnz is greater than the total number of elements (size[-2] * size[-3]).
UserWarning: Raised when `indices_dtype` is not `torch.int64`, as this is the only indices dtype supported for sparse COO tensors. Any other index dtype will be converted to `torch.int64`.
ValueError: Raised if indices_dtype is not torch.int64 for sparse COO tensors.
Returns:
torch.Tensor: Returns a sparse COO tensor of shape size with nnz non-zero elements.
Expand All @@ -73,10 +73,7 @@ def generate_random_sparse_coo_matrix(
raise ValueError("nnz must be less than or equal to nr * nc")

if indices_dtype != torch.int64:
warnings.warn(
f"Only indices of type torch.int64 supported for sparse COO tensors. Indices of type {indices_dtype} will be cast to torch.int64.",
UserWarning,
)
raise ValueError("indices_dtype must be torch.int64 for sparse COO tensors")

if len(size) == 2:
coo_indices = _gen_indices_2d_coo(size[-2], size[-1], nnz, dtype=indices_dtype, device=device)
Expand Down Expand Up @@ -111,7 +108,7 @@ def generate_random_sparse_csr_matrix(
ValueError: Raised if size has less than 2 dimensions.
ValueError: Raised if size has more than 3 dimensions, as this implementation only supports 1 batch dimension.
ValueError: Raised if nnz is greater than the total number of elements (size[-2] * size[-3]).
UserWarning: Raised when `indices_dtype` has a bit depth less than `torch.int32`, as this is not recommended.
ValueError: Raised if indices_dtype is not torch.int64 or torch.int32, as these are the only indices dtypes supported for sparse CSR tensors.
Returns:
torch.Tensor: Returns a sparse CSR tensor of shape size with nnz non-zero elements.
Expand All @@ -125,7 +122,7 @@ def generate_random_sparse_csr_matrix(
raise ValueError("nnz must be less than or equal to nr * nc")

if (indices_dtype != torch.int64) and (indices_dtype != torch.int32):
warnings.warn(f"A bit depth of less than torch.int32 is not recommended for sparse CSR tensors", UserWarning)
raise ValueError("indices_dtype must be torch.int64 or torch.int32 for sparse CSR tensors")

coo_indices = _gen_indices_2d_coo(size[-2], size[-1], nnz, dtype=indices_dtype, device=device)
crow_indices, col_indices, _ = convert_coo_to_csr_indices_values(coo_indices, size[-2], values=None)
Expand All @@ -135,7 +132,7 @@ def generate_random_sparse_csr_matrix(
else:
crow_indices = crow_indices.repeat(size[0], 1)
col_indices = col_indices.repeat(size[0], 1)
values = torch.rand(nnz * size[0], dtype=values_dtype, device=device)
values = torch.rand((size[0], nnz), dtype=values_dtype, device=device)

return torch.sparse_csr_tensor(crow_indices, col_indices, values, size, device=device)

Expand Down Expand Up @@ -187,7 +184,7 @@ def generate_random_sparse_strictly_triangular_coo_matrix(
ValueError: Raised if size has more than 3 dimensions, as this implementation only supports 1 batch dimension.
ValueError: Raised if size is not a square matrix (n, n) or batched square matrix (b, n, n).
ValueError: Raised if nnz is greater than (n * n-1)/2, where n is the number of rows or columns.
UserWarning: Raised when `indices_dtype` is not `torch.int64`, as this is the only indices dtype supported for sparse COO tensors. Any other index dtype will be converted to `torch.int64`.
ValueError: Raised if indices_dtype is not torch.int64 for sparse COO tensors.
Returns:
torch.Tensor: Returns a square strictly upper or lower sparse COO tensor of shape size with nnz non-zero elements.
Expand All @@ -204,10 +201,7 @@ def generate_random_sparse_strictly_triangular_coo_matrix(
raise ValueError("nnz must be less than or equal to (n * n-1)/2, where n is the number of rows or columns")

if indices_dtype != torch.int64:
warnings.warn(
f"Only indices of type torch.int64 supported for sparse COO tensors. Indices of type {indices_dtype} will be cast to torch.int64.",
UserWarning,
)
raise ValueError("indices_dtype must be torch.int64 for sparse COO tensors")

if len(size) == 2:
coo_indices = _gen_indices_2d_coo_strictly_tri(size[-2], nnz, upper=upper, dtype=indices_dtype, device=device)
Expand Down Expand Up @@ -247,7 +241,7 @@ def generate_random_sparse_strictly_triangular_csr_matrix(
ValueError: Raised if size has more than 3 dimensions, as this implementation only supports 1 batch dimension.
ValueError: Raised if size is not a square matrix (n, n) or batched square matrix (b, n, n).
ValueError: Raised if nnz is greater than (n * n-1)/2, where n is the number of rows or columns.
UserWarning: Raised when `indices_dtype` has a bit depth less than `torch.int32`, as this is not recommended.
ValueError: Raised if indices_dtype is not torch.int64 or torch.int32, as these are the only indices dtypes supported for sparse CSR tensors.
Returns:
torch.Tensor: Returns a square strictly upper or lower sparse CSR tensor of shape size with nnz non-zero elements.
Expand All @@ -264,7 +258,7 @@ def generate_random_sparse_strictly_triangular_csr_matrix(
raise ValueError("nnz must be less than or equal to (n * n-1)/2, where n is the number of rows or columns")

if (indices_dtype != torch.int64) and (indices_dtype != torch.int32):
warnings.warn(f"A bit depth of less than torch.int32 is not recommended for sparse CSR tensors", UserWarning)
raise ValueError("indices_dtype must be torch.int64 or torch.int32 for sparse CSR tensors")

coo_indices = _gen_indices_2d_coo_strictly_tri(size[-2], nnz, upper=upper, dtype=indices_dtype, device=device)
crow_indices, col_indices, _ = convert_coo_to_csr_indices_values(coo_indices, size[-2], values=None)
Expand All @@ -274,6 +268,6 @@ def generate_random_sparse_strictly_triangular_csr_matrix(
else:
crow_indices = crow_indices.repeat(size[0], 1)
col_indices = col_indices.repeat(size[0], 1)
values = torch.rand(nnz * size[0], dtype=values_dtype, device=device)
values = torch.rand((size[0], nnz), dtype=values_dtype, device=device)

return torch.sparse_csr_tensor(crow_indices, col_indices, values, size, device=device)
2 changes: 1 addition & 1 deletion torchsparsegradutils/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _sort_coo_indices(indices):
torch.Tensor: A permutation tensor that contains the indices in the original tensor that give the sorted tensor.
"""
indices_sorted, permutation = torch.unique(indices, dim=-1, sorted=True, return_inverse=True)
return indices_sorted, torch.argsort(permutation)
return indices_sorted.contiguous(), torch.argsort(permutation)


def _compress_row_indices(row_indices, num_rows):
Expand Down

0 comments on commit 49ffb59

Please sign in to comment.