Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more tests for collapse #2702

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open

Add more tests for collapse #2702

wants to merge 1 commit into from

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Nov 26, 2020

Ported failing tests from pyro-ppl/numpyro#809 to discuss.

Here in Pyro, we don't face issues like in NumPyro. The only issue is the final log_prob is a Contraction, not a Tensor. I guess we need some pattern here to make it work. Edit: the same issues happen (I got the wrong impression due to trace back mechanism in collapse - which is fixed in this PR)

with pyro.plate("data", T, dim=-1):
expand_shape = (d,) if num_particles == 1 else (num_particles, 1, d)
y = pyro.sample("y", dist.Normal(x, 1.).expand(expand_shape).to_event(1))
pyro.sample("z", dist.Normal(y, 1.).to_event(1), obs=data)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails due to 2 reasons:

  • y.output is Reals[d], which will raise error while infer_param_domain in funsor: Output mismatch: Reals[2] vs Real
  • to_event is not available for Funsor distributions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I wonder if there is always enough information available in the args of expanded distributions to automatically determine the event dim, so that we could make .to_event() a no-op on funsors. For example here we could deduce the event shape from y.output.

@eb8680 would it be possible to support this in to_funsor() and to_data()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to support this in to_funsor() and to_data()?

It should be possible, at least in principle, and if we want collapse to work seamlessly with models that use .to_event we'll need something like that. I think it will require changing the way type inference works in funsor.distribution.Distribution, though.

The point of failure in Funsor is the to_funsor conversion of parameters in funsor.distribution.DistributionMeta.__call__:

class DistributionMeta(FunsorMeta):
    def __call__(cls, *args, **kwargs):
        kwargs.update(zip(cls._ast_fields, args))
        value = kwargs.pop('value', 'value')

        # Failure occurs here -------v
        kwargs = OrderedDict(
            (k, to_funsor(kwargs[k], output=cls._infer_param_domain(k, getattr(kwargs[k], "shape", ()))))
            for k in cls._ast_fields if k != 'value')

        # this is also incorrect
        value = to_funsor(value, output=cls._infer_value_domain(**{k: v.output for k, v in kwargs.items()}))

        args = numbers_to_tensors(*(tuple(kwargs.values()) + (value,)))
        return super(DistributionMeta, cls).__call__(*args)

In this test case, kwargs[k] is loc = Variable("y", Reals[2]), and Normal._infer_param_domain(...) thinks the output should be Real instead of Reals[2].

I think a general solution (at least when all parameters have the same broadcasted .output) is to compute unbroadcasted parameter and value shapes up front, then broadcast:

# compute unbroadcasted domains
domains = {k: cls._infer_param_domain(k, getattr(kwargs[k], "shape")) for k in kwargs if k != "value"}
domains["value"] = cls._infer_value_domain(cls, **domains)

# broadcast individual param domains with Funsor inputs
# this avoids .expand-ing underlying parameter tensors
for k, v in kwargs.items():
    if isinstance(v, Funsor):
        domains[k] = Reals[broadcast_shapes(v.shape, domains[k].shape)[0]]  # assume all Reals for exposition

# broadcast value domain, which depends on param domains, with broadcasted param domains
domains["value"] = Reals[broadcast_shapes(domains["value"], *domains.values())[0]]

Now the previously incorrect to_funsor conversions reduce to

kwargs = OrderedDict((k, to_funsor(kwargs[k], output=domains[k])) for k in kwargs)
value = kwargs["value"]

WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I put up a draft Funsor PR here: pyro-ppl/funsor#402

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Eli! testing now...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe what's in that PR after my recent push is sufficient for this test case, although it's not ready to merge because of some edge cases in the distributions. Let me know if it's not working.

beta0 = pyro.sample("beta0", dist.Normal(x, 1.).expand(expand_shape).to_event(1))
beta = pyro.sample("beta", dist.MultivariateNormal(beta0, torch.eye(S)))

mean = torch.ones((T, d)) @ beta
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails because beta.output is Reals[S] while we need it to be Reals[d, S].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this seems like a major incompatibility between funsor-style operations which have a clear batch/event split and numpy-style operations where everything is effectively an event dim. In particular there's no way for ops.matmul to touch the batch dimension "d".

One workaround in this model would be to move beta0 and beta out of the plate and instead use .to_event(1), and I think that kind of workaround will be needed whenever we exit a plate and treat a formerly-batch dimension as an event dimension. Conversely I conjecture that no workarounds are needed in tree-plated models, that is in models where "no variable outside a plate ever depends on an upstream variable inside a plate"; this class is similar to our condition for tractable tensor variable elimination.

Copy link
Member

@eb8680 eb8680 Nov 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like if we had a way of knowing the plate context of a value both at the time of its creation and each time it was accessed, we could handle this smoothly using funsor.terms.Lambda. Suppose we could overload variable access/assignment, e.g. by using an overloadable environment data structure env rather than locals():

with plate("plate_var", d, dim=-1):
    beta0 = pyro.sample("beta0", dist.Normal(x, 1.).expand(expand_shape).to_event(1))
    # setting env.beta also records the current plate context of beta
    env.beta = pyro.sample("beta", dist.MultivariateNormal(beta0, torch.eye(S)))

# use funsor.terms.Lambda to convert the dead plate dimension to an output:
# now reading env.beta returns funsor.Lambda(plate_var, beta)
# where plate_var is the difference between the current and original plate contexts
mean = torch.ones((T, d)) @ env.beta

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I didn't know about Lambda. I am not sure why we need env but it seems from your code that in the collapse code, when the output does not match plate infos, we can use Lambda.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why we need env

I was suggesting it as a layer of automation on top of Lambda that would help keep track of which plate dimensions need to be converted to event dimensions via Lambda, just like contrib.funsor.to_funsor/to_data automatically track the name_to_dim mapping for enumeration. You're right that it's also possible to use Lambda directly in user code.


mean = torch.ones((T, d)) @ beta
with pyro.plate("data", T, dim=-1):
pyro.sample("obs", dist.MultivariateNormal(mean, torch.eye(S)), obs=data)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails due to a similar reason to the diag_normal_plate_normal test above:

  • funsor requires mean.output is Reals[S] while the output of mean after taking matmul is Reals[T, S]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants