-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Precondition interpretation for Gaussian TVE #553
Conversation
# TODO Replace this with root + Constant(...) after #548 merges. | ||
root_vars = root.input_vars | batch_vars | ||
|
||
def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=set()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that batch_vars
can now change during the course of adjoint
, e.g. in Precondition where the aux vars aren't know until each Approximate
term is hit.
@eb8680 let me know if you want a zoom tour |
samples = {k: v(**subs) for k, v in samples.items()} | ||
|
||
# Compute log density at each sample, lazily dependent on aux_name. | ||
log_prob = -log_Z | ||
for f in factors.values(): | ||
term = f(**samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like these substitutions are triggering expensive materialize()
operations due to Gaussian.eager_subs().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sampling math looks right, just some clarifying questions
@@ -883,6 +883,19 @@ def eager_finitary_stack(op, parts): | |||
return Tensor(raw_result, inputs, parts[0].dtype) | |||
|
|||
|
|||
@eager.register(Finitary, ops.CatOp, typing.Tuple[Tensor, ...]) | |||
def eager_finitary_cat(op, parts): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, do relevant tests still pass if you replace this pattern with
eager.register(Finitary, ops.CatOp, typing.Tuple[Tensor, ...])(
funsor.op_factory.eager_tensor_made_op)
which uses the generic tensor op evaluation pattern in https://github.com/pyro-ppl/funsor/blob/master/funsor/op_factory.py#L19
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't recall, but my guess is that since ops.cat()
doesn't broadcast, we needed to explicitly set expand=True
when calling align_tensors()
.
samples = g2.sample(all_vars, sample_inputs, rng_keys[2]) | ||
actual_mean, actual_cov = compute_moments(samples) | ||
assert_close(actual_mean, expected_mean, atol=1e-1, rtol=1e-1) | ||
assert_close(actual_cov, expected_cov, atol=1e-1, rtol=1e-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these tests strong enough to catch bugs with these tolerances? We could also test these computations exactly by preconditioning with shared noise, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These did indeed catch some bugs, but I'd like to eventually switch to double precision in tests #573
# Substitute noise for the aux value, as would happen each SVI step. | ||
aux_numel = log_prob.inputs["aux"].num_elements | ||
noise = Tensor(randn(num_samples, aux_numel))["particle"] | ||
with memoize(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this memoize
for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the natural evaluation pattern for dags. Since downstream user code will use this pattern, I'd prefer to use this pattern in tests.
shape += tuple(d.size for d in int_inputs.values()) | ||
shape += (dim,) | ||
assert ops.is_numeric_array(prototype) | ||
return ops.randn(prototype, shape, rng_key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the motivation for _sample_white_noise
not returning a Funsor
in all cases? It seems like making it consistent would slightly reduce the mental overhead of understanding our fairly complicated sample
code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it just seemed simpler this way
prec_sqrt = prec_sqrt_b - prec_sqrt_b @ proj_a | ||
white_vec = self.white_vec - _vm(self.white_vec, proj_a) | ||
result += Gaussian(white_vec, prec_sqrt, inputs) | ||
else: # The Gaussian over xa is zero. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which Gaussian tests exercise this case? I want to verify that the size-0 trick used here works on all backends and doesn't introduce unexpected behavior in downstream code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember, maybe in one of the GaussianHMM tests?
Addresses pyro-ppl/pyro#2813
This adds a
Precondition
interpretation for Gaussian factor graphs.Similar to batched
MonteCarlo(..., sample_inputs)
where samples depend on a discrete sample index,Precondition()
returns samples that depend on a single white noise inputaux : Reals[total_num_elements]
. Notably this new interpretation is deterministic, since the samples are drawn lazily. These lazy samples can then be used for sampling by either substituting white noise (in variational inference) or substituting an HMC-controlled vector. Specifically, this can be used to implement Pyro'sAutoGaussian.get_transform()
for use inNeuTraReparam
.Changes to
Gaussian._sample()
can be seen as a first step towards makingGaussian
depend onfunsor.Tensor
s rather than backend arrays, as suggested by @eb8680 #556.Remaining tasks
Gaussian - Gaussian
Subs._sample()
Gaussian._sample()
for subsets of real variablesops.cat
ops.randn
Gaussian - Gaussian
which turns out to be incorrectTested