-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a compiler and tracer, each creating OpPrograms #557
Conversation
Looks great, and definitely feasible in some form. What about tracing individual |
I guess that approach would require strict use of |
🤔 Interesting, I guess that would avoid both a lowering stage and the need to refactor Gaussian to be symbolic.
Yes, I guess we'd need to manually desugar all math e.g. |
We could also experiment with Refactoring |
@eb8680 WDYT about merging this partial implementation (w/o |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good per zoom review
Addresses pyro-ppl/pyro#2929
pair coded with @eb8680
This aims to work around backend jit issues and improve speed of Funsor used in Pyro. The approach is to first perform symbolic computations among Funsors, then lower to simple Funsor expressions, then compile the lowered Funsor expression to a straight-line program involving only backend ops, then optionally convert to Python code. The final program depends only on
funsor.ops
. This eliminates interpreter overhead, but does not eliminate op dispatch overhead.The immediate application is to speed up Pyro's
AutoGaussian
guide pyro-ppl/pyro#2929, but this will also require symbolic Gaussians #556.Tasks
Contraction
toBinary
compile_function()
for functions with multiple outputs (e.g. forward filter backward precondition).Tasks deferred to follow-up PRs
Lambda
orops.einsum
? E.g. to eliminate the bound variablei
:funsor.ops
internally (notopt_einsum
with direct backends), or makeopt_einsum
an op or sthfunsor.ops
internallytrace_function()
to support multiple outputsTested
compile_funsor()
trace_function()