diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 12cd2fd3..57bf3155 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -432,6 +432,7 @@ def sample( ) = None, init: str | None = None, initvals: str | dict | None = None, + include_response_params: bool = False, **kwargs, ) -> az.InferenceData | pm.Approximation: """Perform sampling using the `fit` method via bambi.Model. @@ -453,6 +454,10 @@ def sample( values for parameters of the model, or a string "map" to use initialization at the MAP estimate. If "map" is used, the MAP estimate will be computed if not already attached to the base class from prior call to 'find_MAP`. + include_response_params: optional + Include parameters of the response distribution in the output. These usually + take more space than other parameters as there's one of them per + observation. Defaults to False. kwargs Other arguments passed to bmb.Model.fit(). Please see [here] (https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit) @@ -539,35 +544,39 @@ def sample( # If sampler is finally `numpyro` make sure # the jitter argument is set to False if sampler == "nuts_numpyro": - if kwargs.get("jitter", None): - _logger.warning( - "The jitter argument is set to True. " - + "This argument is not supported " - + "by the numpyro backend. " - + "The jitter argument will be set to False." - ) - kwargs["jitter"] = False - else: - if "jitter" in kwargs: - _logger.warning( - "The jitter keyword argument is " - + "supported only by the nuts_numpyro sampler. \n" - + "The jitter argument will be ignored." - ) - del kwargs["jitter"] + if "nuts_sampler_kwargs" in kwargs: + if kwargs["nuts_sampler_kwargs"].get("jitter"): + _logger.warning( + "The jitter argument is set to True. " + + "This argument is not supported " + + "by the numpyro backend. " + + "The jitter argument will be set to False." + ) + kwargs["nuts_sampler_kwargs"]["jitter"] = False + else: + kwargs["nuts_sampler_kwargs"] = {"jitter": False} - if "include_mean" not in kwargs: - # If not specified, include the mean prediction in - # kwargs to be passed to the model.fit() method - kwargs["include_mean"] = True if self._inference_obj is not None: _logger.warning( "The model has already been sampled. Overwriting the previous " + "inference object. Any previous reference to the inference object " + "will still point to the old object." ) + + if "nuts_sampler" not in kwargs: + if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]: + kwargs["nuts_sampler"] = ( + "pymc" if sampler == "mcmc" else sampler.split("_")[1] + ) + print(kwargs["nuts_sampler"]) + self._inference_obj = self.model.fit( - inference_method=sampler, init=init, **kwargs + inference_method="mcmc" + if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"] + else sampler, + init=init, + include_response_params=include_response_params, + **kwargs, ) # The parent was previously not part of deterministics --> compute it via