diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index 33f9981..f787ce5 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -1097,7 +1097,7 @@ def _check_branch_metadata(branch_metadatas): raise ValueError(f'Mismatched dtype between branches: \'{name}\'.') -def _reap_cond_rule(trace, *tracers, branches, linear): +def _reap_cond_rule(trace, *tracers, branches, linear=None): """Reaps each path of the `cond`.""" index_tracer, ops_tracers = tracers[0], tracers[1:] index_val, ops_vals = tree_util.tree_map(lambda x: x.val, @@ -1122,11 +1122,17 @@ def _reap_cond_rule(trace, *tracers, branches, linear): new_branch_jaxprs, consts, out_trees = ( lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access reaped_branches, in_tree, ops_avals, lax.cond_p.name)) - out = lax.cond_p.bind( - index_val, - *(tuple(consts) + ops_vals), - branches=tuple(new_branch_jaxprs), - linear=(False,) * len(tuple(consts) + linear)) + if linear is None: + out = lax.cond_p.bind( + index_val, + *(tuple(consts) + ops_vals), + branches=tuple(new_branch_jaxprs)) + else: + out = lax.cond_p.bind( + index_val, + *(tuple(consts) + ops_vals), + branches=tuple(new_branch_jaxprs), + linear=(False,) * len(tuple(consts) + linear)) out = jax_util.safe_map(trace.pure, out) out, reaps, preds = tree_util.tree_unflatten(out_trees[0], out) for k, v in reaps.items(): @@ -1558,7 +1564,7 @@ def new_body(*carry): plant_custom_rules[lcf.while_p] = _plant_while_rule -def _plant_cond_rule(trace, *tracers, branches, linear): +def _plant_cond_rule(trace, *tracers, branches, linear=None): """Injects the same values into both branches of a conditional.""" index_tracer, ops_tracers = tracers[0], tracers[1:] index_val, ops_vals = tree_util.tree_map(lambda x: x.val, @@ -1584,11 +1590,17 @@ def _plant_cond_rule(trace, *tracers, branches, linear): new_branch_jaxprs, consts, _ = ( lcf._initial_style_jaxprs_with_common_consts( # pylint: disable=protected-access planted_branches, in_tree, ops_avals, lax.cond_p.name)) - out = lax.cond_p.bind( - index_val, - *(tuple(consts) + ops_vals), - branches=tuple(new_branch_jaxprs), - linear=(False,) * len(tuple(consts) + linear)) + if linear is None: + out = lax.cond_p.bind( + index_val, + *(tuple(consts) + ops_vals), + branches=tuple(new_branch_jaxprs)) + else: + out = lax.cond_p.bind( + index_val, + *(tuple(consts) + ops_vals), + branches=tuple(new_branch_jaxprs), + linear=(False,) * len(tuple(consts) + linear)) return jax_util.safe_map(trace.pure, out)