Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to sqrt(precision) representation in Gaussian #568

Merged
merged 46 commits into from
Oct 18, 2021
Merged

Switch to sqrt(precision) representation in Gaussian #568

merged 46 commits into from
Oct 18, 2021

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Oct 7, 2021

Resolves #567
Adapts @fehiepsi's pyro-ppl/pyro#2019

This switches the internal Gaussian representation to a numerically stable and space efficient representation

- Gaussian(info_vec, precision, inputs)
+ Gaussian(white_vec, prec_sqrt, inputs)

In the new parametrization, Gaussians represent the log-density function

Gaussian(white_vec, prec_sqrt, OrderedDict(x=Reals[n]))
  = -1/2 || x @ prec_sqrt - white_vec ||^2

These two parameters are shaped to efficiently represent low-rank data:

assert white_vec.shape == batch_shape + (rank,)
assert prec_sqrt.shape == batch_shape + (dim, rank)

reducing space complexity from O(dim(dim+1)) to O(rank(dim+1)). In my real-world example rank=1, dim=2369, and batch_shape=(1343,), so the space reduction is 30GB → 13MB.

Computations is cheap in this representation: addition amounts to concatenation, and plate-reduction amounts to transpose and reshape. Some ops are only supported on full-rank Gaussians, and I've added checks based on the new property .is_full_rank. This partial support is ok because the Gaussian funsors that arise in Bayesian models are all full rank due to priors (notwithstanding numerical loss of rank).

Because the Gaussian funsor is internal, the interface change should not cause breakage of most user code, since most user code uses to_funsor() and to_data() with backend-specific distributions. One broken piece of user code is Pyro's AutoGaussianFunsor which will need an update (and which will be sped up).

As suggested by @eb8680, I've added some optional kwarg parametrizations and properties to support conversion to other Gaussian representations, e.g. g = Gaussian(mean=..., covariance=..., inputs=...) and g._mean, g._covariance. This allows more Gaussian math to live in gaussian.py.

Tested

  • added new tests
  • pass existing funsor tests (linear algebra)
  • pass existing funsor tests (patterns)
  • pass NumPyro tests (conjugacy)
  • pass Pyro tests on a branch (Update to use funsor's SRIF Gaussian pyro#2943)
  • check computational cost on the pyro-cov model (results: under 12GB memory)

@fritzo
Copy link
Member Author

fritzo commented Oct 11, 2021

@eb8680 could we pair code on Gaussian patterns this week? I think this PR now has correct linear algebra, but it produces slightly different patterns. Specifically, the new square root representation can no longer be negated, so we'll need to keep Unary(ops.neg, Gaussian) lazy.

@fritzo fritzo marked this pull request as ready for review October 13, 2021 23:33
@fritzo fritzo requested review from fehiepsi and eb8680 and removed request for fehiepsi October 13, 2021 23:34
@fritzo
Copy link
Member Author

fritzo commented Oct 13, 2021

@eb8680 could you please review this PR in general?
@fehiepsi could you please review the linear algebra?

I'm happy to walk you through the changes over zoom.

@fehiepsi
Copy link
Member

Whoa, impressive work to make this possible! I haven't looked into the code yet but my general concern would be on marginalization and compress rank logics. I will look into the details later of this week.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial concerns regarding compress rank and marginalization seem to be already resolved. The math seems correct and the implementation is quite optimal in my opinion.

white_vec seems to be natural with the implementation but I'm curious on why using it instead of info_vec?

funsor/gaussian.py Outdated Show resolved Hide resolved
funsor/gaussian.py Outdated Show resolved Hide resolved
funsor/gaussian.py Outdated Show resolved Hide resolved
def eager_neg(op, arg):
info_vec = -arg.info_vec
precision = -arg.precision
return Gaussian(info_vec, precision, arg.inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, previously when sub is needed? and how you compute Gaussian - Gaussian? (I couldn't derive the math :( )

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe sub is needed only in computing KL divergences. E.g. in fully Gaussian Elbos, we compute Integrate(guide, model - guide) where both model and guide are Gaussian. Actually we'll need some more patterns now that ops.neg(Gaussian(...)) is lazy. As discussed with @eb8680 I'll open an issue once this merges.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As group coded on Friday, full subtraction will be implemented in the follow-up PR #553

funsor/integrate.py Outdated Show resolved Hide resolved
@fritzo
Copy link
Member Author

fritzo commented Oct 15, 2021

Thanks for reviewing, @fehiepsi! Here are a few weak reasons I chose white_vec instead of info_vec:

  • white_vec seems natural 😄
  • white_vec is space optimal in the low-rank case, with white_vec.shape[-1] == rank versus info_vec.shape[-1] == dim. E.g. in the rank-1 case we get up to a factor of two savings: O(rank(1+dim)) < O(dim(1+rank)).
  • Because white_vec does not depend on the real inputs, it need not be rearranged when interleaving or substituting real variables. This is minor, but it does simplify code a bit.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks great. Once AutoGaussian is finished I think we should consider writing a short technical report on the internals of funsor.Gaussian, since there's a lot of cool stuff here and it wasn't really covered at all in the Funsor paper.

funsor/testing.py Outdated Show resolved Hide resolved
funsor/pyro/convert.py Show resolved Hide resolved
if prec_sqrt is not None:
is_tril = False
elif precision is not None:
prec_sqrt = ops.cholesky(precision)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conversion logic in GaussianMeta seems fine overall, but I wonder whether we should look for a way to restrict actual performance of the possibly expensive linear algebra computations to specific interpretations like eager. If so, that might be another reason to attempt #556, since we could just make them lazy Ops.

funsor/gaussian.py Outdated Show resolved Hide resolved
@fritzo
Copy link
Member Author

fritzo commented Oct 16, 2021

Thanks for reviewing @eb8680 and @fehiepsi! I believe I've addressed all comments.

I plan to add more patterns in subsequent PRs that handle Gaussian variable elimination, e.g. in #553 and https://github.com/pyro-ppl/funsor/tree/tractable-for-gaussians

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look great to me on the second pass! Thanks, @fritzo!

@fritzo fritzo requested a review from eb8680 October 17, 2021 19:18
Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@eb8680 eb8680 merged commit c94a9bd into master Oct 18, 2021
@eb8680 eb8680 deleted the srif branch October 18, 2021 14:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Switch to sqrt(prescision) representation in Gaussian?
3 participants