Skip to content

Commit

Permalink
attempted to add unit diagonal for cov
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-barfoot committed Jun 21, 2023
1 parent de5764c commit e06483c
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions torchsparsegradutils/distributions/sparse_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torchsparsegradutils import sparse_mm as spmm
from torchsparsegradutils import sparse_triangular_solve as spts
from torchsparsegradutils.utils import sparse_eye

# from .contraints import sparse_strictly_lower_triangular

Expand Down Expand Up @@ -51,6 +52,7 @@ class SparseMultivariateNormal(Distribution):
LDLt decomposition of precision matrix: P = L @ D @ L.T
This implementation only supports a single batch dimension.
"""
# TODO: It is confusing that scale_tril and precision_tril return unit triangular and strictly lower triangular matrices respectively
# TODO: add in constraints
# arg_constraints = {'loc': constraints.real_vector,
# 'diag': constraints.independent(constraints.positive, 1),
Expand Down Expand Up @@ -89,8 +91,9 @@ def __init__(self, loc, diagonal, scale_tril=None, precision_tril=None, validate
if scale_tril is not None:
if scale_tril.layout == torch.sparse_coo:
scale_tril = scale_tril.coalesce() if not scale_tril.is_coalesced() else scale_tril
indices_dtype = scale_tril.indices().dtype
elif scale_tril.layout == torch.sparse_csr:
pass
indices_dtype = scale_tril.crow_indices().dtype
else:
raise ValueError("scale_tril must be sparse COO or CSR, instead of {}".format(scale_tril.layout))

Expand All @@ -102,7 +105,19 @@ def __init__(self, loc, diagonal, scale_tril=None, precision_tril=None, validate
raise ValueError("scale_tril can only have 1 batch dimension, but has {}".format(scale_tril.dim() - 2))

batch_shape = torch.broadcast_shapes(loc.shape[:-1], diagonal.shape[:-1], scale_tril.shape[:-2])
self._scale_tril = scale_tril

# add unit diagonal to scale_tril, as this is required for LDLt decomposition and sampling
Id = sparse_eye(
scale_tril.shape,
layout=scale_tril.layout,
values_dtype=scale_tril.dtype,
indices_dtype=indices_dtype,
device=scale_tril.device,
)
if len(batch_shape) == 0:
self._scale_tril = scale_tril + Id
else: # BUG: sparse tensors do not support batched addition
pass

else: # precision_tril is not None
if precision_tril.layout == torch.sparse_coo:
Expand All @@ -122,6 +137,7 @@ def __init__(self, loc, diagonal, scale_tril=None, precision_tril=None, validate
)

batch_shape = torch.broadcast_shapes(loc.shape[:-1], diagonal.shape[:-1], precision_tril.shape[:-2])

self._precision_tril = precision_tril

super().__init__(batch_shape, event_shape, validate_args=validate_args)
Expand Down Expand Up @@ -151,20 +167,20 @@ def mode(self):
return self._loc

def rsample(self, sample_shape=torch.Size()):
# if len(sample_shape) == 0:
# sample_shape = torch.Size((1,))
# elif len(sample_shape) > 1:
# raise ValueError("only 1D sample shapes are currently supported")

shape = self._extended_shape(sample_shape)
eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)

if "_scale_tril" in self.__dict__:
x = _batch_sparse_mv(spmm, self.scale_tril, self.diagonal.sqrt() * eps)
x = _batch_sparse_mv(spmm, self._scale_tril, self._diagonal.sqrt() * eps)

else: # 'precision_tril' in self.__dict__
x = _batch_sparse_mv(
spts, self.precision_tril, eps / (self.diagonal.sqrt()), upper=False, unitriangular=True, transpose=True
spts,
self._precision_tril,
eps / (self._diagonal.sqrt()),
upper=False,
unitriangular=True,
transpose=True,
)

return self.loc + x

0 comments on commit e06483c

Please sign in to comment.