Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a compiler and tracer, each creating OpPrograms #557

Merged
merged 17 commits into from
Oct 4, 2021
18 changes: 18 additions & 0 deletions docs/source/compiler.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
Compiler & Tracer
-----------------

.. automodule:: funsor.compiler
:members:
:show-inheritance:
:member-order: bysource

.. automodule:: funsor.ops.tracer
:members:
:show-inheritance:
:member-order: bysource

.. automodule:: funsor.ops.program
:members:
:show-inheritance:
:member-order: bysource
:special-members: __call__
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Funsor is a tensor-like library for functions and distributions
distributions
minipyro
einsum
compiler

.. toctree::
:maxdepth: 1
Expand Down
2 changes: 2 additions & 0 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
affine,
approximations,
cnf,
compiler,
constant,
delta,
distribution,
Expand Down Expand Up @@ -78,6 +79,7 @@
"backward",
"bint",
"cnf",
"compiler",
"constant",
"delta",
"distribution",
Expand Down
142 changes: 142 additions & 0 deletions funsor/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools

import funsor

from .cnf import Contraction
from .ops.program import OpProgram, make_tuple
from .tensor import Tensor
from .terms import Binary, Funsor, Number, Tuple, Unary, Variable


def compile_funsor(expr: Funsor) -> OpProgram:
"""
Compiles a symbolic :class:`~funsor.terms.Funsor` to an
:class:`~funsor.ops.program.OpProgram` that runs on backend values.

Example::

# Create a lazy expression.
a = Variable("a", Reals[3, 3])
b = Variable("b", Reals[3])
x = Variable("x", Reals[3])
expr = a @ x + b

# Evaluate via Funsor substitution.
data = dict(a=randn(3, 3), b=randn(3), x=randn(3))
expected = expr(**data).data

# Alternatively evaluate via a program.
program = compile_funsor(expr)
actual = program(**data)
assert (acutal == expected).all()

:param Funsor expr: A funsor expression to evaluate.
:returns: An op program.
:rtype: ~funsor.ops.program.OpProgram
"""
assert isinstance(expr, Funsor)

# Lower and convert to A-normal form.
lowered_expr = lower(expr)
anf = list(funsor.interpreter.anf(lowered_expr))
ids = {}

# Collect constants (leaves).
constants = []
for f in anf:
if isinstance(f, (Number, Tensor)):
ids[f] = len(ids)
constants.append(f.data)

# Collect input variables (leaves).
inputs = []
for k, d in expr.inputs.items():
f = Variable(k, d)
ids[f] = len(ids)
inputs.append(k)

# Collect operations to be computed (internal nodes).
operations = []
for f in anf:
if f in ids:
continue # constant or free variable
ids[f] = len(ids)
if isinstance(f, Unary):
arg_ids = (ids[f.arg],)
operations.append((f.op, arg_ids))
elif isinstance(f, Binary):
arg_ids = (ids[f.lhs], ids[f.rhs])
operations.append((f.op, arg_ids))
elif isinstance(f, Tuple):
arg_ids = tuple(ids[arg] for arg in f.args)
operations.append((make_tuple, arg_ids))
elif isinstance(f, tuple):
continue # Skip from Tuple directly to its elements.
else:
raise NotImplementedError(type(f).__name__)

return OpProgram(constants, inputs, operations)


def lower(expr: Funsor) -> Funsor:
"""
Lower a funsor expression:
- eliminate bound variables
- convert Contraction to Binary

:param Funsor expr: An arbitrary funsor expression.
:returns: A lowered funsor expression.
:rtype: Funsor
"""
# FIXME should this be lazy? What about Lambda?
with funsor.interpretations.reflect:
return _lower(expr)


@functools.singledispatch
def _lower(x):
raise NotImplementedError(type(x).__name__)


@_lower.register(Number)
@_lower.register(Tensor)
@_lower.register(Variable)
def _lower_atom(x):
return x


@_lower.register(Tuple)
def _lower_tuple(x):
args = tuple(_lower(arg) for arg in x.args)
return Tuple(args)


@_lower.register(Unary)
def _lower_unary(x):
arg = _lower(x.arg)
return Unary(x.op, arg)


@_lower.register(Binary)
def _lower_binary(x):
lhs = _lower(x.lhs)
rhs = _lower(x.rhs)
return Binary(x.op, lhs, rhs)


@_lower.register(Contraction)
def _lower_contraction(x):
if x.reduced_vars:
raise NotImplementedError("TODO")

terms = [_lower(term) for term in x.terms]
bin_op = functools.partial(Binary, x.bin_op)
return functools.reduce(bin_op, terms)


__all__ = [
"lower",
]
4 changes: 3 additions & 1 deletion funsor/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from . import array, builtin, op
from . import array, builtin, op, program, tracer
from .array import *
from .builtin import *
from .op import *
from .program import *
from .tracer import *
35 changes: 34 additions & 1 deletion funsor/ops/op.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import contextlib
import functools
import inspect
import math
import operator
import weakref
from collections import OrderedDict

from funsor.registry import PartialDispatcher
from funsor.util import methodof
Expand Down Expand Up @@ -60,6 +62,23 @@ def __call__(self, *args, **kwargs):
return self.fn(arg, *args, **kwargs)


_TRACE = None
_TRACE_FILTER_ARGS = None


@contextlib.contextmanager
def trace_ops(filter_args):
global _TRACE, _TRACE_FILTER_ARGS
assert _TRACE is None, "not reentrant"
try:
_TRACE = OrderedDict()
_TRACE_FILTER_ARGS = filter_args
yield _TRACE
finally:
_TRACE = None
_TRACE_FILTER_ARGS = None


class OpMeta(type):
"""
Metaclass for :class:`Op` classes.
Expand Down Expand Up @@ -159,6 +178,9 @@ def __str__(self):
return self.__name__

def __call__(self, *args, **kwargs):
global _TRACE, _TRACE_FILTER_ARGS
raw_args = args

# Normalize args, kwargs.
cls = type(self)
bound = cls.signature.bind_partial(*args, **kwargs)
Expand All @@ -170,7 +192,18 @@ def __call__(self, *args, **kwargs):

# Dispatch.
fn = cls.dispatcher.partial_call(*args[: cls.arity])
return fn(*args, **kwargs)
if _TRACE is None or not _TRACE_FILTER_ARGS(raw_args):
result = fn(*args, **kwargs)
else:
# Trace this op but avoid tracing internal ops.
try:
trace, _TRACE = _TRACE, None
result = fn(*args, **kwargs)
trace.setdefault(id(result), (result, self, raw_args))
finally:
_TRACE = trace

return result

def register(self, *pattern):
if len(pattern) != self.arity:
Expand Down
103 changes: 103 additions & 0 deletions funsor/ops/program.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from funsor.util import get_backend, set_backend


class OpProgram:
"""
Backend program for evaluating a symbolic funsor expression.

Programs depend on the funsor library only via ``funsor.ops`` and op
registrations; program evaluation does not involve funsor interpretation or
rewriting. Programs can be pickled and unpickled.

:param iterable expr: A list of built-in constants (leaves).
:param iterable inputs: A list of string names of program inputs (leaves).
:param iterable operations: A list of program operations defining
non-leaf nodes in the program dag. Each operations is a tuple ``(op,
arg_ids)`` where op is a funsor op and ``arg_ids`` is a tuple of
positions of values, starting from zero and counting: constants,
inputs, and operation outputs.
"""

def __init__(self, constants, inputs, operations):
super().__init__()
self.constants = tuple(constants)
self.inputs = tuple(inputs)
self.operations = tuple(operations)
self.backend = get_backend()

def __call__(self, **kwargs):
set_backend(self.backend)

# Initialize environment with constants.
env = list(self.constants)

# Read inputs from kwargs.
for name in self.inputs:
value = kwargs.pop(name, None)
if value is None:
raise ValueError(f"Missing kwarg: {repr(name)}")
env.append(value)
if kwargs:
raise ValueError(f"Unrecognized kwargs: {set(kwargs)}")

# Sequentially compute ops.
for op, arg_ids in self.operations:
args = tuple(env[i] for i in arg_ids)
value = op(*args)
env.append(value)

result = env[-1]
return result

def as_code(self, name="program"):
"""
Returns Python code text defining a straight-line function equivalent
to this program.

:param str name: Optional name for the function, defaults to "program".
:returns: A string defining a python function equivalent to this program.
:rtype: str
"""
lines = [
"# Automatically generated by funsor.compiler.FunsorProgram.as_code().",
"def {}({}):".format(name, ", ".join(self.inputs)),
" from funsor import set_backend, ops",
f" set_backend({repr(self.backend)})",
]
start = len(lines)

def let(body):
i = len(lines) - start
lines.append(f" v{i} = {body}")

for c in self.constants:
let(c)
for name in self.inputs:
let(name)
for op, arg_ids in self.operations:
op = _print_op(op)
args = ", ".join(f"v{arg_id}" for arg_id in arg_ids)
let(f"{op}({args},)")
lines.append(f" return v{len(lines) - start - 1}")
return "\n".join(lines)


def make_tuple(*args):
return args


def _print_op(op):
if op is make_tuple:
return ""
if op.defaults and op.defaults != type(op)().defaults:
args = ", ".join(map(str, op.defaults.values()))
return f"ops.{type(op).__name__}({args})"
return repr(op)


__all__ = [
"OpProgram",
]
Loading