From d00e70d7c6e1d7683d00292a1d8177374285478c Mon Sep 17 00:00:00 2001 From: theo-barfoot Date: Fri, 16 Jun 2023 13:02:02 +0100 Subject: [PATCH] spmvn cov sampling good --- tests/test_distributions.py | 27 +++++++++++++------ .../sparse_multivariate_normal.py | 19 +++++++------ 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 84181a9..1a6c280 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -126,19 +126,30 @@ def test_rsample_forward(device, layout, var, sizes, value_dtype, index_dtype): pytest.skip("Sparse COO with int32 indices is not supported") dist = construct_distribution(sizes, layout, var, value_dtype, index_dtype, device) - samples = dist.rsample((10000,)) + samples = dist.rsample((100000,)) if var == "cov": - scale_tril = dist.scale_tril.to_dense() - # scale_tril_with_diag = scale_tril + torch.diag_embed(dist.diagonal) - covariance = torch.matmul(scale_tril, scale_tril.t()) + torch.diag_embed(dist.diagonal) + scale_tril = dist.scale_tril.to_dense() # L matrix + diagonal = dist.diagonal # D matrix + # Compute covariance from LDL^T decomposition + covariance = torch.matmul(scale_tril @ torch.diag_embed(diagonal), scale_tril.transpose(-1, -2)) else: - precision_tril = dist.precision_tril.to_dense() - covariance = torch.inverse(torch.matmul(precision_tril, precision_tril.transpose(-1, -2))) + precision_tril = dist.precision_tril.to_dense() # L matrix + diagonal = dist.diagonal # D matrix + # Compute precision matrix from LDL^T decomposition + precision = torch.matmul(precision_tril @ torch.diag_embed(diagonal), precision_tril.t()) + # Compute covariance from precision + Id = torch.eye(precision.size(-1), dtype=precision.dtype, device=precision.device) # identity matrix + covariance = torch.linalg.solve( + Id, precision + ) # solves for X in AX=B, where A is precision, B is I, and X is covariance - # TODO: getting closer but still not there assert torch.allclose(samples.mean(0), dist.loc, atol=0.1) - assert torch.allclose(torch.cov(samples.T), covariance, atol=0.1) + if len(samples.shape) == 2: + assert torch.allclose(torch.cov(samples.T), covariance, atol=0.1) + else: + covariance_ = torch.stack([torch.cov(sample.T) for sample in samples.permute(1, 0, 2)]) + assert torch.allclose(covariance_, covariance, atol=0.1) # def test_rsample_backward(device, layout, var, sizes, value_dtype, index_dtype): diff --git a/torchsparsegradutils/distributions/sparse_multivariate_normal.py b/torchsparsegradutils/distributions/sparse_multivariate_normal.py index c17d0b0..ce11409 100644 --- a/torchsparsegradutils/distributions/sparse_multivariate_normal.py +++ b/torchsparsegradutils/distributions/sparse_multivariate_normal.py @@ -13,7 +13,7 @@ __all__ = ["SparseMultivariateNormal"] -def _batch_sparse_mv(op, bmat, bvec): +def _batch_sparse_mv(op, bmat, bvec, **kwargs): """Performs batched matrix-vector operation between a sparse matrix and a dense vector. bmat can have 0 or 1 batch dimension @@ -33,13 +33,13 @@ def _batch_sparse_mv(op, bmat, bvec): torch.Tensor: Dense matrix vector product """ if bmat.dim() == 2 and bvec.dim() == 1: - return op(bmat, bvec.unsqueeze(-1)).squeeze(-1) + return op(bmat, bvec.unsqueeze(-1), **kwargs).squeeze(-1) elif bmat.dim() == 2 and bvec.dim() == 2: - return op(bmat, bvec.t()).t() + return op(bmat, bvec.t(), **kwargs).t() elif bmat.dim() == 3 and bvec.dim() == 2: - return op(bmat, bvec.unsqueeze(-1)).squeeze(-1) + return op(bmat, bvec.unsqueeze(-1), **kwargs).squeeze(-1) elif bmat.dim() == 3 and bvec.dim() == 3: - return op(bmat, bvec.permute(1, 2, 0)).permute(2, 0, 1) + return op(bmat, bvec.permute(1, 2, 0), **kwargs).permute(2, 0, 1) else: raise ValueError("Invalid dimensions for bmat and bvec") @@ -160,9 +160,12 @@ def rsample(self, sample_shape=torch.Size()): eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) if "_scale_tril" in self.__dict__: - cov_sqrt_m_eps = _batch_sparse_mv(spmm, self.scale_tril, eps) + x = _batch_sparse_mv(spmm, self.scale_tril, self.diagonal.sqrt() * eps) else: # 'precision_tril' in self.__dict__ - cov_sqrt_m_eps = _batch_sparse_mv(spts, self.precision_tril, eps) + # TODO: check if this is correct + x = _batch_sparse_mv( + spts, self.precision_tril, eps / (self.diagonal.sqrt()), upper=False, unitriangular=True + ) - return self.loc + cov_sqrt_m_eps * self.diagonal.sqrt() + return self.loc + x