-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch to sqrt(precision) representation in Gaussian #568
Conversation
@eb8680 could we pair code on Gaussian patterns this week? I think this PR now has correct linear algebra, but it produces slightly different patterns. Specifically, the new square root representation can no longer be negated, so we'll need to keep |
Whoa, impressive work to make this possible! I haven't looked into the code yet but my general concern would be on marginalization and compress rank logics. I will look into the details later of this week. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My initial concerns regarding compress rank and marginalization seem to be already resolved. The math seems correct and the implementation is quite optimal in my opinion.
white_vec
seems to be natural with the implementation but I'm curious on why using it instead of info_vec
?
def eager_neg(op, arg): | ||
info_vec = -arg.info_vec | ||
precision = -arg.precision | ||
return Gaussian(info_vec, precision, arg.inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, previously when sub
is needed? and how you compute Gaussian - Gaussian
? (I couldn't derive the math :( )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe sub is needed only in computing KL divergences. E.g. in fully Gaussian Elbos, we compute Integrate(guide, model - guide)
where both model
and guide
are Gaussian. Actually we'll need some more patterns now that ops.neg(Gaussian(...))
is lazy. As discussed with @eb8680 I'll open an issue once this merges.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As group coded on Friday, full subtraction will be implemented in the follow-up PR #553
Thanks for reviewing, @fehiepsi! Here are a few weak reasons I chose
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally looks great. Once AutoGaussian
is finished I think we should consider writing a short technical report on the internals of funsor.Gaussian
, since there's a lot of cool stuff here and it wasn't really covered at all in the Funsor paper.
if prec_sqrt is not None: | ||
is_tril = False | ||
elif precision is not None: | ||
prec_sqrt = ops.cholesky(precision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This conversion logic in GaussianMeta
seems fine overall, but I wonder whether we should look for a way to restrict actual performance of the possibly expensive linear algebra computations to specific interpretations like eager
. If so, that might be another reason to attempt #556, since we could just make them lazy Op
s.
Thanks for reviewing @eb8680 and @fehiepsi! I believe I've addressed all comments. I plan to add more patterns in subsequent PRs that handle Gaussian variable elimination, e.g. in #553 and https://github.com/pyro-ppl/funsor/tree/tractable-for-gaussians |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look great to me on the second pass! Thanks, @fritzo!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Resolves #567
Adapts @fehiepsi's pyro-ppl/pyro#2019
This switches the internal Gaussian representation to a numerically stable and space efficient representation
In the new parametrization, Gaussians represent the log-density function
These two parameters are shaped to efficiently represent low-rank data:
reducing space complexity from
O(dim(dim+1))
toO(rank(dim+1))
. In my real-world example rank=1, dim=2369, and batch_shape=(1343,), so the space reduction is 30GB → 13MB.Computations is cheap in this representation: addition amounts to concatenation, and plate-reduction amounts to transpose and reshape. Some ops are only supported on full-rank Gaussians, and I've added checks based on the new property
.is_full_rank
. This partial support is ok because the Gaussian funsors that arise in Bayesian models are all full rank due to priors (notwithstanding numerical loss of rank).Because the Gaussian funsor is internal, the interface change should not cause breakage of most user code, since most user code uses
to_funsor()
andto_data()
with backend-specific distributions. One broken piece of user code is Pyro's AutoGaussianFunsor which will need an update (and which will be sped up).As suggested by @eb8680, I've added some optional kwarg parametrizations and properties to support conversion to other Gaussian representations, e.g.
g = Gaussian(mean=..., covariance=..., inputs=...)
andg._mean
,g._covariance
. This allows more Gaussian math to live in gaussian.py.Tested