From 709252e11530b4aabee99e19c1cae1ad3cb3ace6 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 24 Jul 2024 13:34:49 -0400 Subject: [PATCH 1/4] Remove HSSMModelGraph --- src/hssm/hssm.py | 28 +---------- src/hssm/utils.py | 126 +--------------------------------------------- 2 files changed, 3 insertions(+), 151 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index c923405a..73c39781 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -45,7 +45,6 @@ _make_default_prior, ) from hssm.utils import ( - HSSMModelGraph, _get_alias_dict, _print_prior, _process_param_in_kwargs, @@ -897,8 +896,7 @@ def response_str(self) -> str: """Return the response variable names in string format.""" return ",".join(self.response) - # NOTE: can't annotate return type because the graphviz dependency is - # optional + # NOTE: can't annotate return type because the graphviz dependency is optional def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png"): """Produce a graphviz Digraph from a built HSSM model. @@ -929,30 +927,8 @@ def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png") ------- graphviz.Graph The graph - - Note - ---- - The code is largely copied from - https://github.com/bambinos/bambi/blob/main/bambi/models.py - Credit for the code goes to Bambi developers. """ - self.model._check_built() - - graphviz = HSSMModelGraph( - model=self.pymc_model, parent=self._parent_param - ).make_graph(formatting=formatting, response_str=self.response_str) - - width, height = (None, None) if figsize is None else figsize - - if name is not None: - graphviz_ = graphviz.copy() - graphviz_.graph_attr.update(size=f"{width},{height}!") - graphviz_.graph_attr.update(dpi=str(dpi)) - graphviz_.render(filename=name, format=fmt, cleanup=True) - - return graphviz_ - - return graphviz + return self.model.graph(formatting, name, figsize, dpi, fmt) def plot_trace( self, diff --git a/src/hssm/utils.py b/src/hssm/utils.py index b6b4eabc..c995afde 100644 --- a/src/hssm/utils.py +++ b/src/hssm/utils.py @@ -10,7 +10,7 @@ """ import logging -from typing import Any, Iterable, Literal, NewType, cast +from typing import Any, Literal, cast import bambi as bmb import jax @@ -20,8 +20,6 @@ import xarray as xr from bambi.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, OffsetTerm from huggingface_hub import hf_hub_download -from pymc.model_graph import DEFAULT_NODE_FORMATTERS, ModelGraph -from pytensor import function from .param import Param @@ -146,128 +144,6 @@ def _get_alias_dict( return alias_dict -def fast_eval(var): - """Fast evaluation of a variable. - - Notes - ----- - This is a helper function required for one of the functions below. - """ - return function([], var, mode="FAST_COMPILE")() - - -VarName = NewType("VarName", str) - - -class HSSMModelGraph(ModelGraph): - """Customize PyMC's ModelGraph class to inject the missing parent parameter. - - Notes - ----- - This is really a hack. There might be better ways to get around the - parent parameter issue. - """ - - def __init__(self, model, parent): - self.parent = parent - super().__init__(model) - - def make_graph( - self, - var_names: Iterable[VarName] | None = None, - formatting: str = "plain", - response_str: str = "rt,response", - ): - """Make graphviz Digraph of PyMC model. - - Returns - ------- - graphviz.Digraph - - Notes - ----- - This is a slightly modified version of the code in: - https://github.com/pymc-devs/pymc/blob/main/pymc/model_graph.py - - Credit for this code goes to PyMC developers. - """ - try: - import graphviz # pylint: disable=C0415 - except ImportError as e: - e.msg = ( - "This function requires the python library graphviz, " - + "along with binaries. " - + "The easiest way to install all of this is by running\n\n" - + "\tconda install -c conda-forge python-graphviz" - ) - raise e - graph = graphviz.Digraph(self.model.name) - for plate_label, all_var_names in self.get_plates(var_names).items(): - if plate_label: - # must be preceded by 'cluster' to get a box around it - with graph.subgraph(name="cluster" + plate_label) as sub: - for var_name in all_var_names: - self._make_node( - var_name, - sub, - formatting=formatting, - node_formatters=DEFAULT_NODE_FORMATTERS, - ) - # plate label goes bottom right - sub.attr( - label=plate_label, - labeljust="r", - labelloc="b", - style="rounded", - ) - - else: - for var_name in all_var_names: - self._make_node( - var_name, - graph, - formatting=formatting, - node_formatters=DEFAULT_NODE_FORMATTERS, - ) - - if self.parent.is_regression: - # Insert the parent parameter that's not included in the graph - with graph.subgraph(name="cluster" + self.parent.name) as sub: - sub.node( - self.parent.name, - label=f"{self.parent.name}\n~\nDeterministic", - shape="box", - ) - shape = fast_eval(self.model[response_str].shape) - plate_label = f"{response_str}_obs({shape[0]})" - - sub.attr( - label=plate_label, - labeljust="r", - labelloc="b", - style="rounded", - ) - - for child, parents in self.make_compute_graph(var_names=var_names).items(): - # parents is a set of rv names that precede child rv nodes - for parent in parents: - if ( - self.parent.is_regression - and parent.startswith(f"{self.parent.name}_") - and child == response_str - ): - # Modify the edges so that they point to the - # parent parameter - graph.edge(parent.replace(":", "&"), self.parent.name) - else: - graph.edge(parent.replace(":", "&"), child.replace(":", "&")) - - if self.parent.is_regression: - graph.edge(self.parent.name, response_str) - - return graph - - def set_floatX(dtype: Literal["float32", "float64"], update_jax: bool = True): """Set float types for pytensor and Jax. From abc40e27aef57e3d5bac371f763f4f13b2e8b190 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 24 Jul 2024 13:59:46 -0400 Subject: [PATCH 2/4] added a hack to produce clean graph in base case --- src/hssm/hssm.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/hssm/hssm.py b/src/hssm/hssm.py index 73c39781..2b8e945d 100644 --- a/src/hssm/hssm.py +++ b/src/hssm/hssm.py @@ -567,7 +567,6 @@ def sample( kwargs["nuts_sampler"] = ( "pymc" if sampler == "mcmc" else sampler.split("_")[1] ) - print(kwargs["nuts_sampler"]) self._inference_obj = self.model.fit( inference_method="mcmc" @@ -928,7 +927,21 @@ def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png") graphviz.Graph The graph """ - return self.model.graph(formatting, name, figsize, dpi, fmt) + graph = self.model.graph(formatting, name, figsize, dpi, fmt) + + parent_param = self._parent_param + if parent_param.is_regression: + return graph + + # Modify the graph + # 1. Remove all nodes and edges related to `{parent}_mean`: + graph.body = [ + item for item in graph.body if f"{parent_param.name}_mean" not in item + ] + # 2. Add a new edge from parent to response + graph.edge(parent_param.name, self.response_str) + + return graph def plot_trace( self, From 666090de8da049cb275fa9081af155b491118787 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 24 Jul 2024 13:59:57 -0400 Subject: [PATCH 3/4] add tests for graphing --- tests/test_graphing.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 tests/test_graphing.py diff --git a/tests/test_graphing.py b/tests/test_graphing.py new file mode 100644 index 00000000..9eb0e5e6 --- /dev/null +++ b/tests/test_graphing.py @@ -0,0 +1,9 @@ +import hssm + + +def test_simple_graphing(data_ddm): + model = hssm.HSSM(data=data_ddm, model="ddm") + graph = model.graph() + + assert graph is not None + assert all(f"{model._parent}_mean" not in node for node in graph.body) From a52d16f9bc967d85aa984ba3c871951550f68740 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Wed, 24 Jul 2024 14:05:42 -0400 Subject: [PATCH 4/4] fix ruff errors --- .pre-commit-config.yaml | 2 +- pyproject.toml | 2 +- src/hssm/defaults.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 09071995..04b3becd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.2 + rev: v0.5.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/pyproject.toml b/pyproject.toml index afe17974..fa3fcf81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ pre-commit = "^2.20.0" jupyterlab = "^4.2.3" ipykernel = "^6.29.4" ipywidgets = "^8.1.2" -ruff = "^0.5.2" +ruff = "^0.5.4" graphviz = "^0.20.3" pytest-xdist = "^3.6.1" onnxruntime = "^1.17.1" diff --git a/src/hssm/defaults.py b/src/hssm/defaults.py index c5c5d075..267f3083 100644 --- a/src/hssm/defaults.py +++ b/src/hssm/defaults.py @@ -393,8 +393,8 @@ def show_defaults(model: SupportedModels, loglik_kind=Optional[LoglikKind]) -> s output += _show_defaults_helper(model, loglik_kind) else: - for loglik_kind in model_config["likelihoods"].keys(): - output += _show_defaults_helper(model, loglik_kind) + for loglik_kind_ in model_config["likelihoods"]: + output += _show_defaults_helper(model, loglik_kind_) output.append("") output = output[:-1]