From e06483c053b8c9a62e05965dd7438e53111c18d2 Mon Sep 17 00:00:00 2001 From: Theo Barfoot Date: Wed, 21 Jun 2023 17:36:25 +0100 Subject: [PATCH] attempted to add unit diagonal for cov --- .../sparse_multivariate_normal.py | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/torchsparsegradutils/distributions/sparse_multivariate_normal.py b/torchsparsegradutils/distributions/sparse_multivariate_normal.py index 4264015..b6913cb 100644 --- a/torchsparsegradutils/distributions/sparse_multivariate_normal.py +++ b/torchsparsegradutils/distributions/sparse_multivariate_normal.py @@ -7,6 +7,7 @@ from torchsparsegradutils import sparse_mm as spmm from torchsparsegradutils import sparse_triangular_solve as spts +from torchsparsegradutils.utils import sparse_eye # from .contraints import sparse_strictly_lower_triangular @@ -51,6 +52,7 @@ class SparseMultivariateNormal(Distribution): LDLt decomposition of precision matrix: P = L @ D @ L.T This implementation only supports a single batch dimension. """ + # TODO: It is confusing that scale_tril and precision_tril return unit triangular and strictly lower triangular matrices respectively # TODO: add in constraints # arg_constraints = {'loc': constraints.real_vector, # 'diag': constraints.independent(constraints.positive, 1), @@ -89,8 +91,9 @@ def __init__(self, loc, diagonal, scale_tril=None, precision_tril=None, validate if scale_tril is not None: if scale_tril.layout == torch.sparse_coo: scale_tril = scale_tril.coalesce() if not scale_tril.is_coalesced() else scale_tril + indices_dtype = scale_tril.indices().dtype elif scale_tril.layout == torch.sparse_csr: - pass + indices_dtype = scale_tril.crow_indices().dtype else: raise ValueError("scale_tril must be sparse COO or CSR, instead of {}".format(scale_tril.layout)) @@ -102,7 +105,19 @@ def __init__(self, loc, diagonal, scale_tril=None, precision_tril=None, validate raise ValueError("scale_tril can only have 1 batch dimension, but has {}".format(scale_tril.dim() - 2)) batch_shape = torch.broadcast_shapes(loc.shape[:-1], diagonal.shape[:-1], scale_tril.shape[:-2]) - self._scale_tril = scale_tril + + # add unit diagonal to scale_tril, as this is required for LDLt decomposition and sampling + Id = sparse_eye( + scale_tril.shape, + layout=scale_tril.layout, + values_dtype=scale_tril.dtype, + indices_dtype=indices_dtype, + device=scale_tril.device, + ) + if len(batch_shape) == 0: + self._scale_tril = scale_tril + Id + else: # BUG: sparse tensors do not support batched addition + pass else: # precision_tril is not None if precision_tril.layout == torch.sparse_coo: @@ -122,6 +137,7 @@ def __init__(self, loc, diagonal, scale_tril=None, precision_tril=None, validate ) batch_shape = torch.broadcast_shapes(loc.shape[:-1], diagonal.shape[:-1], precision_tril.shape[:-2]) + self._precision_tril = precision_tril super().__init__(batch_shape, event_shape, validate_args=validate_args) @@ -151,20 +167,20 @@ def mode(self): return self._loc def rsample(self, sample_shape=torch.Size()): - # if len(sample_shape) == 0: - # sample_shape = torch.Size((1,)) - # elif len(sample_shape) > 1: - # raise ValueError("only 1D sample shapes are currently supported") - shape = self._extended_shape(sample_shape) eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) if "_scale_tril" in self.__dict__: - x = _batch_sparse_mv(spmm, self.scale_tril, self.diagonal.sqrt() * eps) + x = _batch_sparse_mv(spmm, self._scale_tril, self._diagonal.sqrt() * eps) else: # 'precision_tril' in self.__dict__ x = _batch_sparse_mv( - spts, self.precision_tril, eps / (self.diagonal.sqrt()), upper=False, unitriangular=True, transpose=True + spts, + self._precision_tril, + eps / (self._diagonal.sqrt()), + upper=False, + unitriangular=True, + transpose=True, ) return self.loc + x