Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal changes #162

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from absl import logging
import jax
import tree


FrozenInstanceError = dataclasses.FrozenInstanceError
Expand Down Expand Up @@ -62,7 +63,8 @@ def new_init(self, *orig_args, **orig_kwargs):
all_kwargs = dict(*orig_args, **orig_kwargs)
unknown_kwargs = set(all_kwargs.keys()) - all_fields
if unknown_kwargs:
raise ValueError(f"__init__() got unexpected kwargs: {unknown_kwargs}.")
raise ValueError(
f"__init__() got unexpected keyword arguments: {unknown_kwargs}.")

# Pass only arguments corresponding to fields with `init=True`.
valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields}
Expand Down Expand Up @@ -91,7 +93,7 @@ def dataclass(
order=False,
unsafe_hash=False,
frozen=False,
mappable_dataclass=True, # pylint: disable=redefined-outer-name
mappable_dataclass=False, # pylint: disable=redefined-outer-name
):
"""JAX-friendly wrapper for :py:func:`dataclasses.dataclass`.

Expand Down Expand Up @@ -185,7 +187,7 @@ def __call__(self, cls):
delattr(dcls, attr) # delete

def _from_tuple(args):
return dcls(zip(dcls.__dataclass_fields__.keys(), args))
return dcls(**dict(zip(dcls.__dataclass_fields__.keys(), args)))

def _to_tuple(self):
return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys())
Expand All @@ -202,6 +204,8 @@ def _getstate(self):
def _setstate(self, state):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
if not class_self.mappable_dataclass:
register_dataclass_type_with_dm_tree(dcls)
class_self.registered = True
self.__dict__.update(state)

Expand All @@ -213,6 +217,8 @@ def _setstate(self, state):
def _init(self, *args, **kwargs):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
if not class_self.mappable_dataclass:
register_dataclass_type_with_dm_tree(dcls)
class_self.registered = True
return orig_init(self, *args, **kwargs)

Expand Down Expand Up @@ -246,3 +252,30 @@ def register_dataclass_type_with_jax_tree_util(data_class):
nodetype=data_class, flatten_func=flatten, unflatten_func=unflatten)
except ValueError:
logging.info("%s is already registered as JAX PyTree node.", data_class)


def register_dataclass_type_with_dm_tree(data_class):
"""Register an existing dataclass with dm_tree node registry.

This will mean that functions in dm_tree will operate over fields of the
dataclass.

Args:
data_class: A class created using dataclasses.dataclass. It must be
constructable from keyword arguments corresponding to the members exposed
in instance.__dict__.
"""

def to_iterable(d):
keys, values = jax.util.unzip2(sorted(d.__dict__.items()))
return values, keys, keys

def from_iterable(keys, values):
return data_class(**dict(zip(keys, values)))

try:
tree.register_node(data_class, to_iterable, from_iterable)
except ValueError:
logging.log_first_n(logging.INFO,
"%s is already registered as dm_tree node.", 1,
data_class)
11 changes: 8 additions & 3 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,17 @@ class SimpleDataclass:
b: int = 2

SimpleDataclass(a=1, b=3)
with self.assertRaisesRegex(ValueError, 'init.*got unexpected kwargs'):
with self.assertRaisesRegex((ValueError, TypeError),
'.*unexpected keyword argument.*'):
SimpleDataclass(a=1, b=3, c=4)

def test_tuple_conversion(self):
@parameterized.named_parameters(
('non_mappable', False),
('mappable', True),
)
def test_tuple_conversion(self, mappable):

@chex_dataclass()
@chex_dataclass(mappable_dataclass=mappable)
class SimpleDataclass:
b: int
a: int
Expand Down