Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff implementation (experimental) #494

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Mar 17, 2021

This is an implementation of autodiff. The goal is to address issues in computing expectations in TraceEnum_ELBO and TraceMarkovEnum_ELBO (#493). As of now it seems to fix nan gradients under eager interpretation in TraceEnum_ELBO.

The algorithm implements equivalents of linearize(), transpose() functions, and is tape-free (#446).

  1. Linearize. Variables that need to be linearized are replaced by primal- tangent tuple JVP(primal, tangent) and then pattern matched to propagate tangents, e.g.:
JVP(x, dx) + JVP(y, dy) = JVP(x+y, dx+dy)
JVP(x, dx) * JVP(y, dy) = JVP(x*y, ydx + xdy)
JVP(x, dx) * y = JVP(x*y, ydx)

Out tangent is a linear function of in tangents. JVP is used for (add,mul) semiring and LJVP is used for (logaddexp,add) semiring.

  1. Transpose of a linear function. Transpose is implemented simply by inverting the order of function execution and transposing matrices, in this case swapping more primitive operations .reduce(sum_op, "i") and .expand("i") (broadcasting does this automatically).

@ordabayevy ordabayevy added the WIP label Mar 17, 2021
inputs = OrderedDict([(var.name, var.output) for var in expanded_vars])
inputs.update(arg.inputs)
output = arg.output
fresh = frozenset()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be

fresh = frozenset(v.name for v in expanded_vars)

from funsor.terms import Binary, Funsor, Lambda, Number, Tuple, Variable, lazy
from funsor.testing import assert_close, random_tensor

funsor.set_backend("torch")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test files should read but not write the global backend. Instead you can decorate each test with

@pytest.mark.skipif(get_backend() != "torch", reason="backend-specific")

and then run tests with

FUNSOR_BACKEND=torch pytest test/test_autodiff.py

@@ -994,6 +1001,46 @@ def die_binary(op, lhs, rhs):
raise NotImplementedError(f"Missing pattern for {repr(expr)}")


class Expand(Funsor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I'd like to better understand the need for this.

We've been trying to preserve the extensionality property in Funsor, which states that: if under every grounding substitution subs a pair of funsors f,g satisfy f(**subs) == g(**subs), then it should be permissible for an optimizer to replace funsor f with funsor g in any expression. IIUC this Expand funsor would break extensionality because f.expand(...) behaves as f under every grounding substitution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants