Skip to content

Commit

Permalink
Switch LayerStackGetAttrKey to a custom dataclass type.
Browse files Browse the repository at this point in the history
This avoids subclassing JAX's GetAttrKey, which will be changing
its implementation in the future.
  • Loading branch information
danieldjohnson committed Nov 20, 2024
1 parent 9aec1da commit 06dee98
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
16 changes: 15 additions & 1 deletion penzai/core/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,24 @@

from __future__ import annotations

import dataclasses
from typing import Any, Optional

import jax

PyTreeDef = jax.tree_util.PyTreeDef


@dataclasses.dataclass(frozen=True)
class CustomGetAttrKey:
"""Subclass-friendly variant of jax.tree_util.GetAttrKey."""

name: str

def __str__(self):
return f".{self.name}"


def tree_flatten_exactly_one_level(
tree: Any,
) -> Optional[tuple[list[tuple[Any, Any]], PyTreeDef]]:
Expand Down Expand Up @@ -66,7 +77,10 @@ def pretty_keystr(keypath: tuple[Any, ...], tree: Any) -> str:
parts = []
for key in keypath:
if isinstance(
key, jax.tree_util.GetAttrKey | jax.tree_util.FlattenedIndexKey
key,
jax.tree_util.GetAttrKey
| jax.tree_util.FlattenedIndexKey
| CustomGetAttrKey,
):
parts.extend(("/", type(tree).__name__))
split = tree_flatten_exactly_one_level(tree)
Expand Down
5 changes: 3 additions & 2 deletions penzai/nn/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from __future__ import annotations

import collections
from collections.abc import Hashable
import copy
import dataclasses
import enum
from typing import Any, Callable, Hashable
from typing import Any, Callable

import jax
from penzai.core import named_axes
Expand All @@ -39,7 +40,7 @@ class LayerStackVarBehavior(enum.Enum):


@dataclasses.dataclass(frozen=True)
class LayerStackGetAttrKey(jax.tree_util.GetAttrKey):
class LayerStackGetAttrKey(pz_tree_util.CustomGetAttrKey):
"""GetAttrKey for LayerStack with extra metadata.
This allows us to identify whether a given PyTree leaf is contained inside a
Expand Down
14 changes: 13 additions & 1 deletion tests/nn/layer_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any
from absl.testing import absltest
import chex
import collections
import jax
from penzai import pz

Expand Down Expand Up @@ -155,7 +156,18 @@ def builder(init_base_rng, some_value):
unbound_layer, layer_vars = pz.unbind_variables(layer)
unbound_slot_layer, slot_layer_vars = pz.unbind_variables(slot_layer)

chex.assert_trees_all_equal(unbound_layer, unbound_slot_layer)
# Check as dictionaries to avoid limitations of chex:
unbound_layer_leaves, unbound_layer_treedef = (
jax.tree_util.tree_flatten_with_path(unbound_layer)
)
unbound_slot_layer_leaves, unbound_slot_layer_treedef = (
jax.tree_util.tree_flatten_with_path(unbound_slot_layer)
)
self.assertEqual(unbound_layer_treedef, unbound_slot_layer_treedef)
chex.assert_trees_all_equal(
collections.OrderedDict(unbound_layer_leaves),
collections.OrderedDict(unbound_slot_layer_leaves),
)

slot_layer_vars_dict = {var.label: var for var in slot_layer_vars}
for var in layer_vars:
Expand Down

0 comments on commit 06dee98

Please sign in to comment.