Skip to content

Commit

Permalink
spmvn cov sampling good
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 16, 2023
1 parent 80a39cf commit d00e70d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
27 changes: 19 additions & 8 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,30 @@ def test_rsample_forward(device, layout, var, sizes, value_dtype, index_dtype):
pytest.skip("Sparse COO with int32 indices is not supported")

dist = construct_distribution(sizes, layout, var, value_dtype, index_dtype, device)
samples = dist.rsample((10000,))
samples = dist.rsample((100000,))

if var == "cov":
scale_tril = dist.scale_tril.to_dense()
# scale_tril_with_diag = scale_tril + torch.diag_embed(dist.diagonal)
covariance = torch.matmul(scale_tril, scale_tril.t()) + torch.diag_embed(dist.diagonal)
scale_tril = dist.scale_tril.to_dense() # L matrix
diagonal = dist.diagonal # D matrix
# Compute covariance from LDL^T decomposition
covariance = torch.matmul(scale_tril @ torch.diag_embed(diagonal), scale_tril.transpose(-1, -2))
else:
precision_tril = dist.precision_tril.to_dense()
covariance = torch.inverse(torch.matmul(precision_tril, precision_tril.transpose(-1, -2)))
precision_tril = dist.precision_tril.to_dense() # L matrix
diagonal = dist.diagonal # D matrix
# Compute precision matrix from LDL^T decomposition
precision = torch.matmul(precision_tril @ torch.diag_embed(diagonal), precision_tril.t())
# Compute covariance from precision
Id = torch.eye(precision.size(-1), dtype=precision.dtype, device=precision.device) # identity matrix
covariance = torch.linalg.solve(
Id, precision
) # solves for X in AX=B, where A is precision, B is I, and X is covariance

# TODO: getting closer but still not there
assert torch.allclose(samples.mean(0), dist.loc, atol=0.1)
assert torch.allclose(torch.cov(samples.T), covariance, atol=0.1)
if len(samples.shape) == 2:
assert torch.allclose(torch.cov(samples.T), covariance, atol=0.1)
else:
covariance_ = torch.stack([torch.cov(sample.T) for sample in samples.permute(1, 0, 2)])
assert torch.allclose(covariance_, covariance, atol=0.1)


# def test_rsample_backward(device, layout, var, sizes, value_dtype, index_dtype):
Expand Down
19 changes: 11 additions & 8 deletions torchsparsegradutils/distributions/sparse_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
__all__ = ["SparseMultivariateNormal"]


def _batch_sparse_mv(op, bmat, bvec):
def _batch_sparse_mv(op, bmat, bvec, **kwargs):
"""Performs batched matrix-vector operation between a sparse matrix and a dense vector.
bmat can have 0 or 1 batch dimension
Expand All @@ -33,13 +33,13 @@ def _batch_sparse_mv(op, bmat, bvec):
torch.Tensor: Dense matrix vector product
"""
if bmat.dim() == 2 and bvec.dim() == 1:
return op(bmat, bvec.unsqueeze(-1)).squeeze(-1)
return op(bmat, bvec.unsqueeze(-1), **kwargs).squeeze(-1)
elif bmat.dim() == 2 and bvec.dim() == 2:
return op(bmat, bvec.t()).t()
return op(bmat, bvec.t(), **kwargs).t()
elif bmat.dim() == 3 and bvec.dim() == 2:
return op(bmat, bvec.unsqueeze(-1)).squeeze(-1)
return op(bmat, bvec.unsqueeze(-1), **kwargs).squeeze(-1)
elif bmat.dim() == 3 and bvec.dim() == 3:
return op(bmat, bvec.permute(1, 2, 0)).permute(2, 0, 1)
return op(bmat, bvec.permute(1, 2, 0), **kwargs).permute(2, 0, 1)
else:
raise ValueError("Invalid dimensions for bmat and bvec")

Expand Down Expand Up @@ -160,9 +160,12 @@ def rsample(self, sample_shape=torch.Size()):
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)

if "_scale_tril" in self.__dict__:
cov_sqrt_m_eps = _batch_sparse_mv(spmm, self.scale_tril, eps)
x = _batch_sparse_mv(spmm, self.scale_tril, self.diagonal.sqrt() * eps)

else: # 'precision_tril' in self.__dict__
cov_sqrt_m_eps = _batch_sparse_mv(spts, self.precision_tril, eps)
# TODO: check if this is correct
x = _batch_sparse_mv(
spts, self.precision_tril, eps / (self.diagonal.sqrt()), upper=False, unitriangular=True
)

return self.loc + cov_sqrt_m_eps * self.diagonal.sqrt()
return self.loc + x

0 comments on commit d00e70d

Please sign in to comment.