Skip to content

Commit

Permalink
Merge pull request #520 from lnccbrown/bambi-014-fix-sampling
Browse files Browse the repository at this point in the history
Fix sampler specification
  • Loading branch information
digicosmos86 authored Aug 5, 2024
2 parents e52ac63 + 6f7b4fd commit 83a3017
Showing 1 changed file with 30 additions and 21 deletions.
51 changes: 30 additions & 21 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ def sample(
) = None,
init: str | None = None,
initvals: str | dict | None = None,
include_response_params: bool = False,
**kwargs,
) -> az.InferenceData | pm.Approximation:
"""Perform sampling using the `fit` method via bambi.Model.
Expand All @@ -453,6 +454,10 @@ def sample(
values for parameters of the model, or a string "map" to use initialization
at the MAP estimate. If "map" is used, the MAP estimate will be computed if
not already attached to the base class from prior call to 'find_MAP`.
include_response_params: optional
Include parameters of the response distribution in the output. These usually
take more space than other parameters as there's one of them per
observation. Defaults to False.
kwargs
Other arguments passed to bmb.Model.fit(). Please see [here]
(https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit)
Expand Down Expand Up @@ -539,35 +544,39 @@ def sample(
# If sampler is finally `numpyro` make sure
# the jitter argument is set to False
if sampler == "nuts_numpyro":
if kwargs.get("jitter", None):
_logger.warning(
"The jitter argument is set to True. "
+ "This argument is not supported "
+ "by the numpyro backend. "
+ "The jitter argument will be set to False."
)
kwargs["jitter"] = False
else:
if "jitter" in kwargs:
_logger.warning(
"The jitter keyword argument is "
+ "supported only by the nuts_numpyro sampler. \n"
+ "The jitter argument will be ignored."
)
del kwargs["jitter"]
if "nuts_sampler_kwargs" in kwargs:
if kwargs["nuts_sampler_kwargs"].get("jitter"):
_logger.warning(
"The jitter argument is set to True. "
+ "This argument is not supported "
+ "by the numpyro backend. "
+ "The jitter argument will be set to False."
)
kwargs["nuts_sampler_kwargs"]["jitter"] = False
else:
kwargs["nuts_sampler_kwargs"] = {"jitter": False}

if "include_mean" not in kwargs:
# If not specified, include the mean prediction in
# kwargs to be passed to the model.fit() method
kwargs["include_mean"] = True
if self._inference_obj is not None:
_logger.warning(
"The model has already been sampled. Overwriting the previous "
+ "inference object. Any previous reference to the inference object "
+ "will still point to the old object."
)

if "nuts_sampler" not in kwargs:
if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
kwargs["nuts_sampler"] = (
"pymc" if sampler == "mcmc" else sampler.split("_")[1]
)
print(kwargs["nuts_sampler"])

self._inference_obj = self.model.fit(
inference_method=sampler, init=init, **kwargs
inference_method="mcmc"
if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]
else sampler,
init=init,
include_response_params=include_response_params,
**kwargs,
)

# The parent was previously not part of deterministics --> compute it via
Expand Down

0 comments on commit 83a3017

Please sign in to comment.