diff --git a/numpyro/contrib/hsgp/laplacian.py b/numpyro/contrib/hsgp/laplacian.py index 011c828aa..ae0cbbd6b 100644 --- a/numpyro/contrib/hsgp/laplacian.py +++ b/numpyro/contrib/hsgp/laplacian.py @@ -152,8 +152,8 @@ def eigenfunctions(x: ArrayLike, ell: float | list[float], m: int | list[int]) - x_ = jnp.expand_dims(x, axis=-1) else: x_ = jnp.array(x) - dim = jnp.shape(x)[-1] # others assumed batch dims - n_batch_dims = jnp.ndim(x) - 1 + dim = jnp.shape(x_)[-1] # others assumed batch dims + n_batch_dims = jnp.ndim(x_) - 1 ell_ = _convert_ell(ell, dim) a = jnp.expand_dims(ell_, tuple(range(n_batch_dims))) b = jnp.expand_dims(sqrt_eigenvalues(ell_, m, dim), tuple(range(n_batch_dims)))