diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index 3e1beca..7d4d6d6 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -142,6 +142,7 @@ def f(x): from __future__ import annotations import collections +import contextlib import dataclasses import functools from typing import Any, Callable, Dict, FrozenSet, Hashable, Iterable, List, Optional, Tuple, Union @@ -163,13 +164,10 @@ def f(x): from jax.interpreters import mlir from jax.interpreters import partial_eval as pe import jax.numpy as jnp -from oryx.core import primitive as prim -from oryx.core import trace_util __all__ = [ 'HarvestTrace', - 'HarvestTracer', 'call_and_reap', 'harvest', 'nest', @@ -251,7 +249,8 @@ def sow(value, *, tag: Hashable, name: str, mode: str = 'strict', key=None): """ if mode == 'cond_clobber': raise ValueError("For 'cond_clobber' mode, use `sow_cond`.'") - return _sow(value, tag=tag, name=name, mode=mode, key=key) + with jax_core.take_current_trace() as trace: + return _sow(trace, value, tag=tag, name=name, mode=mode, key=key) def sow_cond( @@ -288,18 +287,20 @@ def sow_cond( """ if mode != 'cond_clobber': raise ValueError("`sow_cond` only supports 'cond_clobber' mode.") - return _sow(value, tag=tag, name=name, mode=mode, key=key, pred=pred)[0] + with jax_core.take_current_trace() as trace: + return _sow(trace, value, tag=tag, name=name, + mode=mode, key=key, pred=pred)[0] -def _sow(value, *, tag, name, mode, key=None, pred=None): +def _sow(trace, value, *, tag, name, mode, key=None, pred=None): + del key assert (pred is not None) == (mode == 'cond_clobber') if pred is not None: value = value, pred - value = tree_util.tree_map(jax_core.raise_as_much_as_possible, value) - if key is not None: - value = prim.tie_in(key, value) flat_args, in_tree = tree_util.tree_flatten(value) - out_flat = sow_p.bind(*flat_args, name=name, tag=tag, mode=mode, tree=in_tree) + out_flat = sow_p.bind_with_trace( + trace, flat_args, + dict(name=name, tag=tag, mode=mode, tree=in_tree)) return tree_util.tree_unflatten(in_tree, out_flat) @@ -307,8 +308,7 @@ def _sow(value, *, tag, name, mode, key=None, pred=None): def _nest_impl(f, *args, **_): - with jax_core.new_sublevel(): - return f.call_wrapped(*args) + return f.call_wrapped(*args) nest_p.def_impl(_nest_impl) @@ -378,110 +378,56 @@ def wrapped(*args, **kwargs): class HarvestTrace(jax_core.Trace): """An evaluating trace that dispatches to a dynamic context.""" - def pure(self, val: Value) -> HarvestTracer: - return HarvestTracer(self, val) - - def sublift(self, tracer: HarvestTracer) -> HarvestTracer: - return self.pure(tracer.val) - - def lift(self, val: Value) -> HarvestTracer: - return self.pure(val) + def __init__(self, parent_trace, context): + self.parent_trace = parent_trace + self.context = context def process_primitive( - self, primitive: jax_core.Primitive, tracers: List[HarvestTracer], - params: Dict[str, Any]) -> Union[HarvestTracer, List[HarvestTracer]]: - context = trace_util.get_dynamic_context(self) - custom_rule = context.get_custom_rule(primitive) + self, primitive: jax_core.Primitive, vals: List[Any], + params: Dict[str, Any]) -> Union[Any, List[Any]]: + custom_rule = self.context.get_custom_rule(primitive) if custom_rule: - return custom_rule(self, *tracers, **params) - return self.default_process_primitive(primitive, tracers, params) + return custom_rule(self, *vals, **params) + return self.default_process_primitive(primitive, vals, params) def default_process_primitive( - self, primitive: jax_core.Primitive, tracers: List[HarvestTracer], - params: Dict[str, Any]) -> Union[HarvestTracer, List[HarvestTracer]]: - context = trace_util.get_dynamic_context(self) - vals = [t.val for t in tracers] + self, primitive: jax_core.Primitive, vals: List[Any], + params: Dict[str, Any]) -> Union[Any, List[Any]]: if primitive is sow_p: - outvals = context.process_sow(*vals, **params) - return jax_util.safe_map(self.pure, outvals) - outvals = primitive.bind(*vals, **params) + with jax_core.set_current_trace(self.parent_trace): + return self.context.process_sow(*vals, **params) + outvals = primitive.bind_with_trace(self.parent_trace, vals, params) if not primitive.multiple_results: outvals = [outvals] - out_tracers = jax_util.safe_map(self.pure, outvals) if primitive.multiple_results: - return out_tracers - return out_tracers[0] + return outvals + return outvals[0] def process_call(self, call_primitive: jax_core.Primitive, f: Any, - tracers: List[HarvestTracer], params: Dict[str, Any]): - context = trace_util.get_dynamic_context(self) + vals: List[Any], params: Dict[str, Any]): + context = self.context if call_primitive is nest_p: - return context.process_nest(self, f, *tracers, **params) + return context.process_nest(self, f, *vals, **params) return context.process_higher_order_primitive(self, call_primitive, f, - tracers, params, False) - - def post_process_call(self, call_primitive, out_tracers, params): - vals = tuple(t.val for t in out_tracers) - master = self.main - - def todo(x): - trace = HarvestTrace(master, jax_core.cur_sublevel()) - return jax_util.safe_map(functools.partial(HarvestTracer, trace), x) - - return vals, todo + vals, params, False) def process_map(self, call_primitive: jax_core.Primitive, f: Any, - tracers: List[HarvestTracer], params: Dict[str, Any]): - context = trace_util.get_dynamic_context(self) - return context.process_higher_order_primitive(self, call_primitive, f, - tracers, params, True) - - post_process_map = post_process_call + vals: List[Any], params: Dict[str, Any]): + return self.context.process_higher_order_primitive( + self, call_primitive, f, vals, params, True) - def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, + def process_custom_jvp_call(self, primitive, fun, jvp, vals, *, symbolic_zeros): - context = trace_util.get_dynamic_context(self) - return context.process_custom_jvp_call(self, primitive, fun, jvp, tracers, - symbolic_zeros=symbolic_zeros) + return self.context.process_custom_jvp_call( + self, primitive, fun, jvp, vals, symbolic_zeros=symbolic_zeros) - def process_shard_map(self, primitive, f, tracers, **params): - out_flat = primitive.bind(f, *[t.val for t in tracers], **params) - out_tracers = map(self.pure, out_flat) - return out_tracers - - def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): - context = trace_util.get_dynamic_context(self) - return context.post_process_custom_jvp_call(self, out_tracers, jvp_was_run) + def process_shard_map(self, primitive, f, vals, **params): + return primitive.bind_with_trace(self.parent_trace, (f, *vals), params) - def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, + def process_custom_vjp_call(self, primitive, fun, fwd, bwd, vals, out_trees, symbolic_zeros): - context = trace_util.get_dynamic_context(self) - return context.process_custom_vjp_call(self, primitive, fun, fwd, bwd, - tracers, out_trees, symbolic_zeros) - - def post_process_custom_vjp_call(self, out_tracers, params): - context = trace_util.get_dynamic_context(self) - return context.post_process_custom_vjp_call(self, out_tracers, params) - - def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees): - context = trace_util.get_dynamic_context(self) - return context.post_process_custom_vjp_call_fwd(self, out_tracers, - out_trees) - - -class HarvestTracer(jax_core.Tracer): - """A `HarvestTracer` just encapsulates a single value.""" - - def __init__(self, trace: 'HarvestTrace', val: Value): - self._trace = trace - self.val = val - - @property - def aval(self): - return jax_core.raise_to_shaped(jax_core.get_aval(self.val)) - - def full_lower(self): - return self + return self.context.process_custom_vjp_call( + self, primitive, fun, fwd, bwd, vals, out_trees, symbolic_zeros) @dataclasses.dataclass(frozen=True) @@ -519,32 +465,23 @@ def get_custom_rule(self, primitive): def handle_sow(self, *values, name, tag, mode, tree): raise NotImplementedError - def process_nest(self, trace, f, *tracers, scope, name): + def process_nest(self, trace, f, *vals, scope, name): raise NotImplementedError def process_higher_order_primitive(self, trace: HarvestTrace, call_primitive: jax_core.Primitive, f: Any, - tracers: List[HarvestTracer], + vals: List[Any], params: Dict[str, Any], is_map: bool): raise NotImplementedError - def process_custom_jvp_call(self, trace, primitive, fun, jvp, tracers, *, + def process_custom_jvp_call(self, trace, primitive, fun, jvp, vals, *, symbolic_zeros): raise NotImplementedError - def post_process_custom_jvp_call(self, trace, out_tracers, jvp_was_run): - raise NotImplementedError - - def process_custom_vjp_call(self, trace, primitive, fun, fwd, bwd, tracers, + def process_custom_vjp_call(self, trace, primitive, fun, fwd, bwd, vals, out_trees, symbolic_zeros): raise NotImplementedError - def post_process_custom_vjp_call(self, trace, out_tracers, params): - raise NotImplementedError - - def post_process_custom_vjp_call_fwd(self, trace, out_tracers, out_trees): - raise NotImplementedError - reap_custom_rules = {} @@ -589,12 +526,11 @@ def handle_sow(self, *values, name, tag, tree, mode): self.reaps[name] = Reap(vals, pred, metadata) return values - def reap_higher_order_primitive(self, trace, call_primitive, f, tracers, + def reap_higher_order_primitive(self, trace, call_primitive, f, vals, params, is_map): """Wraps the inner function with a reap trace.""" name = jax_util.wrap_name(params.pop('name', f.__name__), 'reap') - vals = [t.val for t in tracers] - f, aux = reap_eval(f, trace, self.settings) + f, aux = reap_eval(f, self.settings) if is_map: out_axes_thunk = params['out_axes_thunk'] @@ -607,17 +543,15 @@ def new_out_axes_thunk(): return (0,) * out_tree.num_leaves params = dict(params, out_axes_thunk=new_out_axes_thunk) - out_flat = call_primitive.bind(f, *vals, name=name, **params) + out_flat = call_primitive.bind_with_trace( + trace.parent_trace, (f, *vals), dict(params, name=name)) out_tree, metadata = aux() out_vals, reaps, preds = tree_util.tree_unflatten(out_tree, out_flat) - out_tracers = jax_util.safe_map(trace.pure, out_vals) - reap_tracers = tree_util.tree_map(trace.pure, reaps) - pred_tracers = tree_util.tree_map(trace.pure, preds) - return out_tracers, reap_tracers, pred_tracers, metadata + return out_vals, reaps, preds, metadata - def process_nest(self, trace, f, *tracers, scope, name, **params): + def process_nest(self, trace, f, *vals, scope, name, **params): out_tracers, reap_tracers, _, _ = self.reap_higher_order_primitive( - trace, nest_p, f, tracers, dict(params, name=name, scope=scope), False) + trace, nest_p, f, vals, dict(params, name=name, scope=scope), False) tag = self.settings.tag if reap_tracers: flat_reap_tracers, reap_tree = tree_util.tree_flatten(reap_tracers) @@ -626,11 +560,11 @@ def process_nest(self, trace, f, *tracers, scope, name, **params): dict(name=scope, tag=tag, tree=reap_tree, mode='strict')) return out_tracers - def process_higher_order_primitive(self, trace, call_primitive, f, tracers, + def process_higher_order_primitive(self, trace, call_primitive, f, vals, params, is_map): out_tracers, reap_tracers, pred_tracers, metadata = ( self.reap_higher_order_primitive( - trace, call_primitive, f, tracers, params, is_map + trace, call_primitive, f, vals, params, is_map ) ) tag = self.settings.tag @@ -643,113 +577,77 @@ def process_higher_order_primitive(self, trace, call_primitive, f, tracers, dict(name=k, tag=tag, tree=reap_tree, mode=metadata[k]['mode'])) return out_tracers - def process_custom_jvp_call(self, trace, primitive, fun, jvp, tracers, *, + def process_custom_jvp_call(self, trace, primitive, fun, jvp, vals, *, symbolic_zeros): - context = trace_util.get_dynamic_context(trace) - vals_in = [t.val for t in tracers] - fun, aux1 = reap_eval(fun, trace, context.settings) + fun, aux1 = reap_eval(fun, self.settings) @lu.transformation_with_aux - def _jvp_subtrace(main, *args): - trace = main.with_cur_sublevel() - in_tracers = jax_util.safe_map(trace.pure, args) - outs = yield in_tracers, {} - out_tracers = jax_util.safe_map(trace.full_raise, outs) - yield out_tracers, (None, None) - - jvp, aux2 = _jvp_subtrace(jvp, trace.main) - out_flat = primitive.bind(fun, jvp, *vals_in, symbolic_zeros=symbolic_zeros) + def _jvp_subtrace(context, *args): + with harvest_trace(context): + outs = yield args, {} + yield outs, (None, None) + + jvp, aux2 = _jvp_subtrace(jvp, self) + out_flat = primitive.bind_with_trace( + trace.parent_trace, (fun, jvp, *vals), + dict(symbolic_zeros=symbolic_zeros)) fst, (out_tree, metadata) = lu.merge_linear_aux(aux1, aux2) if fst: out, reaps, preds = tree_util.tree_unflatten(out_tree, out_flat) - out_tracers, reap_tracers, pred_tracers = tree_util.tree_map( - trace.pure, (out, reaps, preds) - ) - tag = context.settings.tag - for k, v in reap_tracers.items(): + tag = self.settings.tag + for k, v in reaps.items(): if metadata[k]['mode'] == 'cond_clobber': - v = (v, pred_tracers[k]) + v = (v, preds[k]) flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) trace.process_primitive( sow_p, flat_reap_tracers, dict(name=k, tag=tag, tree=reap_tree, mode=metadata[k]['mode'])) else: - out_tracers = jax_util.safe_map(trace.pure, out_flat) - return out_tracers - - def post_process_custom_jvp_call(self, trace, out_tracers, jvp_was_run): - del jvp_was_run - vals = [t.val for t in out_tracers] - return vals, lambda vals: vals + out = out_flat + return out - def process_custom_vjp_call(self, trace, primitive, fun, fwd, bwd, tracers, + def process_custom_vjp_call(self, trace, primitive, fun, fwd, bwd, vals, out_trees, symbolic_zeros): - context = trace_util.get_dynamic_context(trace) - vals_in = [t.val for t in tracers] - fun, aux1 = reap_eval(fun, trace, context.settings) + fun, aux1 = reap_eval(fun, self.settings) @lu.transformation_with_aux - def _fwd_subtrace(main, *args): - trace = main.with_cur_sublevel() - in_tracers = jax_util.safe_map(trace.pure, args) - outs = yield in_tracers, {} - out_tracers = jax_util.safe_map(trace.full_raise, outs) - yield out_tracers, (None, None) - - fwd, aux2 = _fwd_subtrace(fwd, trace.main) - bwd_ = reap_function(lu.wrap_init(bwd), trace.main, context.settings, True) - bwd = reap_wrapper_drop_aux(bwd_, trace).call_wrapped - out_flat = primitive.bind(fun, fwd, bwd, *vals_in, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) + def _fwd_subtrace(context, *args): + with harvest_trace(context): + outs = yield args, {} + yield outs, (None, None) + + fwd, aux2 = _fwd_subtrace(fwd, self) + bwd_ = reap_function(lu.wrap_init(bwd), self.settings, True) + bwd = reap_wrapper_drop_aux(bwd_).call_wrapped + out_flat = primitive.bind_with_trace( + trace.parent_trace, (fun, fwd, bwd, *vals), + dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) fst, (out_tree, metadata) = lu.merge_linear_aux(aux1, aux2) if fst: out, reaps, preds = tree_util.tree_unflatten(out_tree, out_flat) - out_tracers, reap_tracers, pred_tracers = tree_util.tree_map( - trace.pure, (out, reaps, preds) - ) - tag = context.settings.tag - for k, v in reap_tracers.items(): + tag = self.settings.tag + for k, v in reaps.items(): if metadata[k]['mode'] == 'cond_clobber': - v = (v, pred_tracers[k]) + v = (v, preds[k]) flat_reap_tracers, reap_tree = tree_util.tree_flatten(v) trace.process_primitive( sow_p, flat_reap_tracers, dict(name=k, tag=tag, tree=reap_tree, mode=metadata[k]['mode'])) else: - out_tracers = jax_util.safe_map(trace.pure, out_flat) - return out_tracers - - def post_process_custom_vjp_call(self, trace, out_tracers, params): - del params - vals = [t.val for t in out_tracers] - return vals, lambda vals: vals - - def post_process_custom_vjp_call_fwd(self, trace, out_tracers, out_trees): - vals = [t.val for t in out_tracers] - todo = lambda vals: vals - bwd_transform = lambda bwd: bwd - return vals, todo, bwd_transform + out = out_flat + return out @lu.transformation -def reap_function(main: jax_core.MainTrace, settings: HarvestSettings, +def reap_function(settings: HarvestSettings, return_metadata: bool, args: Iterable[Any]): """A function transformation that returns reap values and predicates.""" - trace = HarvestTrace(main, jax_core.cur_sublevel()) - in_tracers = jax_util.safe_map(trace.pure, args) context = ReapContext(settings, {}) - with trace_util.new_dynamic_context(main, context): - ans = yield in_tracers, {} - out_tracers = jax_util.safe_map(trace.full_raise, ans) - reap_tracers = tree_util.tree_map( - lambda x: tree_util.tree_map(trace.full_raise, x.value), context.reaps) - pred_tracers = tree_util.tree_map( - lambda x: trace.full_raise(x.pred), context.reaps) + with harvest_trace(context): + out_values = yield args, {} + reap_values = tree_util.tree_map(lambda x: x.value, context.reaps) + pred_values = tree_util.tree_map(lambda x: x.pred, context.reaps) reap_metadata = tree_util.tree_map(lambda x: x.metadata, context.reaps) - del main - out_values, reap_values, pred_values = tree_util.tree_map( - lambda x: x.val, (out_tracers, reap_tracers, pred_tracers) - ) if return_metadata: out = (out_values, reap_values, pred_values, reap_metadata) else: @@ -758,23 +656,21 @@ def reap_function(main: jax_core.MainTrace, settings: HarvestSettings, def reap_eval( - f: lu.WrappedFun, trace: HarvestTrace, + f: lu.WrappedFun, settings: HarvestSettings) -> Tuple[lu.WrappedFun, Callable[[], Any]]: - f = reap_function(f, trace.main, settings, True) - return reap_wrapper(f, trace) + f = reap_function(f, settings, True) + return reap_wrapper(f) @lu.transformation_with_aux -def reap_wrapper(trace: HarvestTrace, *args): - del trace +def reap_wrapper(*args): out, reaps, preds, metadata = yield (args,), {} out_flat, out_tree = tree_util.tree_flatten((out, reaps, preds)) yield out_flat, (out_tree, metadata) @lu.transformation -def reap_wrapper_drop_aux(trace: HarvestTrace, *args): - del trace +def reap_wrapper_drop_aux(*args): out, reaps, preds, _ = yield (args,), {} out_flat, _ = tree_util.tree_flatten((out, reaps, preds)) yield out_flat @@ -839,10 +735,8 @@ def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) - with jax_core.new_main(HarvestTrace) as main: - flat_fun = reap_function(flat_fun, main, settings, False) - out_flat, reaps, preds = flat_fun.call_wrapped(flat_args) - del main + flat_fun = reap_function(flat_fun, settings, False) + out_flat, reaps, preds = flat_fun.call_wrapped(flat_args) return tree_util.tree_unflatten(out_tree(), out_flat), reaps, preds return wrapped @@ -890,19 +784,19 @@ def _reap_metadata_wrapper(*args): def _get_harvest_metadata(closed_jaxpr, settings, *args): """Probes a jaxpr for metadata like its sown values.""" fun = lu.wrap_init(jax_core.jaxpr_as_fun(closed_jaxpr)) - with jax_core.new_main(HarvestTrace) as main: - settings = HarvestSettings(settings.tag, settings.blocklist, - settings.allowlist, True) - fun = reap_function(fun, main, settings, True) - fun, aux = _reap_metadata_wrapper(fun) - flat_args, in_tree = tree_util.tree_flatten(args) - flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) - in_avals = jax_util.safe_map( - lambda a: jax_core.raise_to_shaped(jax_core.get_aval(a)), - flat_args) - pe.trace_to_jaxpr_final(flat_fun, in_avals) - metadata = aux() - out_tree() + + settings = HarvestSettings(settings.tag, settings.blocklist, + settings.allowlist, True) + fun = reap_function(fun, settings, True) + fun, aux = _reap_metadata_wrapper(fun) + flat_args, in_tree = tree_util.tree_flatten(args) + flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) + in_avals = jax_util.safe_map( + lambda a: jax_core.raise_to_shaped(jax_core.get_aval(a)), + flat_args) + pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + metadata = aux() + out_tree() return metadata @@ -917,21 +811,19 @@ def _update_clobber_carry(carry_reaps, carry_preds, name, val, preds, mode): carry_reaps[name] = val -def _reap_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, +def _reap_scan_rule(trace: HarvestTrace, *vals, length, reverse, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): """Reaps the body of a scan to pull out `clobber` and `append` sows.""" - const_tracers, carry_tracers, xs_tracers = jax_util.split_list( - tracers, [num_consts, num_carry]) + const_vals, carry_vals, xs_vals = jax_util.split_list( + vals, [num_consts, num_carry]) _, carry_avals, xs_avals = tree_util.tree_map( - lambda x: x.aval, (const_tracers, carry_tracers, xs_tracers)) - const_vals, carry_vals, xs_vals = tree_util.tree_map( - lambda x: x.val, (const_tracers, carry_tracers, xs_tracers)) - context = trace_util.get_dynamic_context(trace) - settings = context.settings - x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers] - x_avals = [t.aval for t in x_tracers] - x_vals = [t.val for t in x_tracers] + lambda x: jax_core.get_aval(x), (const_vals, carry_vals, xs_vals)) # pylint: disable=unnecessary-lambda + settings = trace.context.settings + with jax_core.set_current_trace(trace.parent_trace): + x_vals = [t[0] if hasattr(jax_core.get_aval(t), '_getitem') else t + for t in xs_vals] + x_avals = [jax_core.get_aval(t) for t in x_vals] metadata = _get_harvest_metadata(jaxpr, settings, *(const_vals + carry_vals + x_vals)) @@ -982,32 +874,34 @@ def new_body(carry, x): new_body_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, reap_carry_in_tree, tuple(carry_avals + reap_carry_flat_avals + x_avals)) - dummy_reap_carry_vals = tree_util.tree_map( - lambda x: jnp.zeros(x.shape, x.dtype), - reap_carry_flat_avals, - ) - out = lax.scan_p.bind( - *(consts + carry_vals + dummy_reap_carry_vals + xs_vals), - reverse=reverse, - length=length, - jaxpr=new_body_jaxpr, - num_consts=len(consts), - num_carry=len(carry_vals + dummy_reap_carry_vals), - linear=(linear[:len(consts)] + (False,) * len(dummy_reap_carry_vals) + - linear[len(consts):]), - unroll=unroll, - _split_transpose=_split_transpose) + + with jax_core.set_current_trace(trace.parent_trace): + dummy_reap_carry_vals = tree_util.tree_map( + lambda x: jnp.zeros(x.shape, x.dtype), + reap_carry_flat_avals, + ) + out = lax.scan_p.bind_with_trace( + trace.parent_trace, + (consts + carry_vals + dummy_reap_carry_vals + xs_vals), + dict(reverse=reverse, + length=length, + jaxpr=new_body_jaxpr, + num_consts=len(consts), + num_carry=len(carry_vals + dummy_reap_carry_vals), + linear=( + linear[:len(consts)] + (False,) * len(dummy_reap_carry_vals) + + linear[len(consts):]), + unroll=unroll, + _split_transpose=_split_transpose)) (carry_out, carry_reaps, carry_preds), (ys, ys_reaps) = ( tree_util.tree_unflatten(out_tree, out) ) - (carry_out, carry_reaps, carry_preds), (ys, ys_reaps) = tree_util.tree_map( - trace.pure, ((carry_out, carry_reaps, carry_preds), (ys, ys_reaps)) - ) for k, v in carry_reaps.items(): mode = metadata[k]['mode'] - _sow(v, tag=settings.tag, mode=mode, name=k, pred=carry_preds[k]) + _sow(trace, v, tag=settings.tag, mode=mode, name=k, pred=carry_preds[k]) for k, v in ys_reaps.items(): - sow(v, tag=settings.tag, mode=metadata[k]['mode'], name=k) + mode = metadata[k]['mode'] + _sow(trace, v, tag=settings.tag, mode=mode, name=k) return carry_out + ys @@ -1017,16 +911,13 @@ def new_body(carry, x): def _reap_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): """Reaps the body of a while loop to get the reaps of `clobber` sows.""" - cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( + cond_const_vals, body_const_vals, init_vals = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) - _, init_avals = tree_util.tree_map(lambda x: x.aval, - (body_const_tracers, init_tracers)) - cond_const_vals, body_const_vals, init_vals = tree_util.tree_map( - lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) - context = trace_util.get_dynamic_context(trace) - settings = context.settings + _, init_avals = tree_util.tree_map(lambda x: jax_core.get_aval(x), # pylint: disable=unnecessary-lambda + (body_const_vals, init_vals)) + settings = trace.context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, - *(body_const_tracers + init_tracers)) + *(body_const_vals + init_vals)) reap_avals = {} cond_avals = collections.defaultdict(lambda: None) for k, meta in body_metadata.items(): @@ -1070,17 +961,17 @@ def new_body(carry, carry_reaps, carry_preds): dummy_reap_vals = tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), (reap_avals, cond_avals)) new_in_vals = tree_util.tree_leaves((init_vals, dummy_reap_vals)) - out = lax.while_p.bind( - *(cond_consts + body_consts + new_in_vals), - cond_nconsts=len(cond_consts), - body_nconsts=len(body_consts), - cond_jaxpr=new_cond_jaxpr, - body_jaxpr=new_body_jaxpr) - out = jax_util.safe_map(trace.pure, out) + out = lax.while_p.bind_with_trace( + trace.parent_trace, + (cond_consts + body_consts + new_in_vals), + dict(cond_nconsts=len(cond_consts), + body_nconsts=len(body_consts), + cond_jaxpr=new_cond_jaxpr, + body_jaxpr=new_body_jaxpr)) out, reaps, preds = tree_util.tree_unflatten(out_tree, out) for k, v in reaps.items(): mode = body_metadata[k]['mode'] - _sow(v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) + _sow(trace, v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) return out @@ -1105,20 +996,17 @@ def _check_branch_metadata(branch_metadatas): def _reap_cond_rule(trace, *tracers, branches, linear=None): """Reaps each path of the `cond`.""" - index_tracer, ops_tracers = tracers[0], tracers[1:] - index_val, ops_vals = tree_util.tree_map(lambda x: x.val, - (index_tracer, ops_tracers)) - _, ops_avals = tree_util.tree_map(lambda x: x.aval, - (index_tracer, ops_tracers)) - context = trace_util.get_dynamic_context(trace) - settings = context.settings + index_val, ops_vals = tracers[0], tracers[1:] + _, ops_avals = tree_util.tree_map(lambda x: jax_core.get_aval(x), # pylint: disable=unnecessary-lambda + (index_val, ops_vals)) + settings = trace.context.settings reap_settings = dict( tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) branch_metadatas = tuple( - _get_harvest_metadata(branch, settings, *ops_tracers) + _get_harvest_metadata(branch, settings, *ops_vals) for branch in branches) _check_branch_metadata(branch_metadatas) branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) @@ -1129,57 +1017,53 @@ def _reap_cond_rule(trace, *tracers, branches, linear=None): lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access reaped_branches, in_tree, ops_avals, lax.cond_p.name)) if linear is None: - out = lax.cond_p.bind( - index_val, - *(tuple(consts) + ops_vals), - branches=tuple(new_branch_jaxprs)) + out = lax.cond_p.bind_with_trace( + trace.parent_trace, + (index_val, *consts, *ops_vals), + dict(branches=tuple(new_branch_jaxprs))) else: - out = lax.cond_p.bind( - index_val, - *(tuple(consts) + ops_vals), - branches=tuple(new_branch_jaxprs), - linear=(False,) * len(tuple(consts) + linear)) - out = jax_util.safe_map(trace.pure, out) + out = lax.cond_p.bind_with_trace( + trace.parent_trace, + (index_val, *consts, *ops_vals), + dict(branches=tuple(new_branch_jaxprs), + linear=(False,) * len(tuple(consts) + linear))) out, reaps, preds = tree_util.tree_unflatten(out_trees[0], out) for k, v in reaps.items(): mode = branch_metadatas[0][k]['mode'] - _sow(v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) + _sow(trace, v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) return out reap_custom_rules[lcf.cond_p] = _reap_cond_rule -def _reap_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse, +def _reap_checkpoint_rule(trace, *invals, jaxpr, policy, prevent_cse, differentiated): """Reap checkpoint rule.""" - invals = [t.val for t in tracers] - context = trace_util.get_dynamic_context(trace) - settings = context.settings + settings = trace.context.settings reap_settings = dict( tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) - reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *tracers) + reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *invals) remat_fun = jax_core.jaxpr_as_fun(closed_jaxpr) reaped_remat_fun = _call_and_reap(remat_fun, **reap_settings) reap_jaxpr, consts, out_tree = lcf._initial_style_jaxpr( # pylint: disable=protected-access reaped_remat_fun, tree_util.tree_structure(invals), - tuple(t.aval for t in tracers)) - outvals = ad_checkpoint.remat_p.bind( - *consts, - *invals, - jaxpr=reap_jaxpr.jaxpr, - policy=policy, - prevent_cse=prevent_cse, - differentiated=differentiated) - outvals = jax_util.safe_map(trace.pure, outvals) + tuple(jax_core.get_aval(t) for t in invals)) + outvals = ad_checkpoint.remat_p.bind_with_trace( + trace.parent_trace, + (*consts, *invals), + dict(jaxpr=reap_jaxpr.jaxpr, + policy=policy, + prevent_cse=prevent_cse, + differentiated=differentiated)) out, reaps, preds = tree_util.tree_unflatten(out_tree, outvals) for k, v in reaps.items(): mode = reap_metadata[k]['mode'] - _sow(v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) + _sow(trace, v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) return out @@ -1208,7 +1092,7 @@ def _calc_extra_inps(num_consts, params): return in_shardings, donated_invars, in_layouts -def _reap_pjit_rule(trace, *tracers, **params): +def _reap_pjit_rule(trace, *invals, **params): """Reap pjit rule.""" if params['in_shardings'] and not any( isinstance(i, sharding_impls.UnspecifiedValue) for i in params['in_shardings'] @@ -1225,23 +1109,21 @@ def _reap_pjit_rule(trace, *tracers, **params): f'specified. Got {params["out_shardings"]}' ) - invals = [t.val for t in tracers] - context = trace_util.get_dynamic_context(trace) - settings = context.settings + settings = trace.context.settings reap_settings = dict( tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) closed_jaxpr = params['jaxpr'] - reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *tracers) + reap_metadata = _get_harvest_metadata(closed_jaxpr, settings, *invals) pjit_fun = jax_core.jaxpr_as_fun(closed_jaxpr) reaped_pjit_fun = lu.wrap_init(_call_and_reap(pjit_fun, **reap_settings)) in_tree = tree_util.tree_structure(invals) flat_fun, out_tree = api_util.flatten_fun_nokwargs(reaped_pjit_fun, in_tree) reap_jaxpr, final_consts, out_avals = _oryx_pjit_jaxpr( - flat_fun, tuple(t.aval for t in tracers)) + flat_fun, tuple(jax_core.get_aval(t) for t in invals)) in_shardings, donated_invars, in_layouts = _calc_extra_inps( len(final_consts), params) @@ -1254,13 +1136,13 @@ def _reap_pjit_rule(trace, *tracers, **params): 'in_layouts': in_layouts, 'out_layouts': (None,) * len(out_avals) } - outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params) + outvals = pjit.pjit_p.bind_with_trace( + trace.parent_trace, (*final_consts, *invals), new_params) - outvals = jax_util.safe_map(trace.pure, outvals) out, reaps, preds = tree_util.tree_unflatten(out_tree(), outvals) for k, v in reaps.items(): mode = reap_metadata[k]['mode'] - _sow(v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) + _sow(trace, v, name=k, tag=settings.tag, mode=mode, pred=preds[k]) return out @@ -1299,13 +1181,11 @@ def process_nest(self, trace, f, *tracers, scope, name, **params): return self.process_higher_order_primitive( trace, nest_p, f, tracers, dict(params, name=name, scope=scope), False) - def process_higher_order_primitive(self, trace, call_primitive, f, tracers, + def process_higher_order_primitive(self, trace, call_primitive, f, vals, params, is_map): del is_map name = jax_util.wrap_name(params.pop('name', f.__name__), 'reap') - context = trace_util.get_dynamic_context(trace) - vals = [t.val for t in tracers] - plants = context.plants + plants = trace.context.plants if 'in_axes' in params: # TODO(b/199459308): figure out if invars are mapped or unmapped params = dict( @@ -1319,79 +1199,61 @@ def process_higher_order_primitive(self, trace, call_primitive, f, tracers, elif call_primitive is nest_p: plants = plants.get(params['scope'], {}) all_vals, all_tree = tree_util.tree_flatten((plants, vals)) - f = plant_eval(f, trace, self.settings, all_tree) - out_vals = call_primitive.bind(f, *all_vals, name=name, **params) - return jax_util.safe_map(trace.pure, out_vals) + f = plant_eval(f, self.settings, all_tree) + return call_primitive.bind_with_trace( + trace.parent_trace, (f, *all_vals), dict(name=name, **params)) - def process_custom_jvp_call(self, trace, primitive, fun, jvp, tracers, *, + def process_custom_jvp_call(self, trace, primitive, fun, jvp, vals, *, symbolic_zeros): - vals_in = [t.val for t in tracers] - - @lu.transformation - def _subtrace(main: jax_core.MainTrace, *args: Iterable[Any]): - trace = main.with_cur_sublevel() - in_tracers = jax_util.safe_map(trace.pure, args) - outs = yield in_tracers, {} - yield jax_util.safe_map(trace.full_raise, outs) - - fun = _subtrace(fun, trace.main) - jvp = _subtrace(jvp, trace.main) - out_flat = primitive.bind(fun, jvp, *vals_in, symbolic_zeros=symbolic_zeros) - out_tracers = jax_util.safe_map(trace.pure, out_flat) - return out_tracers - - def post_process_custom_jvp_call(self, trace, out_tracers, jvp_was_run): - vals = [t.val for t in out_tracers] - return vals, lambda vals: vals - - def process_custom_vjp_call(self, trace, primitive, fun, fwd, bwd, tracers, + fun = _subtrace(fun, trace.context) + jvp = _subtrace(jvp, trace.context) + out_flat = primitive.bind_with_trace( + trace.parent_trace, + (fun, jvp) + tuple(vals), + dict(symbolic_zeros=symbolic_zeros)) + return out_flat + + def process_custom_vjp_call(self, trace, primitive, fun, fwd, bwd, vals, out_trees, symbolic_zeros): - vals_in = [t.val for t in tracers] + fun = _subtrace(fun, trace.context) + fwd = _subtrace(fwd, trace.context) + # We don't need to subtrace the `bwd` since it's triggered in another trace. + out_flat = primitive.bind_with_trace( + trace.parent_trace, + (fun, fwd, bwd) + tuple(vals), + dict(out_trees=out_trees, symbolic_zeros=symbolic_zeros)) + return out_flat - @lu.transformation - def _subtrace(main: jax_core.MainTrace, *args: Iterable[Any]): - trace = main.with_cur_sublevel() - in_tracers = jax_util.safe_map(trace.pure, args) - outs = yield in_tracers, {} - yield jax_util.safe_map(trace.full_raise, outs) - fun = _subtrace(fun, trace.main) - fwd = _subtrace(fwd, trace.main) - # We don't need to subtrace the `bwd` since it's triggered in another trace. - out_flat = primitive.bind(fun, fwd, bwd, *vals_in, out_trees=out_trees, - symbolic_zeros=symbolic_zeros) - return jax_util.safe_map(trace.pure, out_flat) +@contextlib.contextmanager +def harvest_trace(context: HarvestContext): + with jax_core.take_current_trace() as parent_trace: + trace = HarvestTrace(parent_trace, context) + with jax_core.set_current_trace(trace): + yield - def post_process_custom_vjp_call(self, trace, out_tracers, params): - del params - vals = [t.val for t in out_tracers] - return vals, lambda vals: vals - def post_process_custom_vjp_call_fwd(self, trace, out_tracers, out_trees): - vals = [t.val for t in out_tracers] - todo = lambda vals: vals - bwd_transform = lambda bwd: bwd - return vals, todo, bwd_transform +@lu.transformation +def _subtrace(context: HarvestContext, *args: Iterable[Any]): + with harvest_trace(context): + outs = yield args, {} + yield outs @lu.transformation -def plant_function(main: jax_core.MainTrace, settings: HarvestSettings, +def plant_function(settings: HarvestSettings, in_tree: Any, args: Iterable[Any]): """A function transformation that injects values in place of sows.""" - trace = HarvestTrace(main, jax_core.cur_sublevel()) plants, args = tree_util.tree_unflatten(in_tree, args) - args = jax_util.safe_map(trace.pure, args) context = PlantContext(settings, plants) - with trace_util.new_dynamic_context(main, context): + with harvest_trace(context): ans = yield args, {} - out_tracers = jax_util.safe_map(trace.full_raise, ans) - del main - yield [t.val for t in out_tracers] + yield ans -def plant_eval(f: lu.WrappedFun, trace: HarvestTrace, settings: HarvestSettings, +def plant_eval(f: lu.WrappedFun, settings: HarvestSettings, all_tree: Any) -> Tuple[lu.WrappedFun, Callable[[], Any]]: - f = plant_function(f, trace.main, settings, all_tree) + f = plant_function(f, settings, all_tree) return plant_wrapper(f) @@ -1435,10 +1297,8 @@ def wrapped(plants, *args, **kwargs): flat_args, in_tree = tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) all_args, all_tree = tree_util.tree_flatten((plants, flat_args)) - with jax_core.new_main(HarvestTrace) as main: - flat_fun = plant_function(flat_fun, main, settings, all_tree) - out_flat = flat_fun.call_wrapped(all_args) - del main + flat_fun = plant_function(flat_fun, settings, all_tree) + out_flat = flat_fun.call_wrapped(all_args) return tree_util.tree_unflatten(out_tree(), out_flat) return wrapped @@ -1448,20 +1308,20 @@ def _plant_scan_rule(trace: HarvestTrace, *tracers, length, reverse, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): """Injects values into a scan according to their sow mode.""" - const_tracers, carry_tracers, xs_tracers = jax_util.split_list( + const_vals, carry_vals, xs_vals = jax_util.split_list( tracers, [num_consts, num_carry]) - carry_avals, xs_avals = tree_util.tree_map(lambda x: x.aval, - (carry_tracers, xs_tracers)) - const_vals, carry_vals, xs_vals = tree_util.tree_map( - lambda x: x.val, (const_tracers, carry_tracers, xs_tracers)) - context = trace_util.get_dynamic_context(trace) - settings = context.settings - x_tracers = [t[0] if hasattr(t, '_getitem') else t for t in xs_tracers] - x_avals = [t.aval for t in x_tracers] + carry_avals, xs_avals = tree_util.tree_map(lambda x: jax_core.get_aval(x), # pylint: disable=unnecessary-lambda + (carry_vals, xs_vals)) + settings = trace.context.settings + + with jax_core.set_current_trace(trace.parent_trace): + x_vals = [t[0] if hasattr(jax_core.get_aval(t), '_getitem') else t + for t in xs_vals] + x_avals = [t.aval for t in x_vals] metadata = _get_harvest_metadata(jaxpr, settings, - *(const_tracers + carry_tracers + x_tracers)) + *(const_vals + carry_vals + x_vals)) - plants = context.plants + plants = trace.context.plants plant_modes = collections.defaultdict(set) plant_xs_avals = {} for name, meta in metadata.items(): @@ -1506,16 +1366,17 @@ def new_body(carry, x): new_body, plant_xs_in_tree, tuple(carry_avals + x_avals + plant_xs_flat_avals)) plant_vals = tree_util.tree_leaves(append_plants) - out = lcf.scan_p.bind( - *(consts + carry_vals + xs_vals + plant_vals), - reverse=reverse, - length=length, - jaxpr=new_body_jaxpr, - num_consts=len(consts), - num_carry=num_carry, - linear=linear + (False,) * len(plant_vals), - unroll=unroll, - _split_transpose=_split_transpose) + out = lcf.scan_p.bind_with_trace( + trace.parent_trace, + (consts + carry_vals + xs_vals + plant_vals), + dict(reverse=reverse, + length=length, + jaxpr=new_body_jaxpr, + num_consts=len(consts), + num_carry=num_carry, + linear=linear + (False,) * len(plant_vals), + unroll=unroll, + _split_transpose=_split_transpose)) return out @@ -1525,15 +1386,12 @@ def new_body(carry, x): def _plant_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, cond_nconsts, body_nconsts): """Injects values into a while loop, overriding values for all iterations.""" - cond_const_tracers, body_const_tracers, init_tracers = jax_util.split_list( + cond_const_vals, body_const_vals, init_vals = jax_util.split_list( tracers, [cond_nconsts, body_nconsts]) - init_avals = tree_util.tree_map(lambda x: x.aval, init_tracers) - cond_const_vals, body_const_vals, init_vals = tree_util.tree_map( - lambda x: x.val, (cond_const_tracers, body_const_tracers, init_tracers)) - context = trace_util.get_dynamic_context(trace) - settings = context.settings + init_avals = tree_util.tree_map(lambda x: jax_core.get_aval(x), init_vals) # pylint: disable=unnecessary-lambda + settings = trace.context.settings body_metadata = _get_harvest_metadata(body_jaxpr, settings, - *(body_const_tracers + init_tracers)) + *(body_const_vals + init_vals)) for k, meta in body_metadata.items(): mode = meta['mode'] if mode not in ['clobber', 'cond_clobber']: @@ -1548,7 +1406,7 @@ def _plant_while_rule(trace: HarvestTrace, *tracers, cond_jaxpr, body_jaxpr, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) - plants = context.plants + plants = trace.context.plants def new_body(*carry): carry = plant(body_fun, **plant_settings)(plants, @@ -1558,13 +1416,14 @@ def new_body(*carry): in_tree = tree_util.tree_structure(init_avals) new_body_jaxpr, new_body_consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access new_body, in_tree, tuple(init_avals)) - out = lcf.while_p.bind( - *(cond_const_vals + new_body_consts + init_vals), - cond_nconsts=len(cond_const_vals), - body_nconsts=len(new_body_consts), - cond_jaxpr=cond_jaxpr, - body_jaxpr=new_body_jaxpr) - return jax_util.safe_map(trace.pure, out) + out = lcf.while_p.bind_with_trace( + trace.parent_trace, + (cond_const_vals + new_body_consts + init_vals), + dict(cond_nconsts=len(cond_const_vals), + body_nconsts=len(new_body_consts), + cond_jaxpr=cond_jaxpr, + body_jaxpr=new_body_jaxpr)) + return out plant_custom_rules[lcf.while_p] = _plant_while_rule @@ -1572,22 +1431,19 @@ def new_body(*carry): def _plant_cond_rule(trace, *tracers, branches, linear=None): """Injects the same values into both branches of a conditional.""" - index_tracer, ops_tracers = tracers[0], tracers[1:] - index_val, ops_vals = tree_util.tree_map(lambda x: x.val, - (index_tracer, ops_tracers)) - ops_avals = tree_util.tree_map(lambda x: x.aval, ops_tracers) - context = trace_util.get_dynamic_context(trace) - settings = context.settings + index_val, ops_vals = tracers[0], tracers[1:] + ops_avals = tree_util.tree_map(lambda x: jax_core.get_aval(x), ops_vals) # pylint: disable=unnecessary-lambda + settings = trace.context.settings plant_settings = dict( tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) branch_metadatas = tuple( - _get_harvest_metadata(branch, settings, *ops_tracers) + _get_harvest_metadata(branch, settings, *ops_vals) for branch in branches) _check_branch_metadata(branch_metadatas) - plants = context.plants + plants = trace.context.plants branch_funs = tuple(map(jax_core.jaxpr_as_fun, branches)) planted_branches = tuple( functools.partial(plant(f, **plant_settings), plants) @@ -1597,55 +1453,52 @@ def _plant_cond_rule(trace, *tracers, branches, linear=None): lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access planted_branches, in_tree, ops_avals, lax.cond_p.name)) if linear is None: - out = lax.cond_p.bind( - index_val, - *(tuple(consts) + ops_vals), - branches=tuple(new_branch_jaxprs)) + out = lax.cond_p.bind_with_trace( + trace.parent_trace, + (index_val, *consts, *ops_vals), + dict(branches=tuple(new_branch_jaxprs))) else: - out = lax.cond_p.bind( - index_val, - *(tuple(consts) + ops_vals), - branches=tuple(new_branch_jaxprs), - linear=(False,) * len(tuple(consts) + linear)) - return jax_util.safe_map(trace.pure, out) + out = lax.cond_p.bind_with_trace( + trace.parent_trace, + (index_val, *consts, *ops_vals), + dict(branches=tuple(new_branch_jaxprs), + linear=(False,) * len(tuple(consts) + linear))) + return out plant_custom_rules[lcf.cond_p] = _plant_cond_rule -def _plant_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse, +def _plant_checkpoint_rule(trace, *invals, jaxpr, policy, prevent_cse, differentiated): """Plant checkpoint rule.""" - invals = [t.val for t in tracers] - context = trace_util.get_dynamic_context(trace) - settings = context.settings + settings = trace.context.settings plant_settings = dict( tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) closed_jaxpr = jax_core.ClosedJaxpr(jaxpr, ()) - plants = context.plants + plants = trace.context.plants remat_fun = jax_core.jaxpr_as_fun(closed_jaxpr) planted_remat_fun = functools.partial( plant(remat_fun, **plant_settings), plants) plant_jaxpr, consts, _ = lcf._initial_style_jaxpr( # pylint: disable=protected-access planted_remat_fun, tree_util.tree_structure(invals), - tuple(t.aval for t in tracers)) - outvals = ad_checkpoint.remat_p.bind( - *consts, - *invals, - jaxpr=plant_jaxpr.jaxpr, - policy=policy, - prevent_cse=prevent_cse, - differentiated=differentiated) - return jax_util.safe_map(trace.pure, outvals) + tuple(jax_core.get_aval(t) for t in invals)) + return ad_checkpoint.remat_p.bind_with_trace( + trace.parent_trace, + (*consts, *invals), + dict(jaxpr=plant_jaxpr.jaxpr, + policy=policy, + prevent_cse=prevent_cse, + differentiated=differentiated)) plant_custom_rules[ad_checkpoint.remat_p] = _plant_checkpoint_rule -def _plant_pjit_rule(trace, *tracers, **params): +def _plant_pjit_rule(trace, *invals, **params): """Plant pjit rule.""" if params['in_shardings'] and not any( isinstance(i, sharding_impls.UnspecifiedValue) for i in params['in_shardings'] @@ -1662,16 +1515,14 @@ def _plant_pjit_rule(trace, *tracers, **params): f'specified. Got {params["out_shardings"]}' ) - invals = [t.val for t in tracers] - context = trace_util.get_dynamic_context(trace) - settings = context.settings + settings = trace.context.settings plant_settings = dict( tag=settings.tag, allowlist=settings.allowlist, blocklist=settings.blocklist, exclusive=settings.exclusive) closed_jaxpr = params['jaxpr'] - plants = context.plants + plants = trace.context.plants pjit_fun = jax_core.jaxpr_as_fun(closed_jaxpr) planted_pjit_fun = lu.wrap_init(functools.partial( @@ -1680,7 +1531,7 @@ def _plant_pjit_rule(trace, *tracers, **params): flat_fun, _ = api_util.flatten_fun_nokwargs(planted_pjit_fun, in_tree) planted_jaxpr, final_consts, out_avals = _oryx_pjit_jaxpr( - flat_fun, tuple(t.aval for t in tracers)) + flat_fun, tuple(jax_core.get_aval(t) for t in invals)) in_shardings, donated_invars, in_layouts = _calc_extra_inps( len(final_consts), params) @@ -1693,9 +1544,10 @@ def _plant_pjit_rule(trace, *tracers, **params): 'in_layouts': in_layouts, 'out_layouts': (None,) * len(out_avals), } - outvals = pjit.pjit_p.bind(*final_consts, *invals, **new_params) + outvals = pjit.pjit_p.bind_with_trace( + trace.parent_trace, (*final_consts, *invals), new_params) - return jax_util.safe_map(trace.pure, outvals) + return outvals plant_custom_rules[pjit.pjit_p] = _plant_pjit_rule diff --git a/oryx/core/interpreters/harvest_test.py b/oryx/core/interpreters/harvest_test.py index 1f22242..b35a00a 100644 --- a/oryx/core/interpreters/harvest_test.py +++ b/oryx/core/interpreters/harvest_test.py @@ -42,7 +42,6 @@ from jax.experimental import shard_map import jax.numpy as jnp import numpy as np -from oryx.core import trace_util from oryx.core.interpreters import harvest from oryx.internal import test_util @@ -485,7 +484,6 @@ def f_jvp(xs, ts): self.assertEqual(plant_variables(f)(dict(y=2.), 3.), 2. + np.sin(3.)) def test_can_plant_into_jvp_of_custom_jvp_function_unimplemented(self): - @jax.custom_jvp def f(x): return jnp.sin(x) @@ -638,15 +636,6 @@ def f(x): } }, 1.), (2., {})) - def test_harvest_should_clean_up_context(self): - - def f(x): - raise ValueError('Intentional error!') - - with self.assertRaisesRegex(ValueError, 'Intentional error!'): - harvest_variables(f)({}, 1.) - self.assertDictEqual(trace_util._thread_local_state.dynamic_contexts, {}) - def test_can_jit_compile_nest(self): def f(x): diff --git a/oryx/core/primitive.py b/oryx/core/primitive.py index d841b2f..493dd18 100644 --- a/oryx/core/primitive.py +++ b/oryx/core/primitive.py @@ -77,7 +77,7 @@ def __init__(self, name): def impl(self, f, *args, **params): del params - with jax_core.new_sublevel(): + with jax_core.eval_context(): return f.call_wrapped(*args) def subcall(self, name): @@ -109,21 +109,11 @@ def rule(ctx, *args, backend, name, call_jaxpr, **_params): register_hop_transformation_rule('mlir', hop_lowering) -def batch_fun(fun: lu.WrappedFun, in_dims): - fun, out_dims = batching.batch_subtrace(fun) - return _batch_fun(fun, in_dims), out_dims - - -@lu.transformation -def _batch_fun(in_dims, *in_vals, **params): - with jax_core.new_main( - batching.BatchTrace, axis_name=jax_core.no_axis_name) as main: - out_vals = yield ( - main, - in_dims, - ) + in_vals, params - del main - yield out_vals +def batch_fun(fun: lu.WrappedFun, axis_data, in_dims): + tag = jax_core.TraceTag() + in_dims = in_dims() if callable(in_dims) else in_dims + batched, out_dims = batching.batch_subtrace(fun, tag, axis_data, in_dims) + return batched, out_dims class FlatPrimitive(jax_core.Primitive): @@ -139,18 +129,19 @@ def _abstract(*flat_avals, **params): self.def_abstract_eval(_abstract) def _jvp(primals, tangents, **params): - primals_out, tangents_out = ad.jvp(lu.wrap_init(self.impl, - params)).call_wrapped( - primals, tangents) + primals_out, tangents_out = ad.jvp( + lu.wrap_init(self.impl, params)).call_wrapped(primals, tangents) + return primals_out, tangents_out ad.primitive_jvps[self] = _jvp - def _batch(args, dims, **params): - batched, out_dims = batch_fun(lu.wrap_init(self.impl, params), dims) + def _batch(axis_data, args, dims, **params): + batched, out_dims = batch_fun( + lu.wrap_init(self.impl, params), axis_data, dims) return batched.call_wrapped(*args), out_dims() - batching.primitive_batchers[self] = _batch + batching.fancy_primitive_batchers[self] = _batch def _mlir(c, *mlir_args, **params): lowering = mlir.lower_fun(self.impl, multiple_results=True) diff --git a/oryx/core/state/function_test.py b/oryx/core/state/function_test.py index 6543e20..95ab061 100644 --- a/oryx/core/state/function_test.py +++ b/oryx/core/state/function_test.py @@ -65,7 +65,7 @@ def f(x): def test_init_stateful_function(self): def f(x, init_key=None): - y = module.variable(np.ones(x.shape), name='y', key=init_key) + y = module.variable(np.ones(np.shape(x)), name='y', key=init_key) return x + y m = api.init(f)(random.PRNGKey(0), 1.) @@ -77,7 +77,7 @@ def f(x, init_key=None): def test_init_stateful_function_with_assign(self): def f(x, init_key=None): - y = module.variable(np.zeros(x.shape), name='y', key=init_key) + y = module.variable(np.zeros(np.shape(x)), name='y', key=init_key) next_y = module.assign(y + 1., name='y') return x + next_y @@ -90,7 +90,7 @@ def f(x, init_key=None): def test_assign_with_no_matching_variable_should_error(self): def f(x, init_key=None): - y = module.variable(np.zeros(x.shape), name='y', key=init_key) + y = module.variable(np.zeros(np.shape(x)), name='y', key=init_key) next_y = module.assign(y + 1., name='z') return x + next_y @@ -102,7 +102,7 @@ def f(x, init_key=None): def test_init_stateful_function_with_tied_in_assign(self): def f(x, init_key=None): - y = module.variable(np.zeros(x.shape), name='y', key=init_key) + y = module.variable(np.zeros(np.shape(x)), name='y', key=init_key) next_y = module.assign(y + 1., name='y') return primitive.tie_in(next_y, x) + y @@ -115,7 +115,7 @@ def f(x, init_key=None): def test_init_of_composed_stateful_functions_should_have_flat_params(self): def f(x, init_key=None): - y = module.variable(np.zeros(x.shape), name='y', key=init_key) + y = module.variable(np.zeros(np.shape(x)), name='y', key=init_key) next_y = module.assign(y + 1., name='y') return primitive.tie_in(next_y, x) + y @@ -131,7 +131,7 @@ def g(x, init_key=None): def test_init_of_nested_init_without_name_should_have_flat_params(self): def f(x, init_key=None): - y = module.variable(np.zeros(x.shape), name='y', key=init_key) + y = module.variable(np.zeros(np.shape(x)), name='y', key=init_key) next_y = module.assign(y + 1., name='y') return primitive.tie_in(next_y, x) + y @@ -147,7 +147,7 @@ def g(x, init_key=None): def test_init_of_nested_init_with_name_should_have_nested_params(self): def f(x, init_key=None): - y = module.variable(np.zeros(x.shape), name='y', key=init_key) + y = module.variable(np.zeros(np.shape(x)), name='y', key=init_key) next_y = module.assign(y + 1., name='y') return primitive.tie_in(next_y, x) + y @@ -186,7 +186,7 @@ class VmapModuleTest(absltest.TestCase): def test_vmap_of_init_should_return_ensemble(self): def f(x, init_key=None): - w = module.variable(random.normal(init_key, x.shape), name='w') + w = module.variable(random.normal(init_key, np.shape(x)), name='w') return np.dot(w, x) ensemble = jax.vmap(api.init(f))( random.split(random.PRNGKey(0)), diff --git a/oryx/core/trace_util.py b/oryx/core/trace_util.py index 0d55533..2ff61fe 100644 --- a/oryx/core/trace_util.py +++ b/oryx/core/trace_util.py @@ -15,7 +15,7 @@ """Module for JAX tracing utility functions.""" import contextlib import threading -from typing import Any, Dict, Generator, List +from typing import Any, Dict, Generator, List, Hashable from jax import api_util from jax import tree_util @@ -31,7 +31,7 @@ 'stage', 'trees', 'new_dynamic_context', - 'get_dynamic_context' + 'get_dynamic_context', ] safe_map = jax_util.safe_map @@ -67,7 +67,7 @@ def wrapped(*args, **kwargs): flat_avals) else: pvals = [pe.PartialVal.unknown(aval) for aval in flat_avals] - jaxpr, _, consts = pe.trace_to_jaxpr( + jaxpr, _, consts = pe.trace_to_jaxpr_nounits( flat_fun, pvals, instantiate=True) @@ -86,19 +86,26 @@ def wrapped(*args, **kwargs): return wrapped +def extract_call_jaxpr(primitive, params): + if not (primitive.call_primitive or primitive.map_primitive): + return None, params + else: + params = dict(params) + return params.pop('call_jaxpr'), params + + class _ThreadLocalState(threading.local): def __init__(self): super().__init__() - self.dynamic_contexts: Dict[jax_core.MainTrace, List[Any]] = {} + self.dynamic_contexts: Dict[Hashable, List[Any]] = {} _thread_local_state = _ThreadLocalState() @contextlib.contextmanager -def new_dynamic_context(master: jax_core.MainTrace, - context: Any) -> Generator[None, None, None]: - """Creates a dynamic context for a trace.""" +def new_dynamic_context(_: Any, context: Any) -> Generator[None, None, None]: + master = jax_core.get_opaque_trace_state('oryx') if master not in _thread_local_state.dynamic_contexts: _thread_local_state.dynamic_contexts[master] = [] _thread_local_state.dynamic_contexts[master].append(context) @@ -110,16 +117,8 @@ def new_dynamic_context(master: jax_core.MainTrace, del _thread_local_state.dynamic_contexts[master] -def get_dynamic_context(trace: jax_core.Trace) -> Any: - """Returns the current active dynamic context for a trace.""" - if trace.main not in _thread_local_state.dynamic_contexts: - raise ValueError(f'No dynamic context registered for trace: {trace}') - return _thread_local_state.dynamic_contexts[trace.main][-1] - - -def extract_call_jaxpr(primitive, params): - if not (primitive.call_primitive or primitive.map_primitive): - return None, params - else: - params = dict(params) - return params.pop('call_jaxpr'), params +def get_dynamic_context(_: Any) -> Any: + master = jax_core.get_opaque_trace_state('oryx') + if master not in _thread_local_state.dynamic_contexts: + raise ValueError(f'No dynamic context registered for state: {master}') + return _thread_local_state.dynamic_contexts[master][-1] diff --git a/oryx/experimental/nn/base.py b/oryx/experimental/nn/base.py index 836a5f9..49cef0c 100644 --- a/oryx/experimental/nn/base.py +++ b/oryx/experimental/nn/base.py @@ -403,8 +403,8 @@ class NoneProxy: not_mapped = NoneProxy() -def custom_layer_cau_batch(vals, dims, *, num_consts, in_tree, out_tree, kwargs, - **params): +def custom_layer_cau_batch(axis_data, vals, dims, *, num_consts, in_tree, + out_tree, kwargs, **params): """Batching rule for layer_cau primitive to handle custom layers.""" if all(dim is batching.not_mapped for dim in dims): return layer_cau_p.bind(*vals, num_consts=num_consts, in_tree=in_tree, @@ -453,9 +453,9 @@ def custom_layer_cau_batch(vals, dims, *, num_consts, in_tree, out_tree, kwargs, batched, out_dims = primitive.batch_fun(lu.wrap_init( layer_cau_p.impl, dict(params, num_consts=num_consts, in_tree=in_tree, out_tree=out_tree, - kwargs=kwargs)), orig_dims) + kwargs=kwargs)), axis_data, orig_dims) return batched.call_wrapped(*orig_vals), out_dims() -batching.primitive_batchers[layer_cau_p] = custom_layer_cau_batch +batching.fancy_primitive_batchers[layer_cau_p] = custom_layer_cau_batch def _layer_cau_batched(layer, *args, **kwargs):