Skip to content

Commit

Permalink
Merge pull request #521 from lnccbrown/bambi-014-fix-graphing
Browse files Browse the repository at this point in the history
Fix graphing
  • Loading branch information
digicosmos86 authored Aug 5, 2024
2 parents 83a3017 + a52d16f commit 1a7cadc
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 153 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.2
rev: v0.5.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pre-commit = "^2.20.0"
jupyterlab = "^4.2.3"
ipykernel = "^6.29.4"
ipywidgets = "^8.1.2"
ruff = "^0.5.2"
ruff = "^0.5.4"
graphviz = "^0.20.3"
pytest-xdist = "^3.6.1"
onnxruntime = "^1.17.1"
Expand Down
4 changes: 2 additions & 2 deletions src/hssm/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ def show_defaults(model: SupportedModels, loglik_kind=Optional[LoglikKind]) -> s
output += _show_defaults_helper(model, loglik_kind)

else:
for loglik_kind in model_config["likelihoods"].keys():
output += _show_defaults_helper(model, loglik_kind)
for loglik_kind_ in model_config["likelihoods"]:
output += _show_defaults_helper(model, loglik_kind_)
output.append("")

output = output[:-1]
Expand Down
37 changes: 13 additions & 24 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
_make_default_prior,
)
from hssm.utils import (
HSSMModelGraph,
_get_alias_dict,
_print_prior,
_process_param_in_kwargs,
Expand Down Expand Up @@ -568,7 +567,6 @@ def sample(
kwargs["nuts_sampler"] = (
"pymc" if sampler == "mcmc" else sampler.split("_")[1]
)
print(kwargs["nuts_sampler"])

self._inference_obj = self.model.fit(
inference_method="mcmc"
Expand Down Expand Up @@ -897,8 +895,7 @@ def response_str(self) -> str:
"""Return the response variable names in string format."""
return ",".join(self.response)

# NOTE: can't annotate return type because the graphviz dependency is
# optional
# NOTE: can't annotate return type because the graphviz dependency is optional
def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png"):
"""Produce a graphviz Digraph from a built HSSM model.
Expand Down Expand Up @@ -929,30 +926,22 @@ def graph(self, formatting="plain", name=None, figsize=None, dpi=300, fmt="png")
-------
graphviz.Graph
The graph
Note
----
The code is largely copied from
https://github.com/bambinos/bambi/blob/main/bambi/models.py
Credit for the code goes to Bambi developers.
"""
self.model._check_built()

graphviz = HSSMModelGraph(
model=self.pymc_model, parent=self._parent_param
).make_graph(formatting=formatting, response_str=self.response_str)
graph = self.model.graph(formatting, name, figsize, dpi, fmt)

width, height = (None, None) if figsize is None else figsize
parent_param = self._parent_param
if parent_param.is_regression:
return graph

if name is not None:
graphviz_ = graphviz.copy()
graphviz_.graph_attr.update(size=f"{width},{height}!")
graphviz_.graph_attr.update(dpi=str(dpi))
graphviz_.render(filename=name, format=fmt, cleanup=True)

return graphviz_
# Modify the graph
# 1. Remove all nodes and edges related to `{parent}_mean`:
graph.body = [
item for item in graph.body if f"{parent_param.name}_mean" not in item
]
# 2. Add a new edge from parent to response
graph.edge(parent_param.name, self.response_str)

return graphviz
return graph

def plot_trace(
self,
Expand Down
126 changes: 1 addition & 125 deletions src/hssm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""

import logging
from typing import Any, Iterable, Literal, NewType, cast
from typing import Any, Literal, cast

import bambi as bmb
import jax
Expand All @@ -20,8 +20,6 @@
import xarray as xr
from bambi.terms import CommonTerm, GroupSpecificTerm, HSGPTerm, OffsetTerm
from huggingface_hub import hf_hub_download
from pymc.model_graph import DEFAULT_NODE_FORMATTERS, ModelGraph
from pytensor import function

from .param import Param

Expand Down Expand Up @@ -146,128 +144,6 @@ def _get_alias_dict(
return alias_dict


def fast_eval(var):
"""Fast evaluation of a variable.
Notes
-----
This is a helper function required for one of the functions below.
"""
return function([], var, mode="FAST_COMPILE")()


VarName = NewType("VarName", str)


class HSSMModelGraph(ModelGraph):
"""Customize PyMC's ModelGraph class to inject the missing parent parameter.
Notes
-----
This is really a hack. There might be better ways to get around the
parent parameter issue.
"""

def __init__(self, model, parent):
self.parent = parent
super().__init__(model)

def make_graph(
self,
var_names: Iterable[VarName] | None = None,
formatting: str = "plain",
response_str: str = "rt,response",
):
"""Make graphviz Digraph of PyMC model.
Returns
-------
graphviz.Digraph
Notes
-----
This is a slightly modified version of the code in:
https://github.com/pymc-devs/pymc/blob/main/pymc/model_graph.py
Credit for this code goes to PyMC developers.
"""
try:
import graphviz # pylint: disable=C0415
except ImportError as e:
e.msg = (
"This function requires the python library graphviz, "
+ "along with binaries. "
+ "The easiest way to install all of this is by running\n\n"
+ "\tconda install -c conda-forge python-graphviz"
)
raise e
graph = graphviz.Digraph(self.model.name)
for plate_label, all_var_names in self.get_plates(var_names).items():
if plate_label:
# must be preceded by 'cluster' to get a box around it
with graph.subgraph(name="cluster" + plate_label) as sub:
for var_name in all_var_names:
self._make_node(
var_name,
sub,
formatting=formatting,
node_formatters=DEFAULT_NODE_FORMATTERS,
)
# plate label goes bottom right
sub.attr(
label=plate_label,
labeljust="r",
labelloc="b",
style="rounded",
)

else:
for var_name in all_var_names:
self._make_node(
var_name,
graph,
formatting=formatting,
node_formatters=DEFAULT_NODE_FORMATTERS,
)

if self.parent.is_regression:
# Insert the parent parameter that's not included in the graph
with graph.subgraph(name="cluster" + self.parent.name) as sub:
sub.node(
self.parent.name,
label=f"{self.parent.name}\n~\nDeterministic",
shape="box",
)
shape = fast_eval(self.model[response_str].shape)
plate_label = f"{response_str}_obs({shape[0]})"

sub.attr(
label=plate_label,
labeljust="r",
labelloc="b",
style="rounded",
)

for child, parents in self.make_compute_graph(var_names=var_names).items():
# parents is a set of rv names that precede child rv nodes
for parent in parents:
if (
self.parent.is_regression
and parent.startswith(f"{self.parent.name}_")
and child == response_str
):
# Modify the edges so that they point to the
# parent parameter
graph.edge(parent.replace(":", "&"), self.parent.name)
else:
graph.edge(parent.replace(":", "&"), child.replace(":", "&"))

if self.parent.is_regression:
graph.edge(self.parent.name, response_str)

return graph


def set_floatX(dtype: Literal["float32", "float64"], update_jax: bool = True):
"""Set float types for pytensor and Jax.
Expand Down
9 changes: 9 additions & 0 deletions tests/test_graphing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import hssm


def test_simple_graphing(data_ddm):
model = hssm.HSSM(data=data_ddm, model="ddm")
graph = model.graph()

assert graph is not None
assert all(f"{model._parent}_mean" not in node for node in graph.body)

0 comments on commit 1a7cadc

Please sign in to comment.