diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index ae08eee91..8540b6e5f 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -229,23 +229,34 @@ def __init__(self, batch_shape=(), event_shape=(), *, validate_args=None): if validate_args is not None: self._validate_args = validate_args if self._validate_args: - for param, constraint in self.arg_constraints.items(): - if param not in self.__dict__ and isinstance( - getattr(type(self), param), lazy_property - ): - continue - if constraints.is_dependent(constraint): - continue # skip constraints that cannot be checked - is_valid = constraint(getattr(self, param)) - if not_jax_tracer(is_valid): - if not np.all(is_valid): - raise ValueError( - "{} distribution got invalid {} parameter.".format( - self.__class__.__name__, param - ) - ) + self.validate_args(strict=False) super(Distribution, self).__init__() + def validate_args(self, strict: bool = True) -> None: + """ + Validate the arguments of the distribution. + + :param strict: Require strict validation, raising an error if the function is + called inside jitted code. + """ + for param, constraint in self.arg_constraints.items(): + if param not in self.__dict__ and isinstance( + getattr(type(self), param), lazy_property + ): + continue + if constraints.is_dependent(constraint): + continue # skip constraints that cannot be checked + is_valid = constraint(getattr(self, param)) + if not_jax_tracer(is_valid): + if not np.all(is_valid): + raise ValueError( + "{} distribution got invalid {} parameter.".format( + self.__class__.__name__, param + ) + ) + elif strict: + raise RuntimeError("Cannot validate arguments inside jitted code.") + @property def batch_shape(self): """ diff --git a/test/test_distributions.py b/test/test_distributions.py index b0c379290..20da4165d 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3322,6 +3322,25 @@ def test_vmap_validate_args(): assert not v_dist._validate_args +def test_explicit_validate_args(): + # Check validation passes for valid parameters. + d = dist.Normal(0, 1) + d.validate_args() + + # Check validation fails for invalid parameters. + d = dist.Normal(0, -1) + with pytest.raises(ValueError, match="got invalid scale parameter"): + d.validate_args() + + # Check validation is skipped for strict=False and raises an error for strict=True. + jitted = jax.jit( + lambda d, strict: d.validate_args(strict), static_argnames=["strict"] + ) + jitted(d, False) + with pytest.raises(RuntimeError, match="Cannot validate arguments"): + jitted(d, True) + + def test_multinomial_abstract_total_count(): probs = jnp.array([0.2, 0.5, 0.3]) key = random.PRNGKey(0)