diff --git a/oryx/core/interpreters/harvest.py b/oryx/core/interpreters/harvest.py index 6d4d4ca..3e1beca 100644 --- a/oryx/core/interpreters/harvest.py +++ b/oryx/core/interpreters/harvest.py @@ -1211,14 +1211,14 @@ def _calc_extra_inps(num_consts, params): def _reap_pjit_rule(trace, *tracers, **params): """Reap pjit rule.""" if params['in_shardings'] and not any( - sharding_impls.is_unspecified(i) for i in params['in_shardings'] + isinstance(i, sharding_impls.UnspecifiedValue) for i in params['in_shardings'] ): raise ValueError( 'oryx only supports pjit which has no in_axis_resources ' f'specified. Got {params["in_shardings"]}' ) if params['out_shardings'] and not any( - sharding_impls.is_unspecified(o) for o in params['out_shardings'] + isinstance(o, sharding_impls.UnspecifiedValue) for o in params['out_shardings'] ): raise ValueError( 'oryx only supports pjit which has no out_axis_resources ' @@ -1648,14 +1648,14 @@ def _plant_checkpoint_rule(trace, *tracers, jaxpr, policy, prevent_cse, def _plant_pjit_rule(trace, *tracers, **params): """Plant pjit rule.""" if params['in_shardings'] and not any( - sharding_impls.is_unspecified(i) for i in params['in_shardings'] + isinstance(i, sharding_impls.UnspecifiedValue) for i in params['in_shardings'] ): raise ValueError( 'oryx only supports pjit which has no in_axis_resources ' f'specified. Got {params["in_shardings"]}' ) if params['out_shardings'] and not any( - sharding_impls.is_unspecified(o) for o in params['out_shardings'] + isinstance(o, sharding_impls.UnspecifiedValue) for o in params['out_shardings'] ): raise ValueError( 'oryx only supports pjit which has no out_axis_resources ' diff --git a/oryx/core/interpreters/propagate.py b/oryx/core/interpreters/propagate.py index bcf5b3b..7b0a84b 100644 --- a/oryx/core/interpreters/propagate.py +++ b/oryx/core/interpreters/propagate.py @@ -38,7 +38,6 @@ from jax._src import sharding_impls from jax.extend import linear_util as lu from jax.interpreters import partial_eval as pe -from jax.interpreters import pxla from oryx.core import pytree from oryx.core import trace_util @@ -367,10 +366,10 @@ def _pjit_propagate_rule(incells, outcells, **params): """Propagate rule for pjit primitive.""" # TODO(https://github.com/jax-ml/oryx/issues/29): Fix this rule so that it # pylint: disable=g-bad-todo # works correct for in_sharding, out_shardings and donated_invars. - if not any(pxla._is_unspecified(i) for i in params['in_shardings']): # pylint: disable=protected-access + if not any(isinstance(i, sharding_impls.UnspecifiedValue) for i in params['in_shardings']): # pylint: disable=protected-access raise ValueError('oryx only supports pjit which has no in_axis_resources ' 'specified.') - if not any(pxla._is_unspecified(o) for o in params['out_shardings']): # pylint: disable=protected-access + if not any(isinstance(o, sharding_impls.UnspecifiedValue) for o in params['out_shardings']): # pylint: disable=protected-access raise ValueError('oryx only supports pjit which has no out_axis_resources ' 'specified.')