Skip to content

Commit

Permalink
simplify if elif else
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFengler committed Aug 5, 2024
1 parent 2ba38e5 commit 1d58e74
Showing 1 changed file with 41 additions and 50 deletions.
91 changes: 41 additions & 50 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __init__(
) = None,
loglik_kind: LoglikKind | None = None,
p_outlier: float | dict | bmb.Prior | None = 0.05,
lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=10.0),
lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=20.0),
hierarchical: bool = False,
link_settings: Literal["log_logit"] | None = None,
prior_settings: Literal["safe"] | None = "safe",
Expand Down Expand Up @@ -975,40 +975,36 @@ def plot_trace(
Whether to call plt.tight_layout() after plotting. Defaults to True.
"""
data = data or self.traces
assert isinstance(data, az.InferenceData)

if data is not None:
if not include_deterministic:
var_names = list(
set([var.name for var in self.pymc_model.free_RVs]).intersection(
set(list(data["posterior"].data_vars.keys()))
)
assert isinstance(
data, az.InferenceData
), "data must be an InferenceData object."

if not include_deterministic:
var_names = list(
set([var.name for var in self.pymc_model.free_RVs]).intersection(
set(list(data["posterior"].data_vars.keys()))
)
# var_names = self._get_deterministic_var_names(data)
if var_names:
if "var_names" in kwargs:
if isinstance(kwargs["var_names"], str):
if kwargs["var_names"] not in var_names:
var_names.append(kwargs["var_names"])
kwargs["var_names"] = var_names
elif isinstance(kwargs["var_names"], list):
kwargs["var_names"] = list(
set(var_names) | set(kwargs["var_names"])
)
elif kwargs["var_names"] is None:
kwargs["var_names"] = var_names
else:
raise ValueError(
"`var_names` must be a string, a list of strings"
", or None."
)
else:
)
# var_names = self._get_deterministic_var_names(data)
if var_names:
if "var_names" in kwargs:
if isinstance(kwargs["var_names"], str):
if kwargs["var_names"] not in var_names:
var_names.append(kwargs["var_names"])
kwargs["var_names"] = var_names
elif (not isinstance(data, az.InferenceData)) and (data is not None):
raise ValueError("data must be an InferenceData object.")
elif data is None:
raise ValueError("Please sample to model first.")

elif isinstance(kwargs["var_names"], list):
kwargs["var_names"] = list(
set(var_names) | set(kwargs["var_names"])
)
elif kwargs["var_names"] is None:
kwargs["var_names"] = var_names
else:
raise ValueError(
"`var_names` must be a string, a list of strings"
", or None."
)
else:
kwargs["var_names"] = var_names
az.plot_trace(data, **kwargs)

if tight_layout:
Expand Down Expand Up @@ -1046,24 +1042,19 @@ def summary(
A pandas DataFrame or xarray Dataset containing the summary statistics.
"""
data = data or self.traces
assert isinstance(data, az.InferenceData)
if data is not None:
if not include_deterministic:
var_names = list(
set([var.name for var in self.pymc_model.free_RVs]).intersection(
set(list(data["posterior"].data_vars.keys()))
)
assert isinstance(
data, az.InferenceData
), "data must be an InferenceData object."

if not include_deterministic:
var_names = list(
set([var.name for var in self.pymc_model.free_RVs]).intersection(
set(list(data["posterior"].data_vars.keys()))
)
# var_names = self._get_deterministic_var_names(data)
if var_names:
kwargs["var_names"] = list(
set(var_names + kwargs.get("var_names", []))
)
elif (not isinstance(data, az.InferenceData)) and (data is not None):
raise ValueError("data must be an InferenceData object.")
elif data is None:
raise ValueError("Please sample to model first.")

)
# var_names = self._get_deterministic_var_names(data)
if var_names:
kwargs["var_names"] = list(set(var_names + kwargs.get("var_names", [])))
return az.summary(data, **kwargs)

def initial_point(self, transformed: bool = False) -> dict[str, np.ndarray]:
Expand Down

0 comments on commit 1d58e74

Please sign in to comment.