From 44d885f6d297fa94b147fa074a3cdbfb774206f9 Mon Sep 17 00:00:00 2001
From: Paul Xu <yang_xu@brown.edu>
Date: Wed, 24 Jul 2024 12:04:31 -0400
Subject: [PATCH 1/2] fix sampler specification

---
 src/hssm/hssm.py | 51 ++++++++++++++++++++++++++++--------------------
 1 file changed, 30 insertions(+), 21 deletions(-)

diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py
index 12cd2fd3..c923405a 100644
--- a/src/hssm/hssm.py
+++ b/src/hssm/hssm.py
@@ -432,6 +432,7 @@ def sample(
         ) = None,
         init: str | None = None,
         initvals: str | dict | None = None,
+        include_response_params: bool = False,
         **kwargs,
     ) -> az.InferenceData | pm.Approximation:
         """Perform sampling using the `fit` method via bambi.Model.
@@ -453,6 +454,10 @@ def sample(
             values for parameters of the model, or a string "map" to use initialization
             at the MAP estimate. If "map" is used, the MAP estimate will be computed if
             not already attached to the base class from prior call to 'find_MAP`.
+        include_response_params: optional
+            Include parameters of the response distribution in the output. These usually
+            take more space than other parameters as there's one of them per
+            observation. Defaults to False.
         kwargs
             Other arguments passed to bmb.Model.fit(). Please see [here]
             (https://bambinos.github.io/bambi/api_reference.html#bambi.models.Model.fit)
@@ -539,35 +544,39 @@ def sample(
         # If sampler is finally `numpyro` make sure
         # the jitter argument is set to False
         if sampler == "nuts_numpyro":
-            if kwargs.get("jitter", None):
-                _logger.warning(
-                    "The jitter argument is set to True. "
-                    + "This argument is not supported "
-                    + "by the numpyro backend. "
-                    + "The jitter argument will be set to False."
-                )
-            kwargs["jitter"] = False
-        else:
-            if "jitter" in kwargs:
-                _logger.warning(
-                    "The jitter keyword argument is "
-                    + "supported only by the nuts_numpyro sampler. \n"
-                    + "The jitter argument will be ignored."
-                )
-                del kwargs["jitter"]
+            if "nuts_numpyro_kwargs" in kwargs:
+                if kwargs["nuts_sampler_kwargs"].get("jitter"):
+                    _logger.warning(
+                        "The jitter argument is set to True. "
+                        + "This argument is not supported "
+                        + "by the numpyro backend. "
+                        + "The jitter argument will be set to False."
+                    )
+                kwargs["nuts_sampler_kwargs"]["jitter"] = False
+            else:
+                kwargs["nuts_sampler_kwargs"] = {"jitter": False}
 
-        if "include_mean" not in kwargs:
-            # If not specified, include the mean prediction in
-            # kwargs to be passed to the model.fit() method
-            kwargs["include_mean"] = True
         if self._inference_obj is not None:
             _logger.warning(
                 "The model has already been sampled. Overwriting the previous "
                 + "inference object. Any previous reference to the inference object "
                 + "will still point to the old object."
             )
+
+        if "nuts_sampler" not in kwargs:
+            if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
+                kwargs["nuts_sampler"] = (
+                    "pymc" if sampler == "mcmc" else sampler.split("_")[1]
+                )
+            print(kwargs["nuts_sampler"])
+
         self._inference_obj = self.model.fit(
-            inference_method=sampler, init=init, **kwargs
+            inference_method="mcmc"
+            if sampler in ["mcmc", "nuts_numpyro", "nuts_blackjax"]
+            else sampler,
+            init=init,
+            include_response_params=include_response_params,
+            **kwargs,
         )
 
         # The parent was previously not part of deterministics --> compute it via

From 6f7b4fd4a79de0cfcf706162220b969e8c7c28e0 Mon Sep 17 00:00:00 2001
From: Paul Xu <yang_xu@brown.edu>
Date: Wed, 31 Jul 2024 09:29:06 -0400
Subject: [PATCH 2/2] fix typo

---
 src/hssm/hssm.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py
index c923405a..57bf3155 100644
--- a/src/hssm/hssm.py
+++ b/src/hssm/hssm.py
@@ -544,7 +544,7 @@ def sample(
         # If sampler is finally `numpyro` make sure
         # the jitter argument is set to False
         if sampler == "nuts_numpyro":
-            if "nuts_numpyro_kwargs" in kwargs:
+            if "nuts_sampler_kwargs" in kwargs:
                 if kwargs["nuts_sampler_kwargs"].get("jitter"):
                     _logger.warning(
                         "The jitter argument is set to True. "