Skip to content

Commit

Permalink
Refactor validate_args method on distributions (fixes #1865). (#1866)
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann authored Sep 24, 2024
1 parent 94f4b99 commit 8e9313f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
41 changes: 26 additions & 15 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,23 +229,34 @@ def __init__(self, batch_shape=(), event_shape=(), *, validate_args=None):
if validate_args is not None:
self._validate_args = validate_args
if self._validate_args:
for param, constraint in self.arg_constraints.items():
if param not in self.__dict__ and isinstance(
getattr(type(self), param), lazy_property
):
continue
if constraints.is_dependent(constraint):
continue # skip constraints that cannot be checked
is_valid = constraint(getattr(self, param))
if not_jax_tracer(is_valid):
if not np.all(is_valid):
raise ValueError(
"{} distribution got invalid {} parameter.".format(
self.__class__.__name__, param
)
)
self.validate_args(strict=False)
super(Distribution, self).__init__()

def validate_args(self, strict: bool = True) -> None:
"""
Validate the arguments of the distribution.
:param strict: Require strict validation, raising an error if the function is
called inside jitted code.
"""
for param, constraint in self.arg_constraints.items():
if param not in self.__dict__ and isinstance(
getattr(type(self), param), lazy_property
):
continue
if constraints.is_dependent(constraint):
continue # skip constraints that cannot be checked
is_valid = constraint(getattr(self, param))
if not_jax_tracer(is_valid):
if not np.all(is_valid):
raise ValueError(
"{} distribution got invalid {} parameter.".format(
self.__class__.__name__, param
)
)
elif strict:
raise RuntimeError("Cannot validate arguments inside jitted code.")

@property
def batch_shape(self):
"""
Expand Down
19 changes: 19 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3322,6 +3322,25 @@ def test_vmap_validate_args():
assert not v_dist._validate_args


def test_explicit_validate_args():
# Check validation passes for valid parameters.
d = dist.Normal(0, 1)
d.validate_args()

# Check validation fails for invalid parameters.
d = dist.Normal(0, -1)
with pytest.raises(ValueError, match="got invalid scale parameter"):
d.validate_args()

# Check validation is skipped for strict=False and raises an error for strict=True.
jitted = jax.jit(
lambda d, strict: d.validate_args(strict), static_argnames=["strict"]
)
jitted(d, False)
with pytest.raises(RuntimeError, match="Cannot validate arguments"):
jitted(d, True)


def test_multinomial_abstract_total_count():
probs = jnp.array([0.2, 0.5, 0.3])
key = random.PRNGKey(0)
Expand Down

0 comments on commit 8e9313f

Please sign in to comment.