Skip to content

Commit

Permalink
Merge pull request #42 from cai4cai/tb_distributions
Browse files Browse the repository at this point in the history
Sparse Multivariate Normal Distribution
  • Loading branch information
theo-barfoot authored Jun 22, 2023
2 parents 0f0b2c8 + b1f7809 commit 5137ad6
Show file tree
Hide file tree
Showing 14 changed files with 863 additions and 53 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
A collection of utility functions to work with PyTorch sparse tensors. This is work-in-progress, here be dragons.

Currenly available features with backprop include:
- Memory efficient sparse mm (workaround for https://github.com/pytorch/pytorch/issues/41128)
- Sparse triangular solver (see discussion in https://github.com/pytorch/pytorch/issues/87358)
- Memory efficient sparse mm with batch support (workaround for https://github.com/pytorch/pytorch/issues/41128)
- Sparse triangular solver with batch support (see discussion in https://github.com/pytorch/pytorch/issues/87358)
- Generic sparse linear solver (requires a non-differentiable backbone sparse solver)
- Generic sparse linear least-squares solver (requires a non-differentiable backbone sparse linear least-squares solver)
- Wrappers around [cupy sparse solvers](https://docs.cupy.dev/en/stable/reference/scipy_sparse_linalg.html#solving-linear-problems) (see discussion in https://github.com/pytorch/pytorch/issues/69538)
- Wrappers around [jax sparse solvers](https://jax.readthedocs.io/en/latest/jax.scipy.html#module-jax.scipy.sparse.linalg)
- Sparse multivariate normal distribution with sparse covariance and precision parameterisation, with reparameterised sampling (rsample)

Additional backbone solvers implemented in pytorch with no additional dependencies include:
- BICGSTAB (ported from [pykrylov](https://github.com/PythonOptimizers/pykrylov))
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ scipy
jax[cpu]
parameterized
pytest
pytest-rerunfailures
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def readme():

setuptools.setup(
name="torchsparsegradutils",
version="0.0.1",
version="0.0.2",
description="A collection of utility functions to work with PyTorch sparse tensors",
long_description=readme(),
long_description_content_type="text/markdown",
Expand Down
247 changes: 247 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
import torch
import pytest

from torch.distributions.multivariate_normal import _batch_mv
from torchsparsegradutils.distributions.sparse_multivariate_normal import _batch_sparse_mv
from torchsparsegradutils.distributions import SparseMultivariateNormal
from torchsparsegradutils.utils import rand_sparse_tri
from torchsparsegradutils import sparse_mm, sparse_triangular_solve

# Identify Testing Parameters
DEVICES = [torch.device("cpu")]
if torch.cuda.is_available():
DEVICES.append(torch.device("cuda"))

TEST_DATA = [
# name, batch size, event size, spartsity
("unbat", None, 4, 0.5),
("bat", 4, 4, 0.5),
("bat2", 4, 64, 0.01),
]

INDEX_DTYPES = [torch.int32, torch.int64]
VALUE_DTYPES = [torch.float32, torch.float64]
SPASRE_LAYOUTS = [torch.sparse_coo, torch.sparse_csr]

# DISTRIBUTIONS = [SparseMultivariateNormal]


# Define Test Names:
def data_id(sizes):
return sizes[0]


def device_id(device):
return str(device)


def dtype_id(dtype):
return str(dtype).split(".")[-1]


def layout_id(layout):
return str(layout).split(".")[-1].split("_")[-1].upper()


# def dist_id(dist):
# return dist.__name__


# Define Fixtures


@pytest.fixture(params=TEST_DATA, ids=[data_id(d) for d in TEST_DATA])
def sizes(request):
return request.param


@pytest.fixture(params=VALUE_DTYPES, ids=[dtype_id(d) for d in VALUE_DTYPES])
def value_dtype(request):
return request.param


@pytest.fixture(params=INDEX_DTYPES, ids=[dtype_id(d) for d in INDEX_DTYPES])
def index_dtype(request):
return request.param


@pytest.fixture(params=DEVICES, ids=[device_id(d) for d in DEVICES])
def device(request):
return request.param


@pytest.fixture(params=SPASRE_LAYOUTS, ids=[layout_id(lay) for lay in SPASRE_LAYOUTS])
def layout(request):
return request.param


# @pytest.fixture(params=DISTRIBUTIONS, ids=[dist_id(d) for d in DISTRIBUTIONS])
# def distribution(request):
# return request.param


# Convenience functions:


def construct_distribution(sizes, layout, var, value_dtype, index_dtype, device, requires_grad=False):
_, batch_size, event_size, sparsity = sizes
loc = torch.randn(event_size, device=device, dtype=value_dtype, requires_grad=requires_grad)
diag = torch.rand(event_size, device=device, dtype=value_dtype, requires_grad=requires_grad)

tril_size = (batch_size, event_size, event_size) if batch_size else (event_size, event_size)
nnz = int(sparsity * event_size * (event_size + 1) / 2)

tril = rand_sparse_tri(
tril_size,
nnz,
layout,
upper=False,
strict=True,
indices_dtype=index_dtype,
values_dtype=value_dtype,
device=device,
)

tril.requires_grad = requires_grad

if var == "cov":
return SparseMultivariateNormal(loc, diag, scale_tril=tril)
elif var == "prec":
return SparseMultivariateNormal(loc, diag, precision_tril=tril)
else:
raise ValueError(f"tril must be one of 'cov' or 'prec', but got {tril}")


def check_covariance_within_tolerance(
covariance_test, covariance_ref, absolute_tolerance=None, relative_tolerance=None, desired_threshold=0
):
# Calculate absolute difference if absolute_tolerance is provided
if absolute_tolerance is not None:
abs_diff = torch.abs(covariance_test - covariance_ref)
abs_within_tolerance = abs_diff <= absolute_tolerance
else:
abs_diff = None
abs_within_tolerance = torch.ones_like(covariance_test, dtype=torch.bool)

# Calculate relative difference if relative_tolerance is provided
if relative_tolerance is not None:
if abs_diff is None:
abs_diff = torch.abs(covariance_test - covariance_ref)
rel_diff = abs_diff / torch.abs(covariance_ref)
rel_within_tolerance = rel_diff <= relative_tolerance
else:
rel_within_tolerance = torch.ones_like(covariance_test, dtype=torch.bool)

# Determine values within both absolute and relative tolerance
within_tolerance = abs_within_tolerance & rel_within_tolerance

# Calculate percentage within tolerance
percentage_within_tolerance = (within_tolerance.sum() / within_tolerance.numel()) * 100

# Check if the percentage meets the desired threshold
is_within_desired_threshold = percentage_within_tolerance >= desired_threshold

print(f"Percentage within tolerance: {percentage_within_tolerance.item():.2f}%")
print(f"Is within desired threshold? {is_within_desired_threshold}")

assert is_within_desired_threshold


# Define Tests


@pytest.mark.flaky(reruns=5) # probably not needed, but seems sensible for CI
def test_rsample_forward_cov(device, layout, sizes, value_dtype, index_dtype):
if layout == torch.sparse_coo and index_dtype == torch.int32:
pytest.skip("Sparse COO with int32 indices is not supported")

dist = construct_distribution(sizes, layout, "cov", value_dtype, index_dtype, device)
samples = dist.rsample((100000,))

scale_tril = dist.scale_tril.to_dense()
scale_tril = scale_tril + torch.eye(*dist.event_shape, dtype=scale_tril.dtype, device=scale_tril.device) # L matrix
diagonal = dist.diagonal # D matrix
# Compute covariance from LDL^T decomposition
covariance_ref = torch.matmul(scale_tril @ torch.diag_embed(diagonal), scale_tril.transpose(-1, -2))

assert torch.allclose(samples.mean(0), dist.loc, atol=0.1)

if len(samples.shape) == 2:
covariance_test = torch.cov(samples.T)
else:
covariance_test = torch.stack([torch.cov(sample.T) for sample in samples.permute(1, 0, 2)])

check_covariance_within_tolerance(covariance_test, covariance_ref, absolute_tolerance=0.1, desired_threshold=99.0)


@pytest.mark.flaky(reruns=5)
# NOTE: This test often failes, hence the flaky and reruns
def test_rsample_forward_prec(device, layout, sizes, value_dtype, index_dtype):
if layout == torch.sparse_coo and index_dtype == torch.int32:
pytest.skip("Sparse COO with int32 indices is not supported")

dist = construct_distribution(sizes, layout, "prec", value_dtype, index_dtype, device)
samples = dist.rsample((100000,))

precision_tril = dist.precision_tril.to_dense()
precision_tril = precision_tril + torch.eye(
*dist.event_shape, dtype=precision_tril.dtype, device=precision_tril.device
) # L matrix
diagonal = dist.diagonal # D matrix
# Compute precision matrix from LDL^T decomposition
precision_ref = torch.matmul(precision_tril @ torch.diag_embed(diagonal), precision_tril.transpose(-1, -2))
# Compute covariance from precision
covariance_ref = torch.linalg.inv(precision_ref)

assert torch.allclose(samples.mean(0), dist.loc, atol=0.1)

if len(samples.shape) == 2:
covariance_test = torch.cov(samples.T)
else:
covariance_test = torch.stack([torch.cov(sample.T) for sample in samples.permute(1, 0, 2)])

# NOTE: higher atol due to larger covariance values after inversion
# NOTE: lower threshold due to numerical instability of inversion
check_covariance_within_tolerance(covariance_test, covariance_ref, absolute_tolerance=1, desired_threshold=95.0)


def test_rsample_backward_cov(device, layout, sizes, value_dtype, index_dtype):
if layout == torch.sparse_coo and index_dtype == torch.int32:
pytest.skip("Sparse COO with int32 indices is not supported")

dist = construct_distribution(sizes, layout, "cov", value_dtype, index_dtype, device, requires_grad=True)
samples = dist.rsample((10,))

samples.sum().backward()


def test_rsample_backward_prec(device, layout, sizes, value_dtype, index_dtype):
if layout == torch.sparse_coo and index_dtype == torch.int32:
pytest.skip("Sparse COO with int32 indices is not supported")

dist = construct_distribution(sizes, layout, "prec", value_dtype, index_dtype, device, requires_grad=True)
samples = dist.rsample((10,))

samples.sum().backward()


BATCH_MV_TEST_DATA = [
# 3 = event_size, 4 = batch_size, 5 = sample_size
# bmat, bvec
(torch.randn(3, 3), torch.randn(3)),
(torch.randn(3, 3), torch.randn(5, 3)),
(torch.randn(4, 3, 3), torch.randn(4, 3)),
(torch.randn(4, 3, 3), torch.randn(5, 4, 3)),
]


@pytest.fixture(params=BATCH_MV_TEST_DATA)
def batch_mv_test_data(request):
return request.param


def test_sparse_batch_mv(batch_mv_test_data):
bmat, bvec = batch_mv_test_data
res_ref = _batch_mv(bmat, bvec)
res_test = _batch_sparse_mv(sparse_mm, bmat.to_sparse(), bvec)
assert torch.allclose(res_ref, res_test)
2 changes: 1 addition & 1 deletion tests/test_sparse_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def forward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, s
res_sparse = op_test(A, B) # both results are dense
res_dense = op_ref(Ad, B)

torch.allclose(res_sparse, res_dense, atol=ATOL, rtol=RTOL)
assert torch.allclose(res_sparse, res_dense, atol=ATOL, rtol=RTOL)


def backward_routine(op_test, op_ref, layout, device, value_dtype, index_dtype, shapes, is_backward=False):
Expand Down
Loading

0 comments on commit 5137ad6

Please sign in to comment.