Skip to content

Commit

Permalink
Fix generic dataclasses with bound parameters.
Browse files Browse the repository at this point in the history
This alters the way in which mappable dataclasses are created in order to fix a
crash when a mappable, frozen generic dataclass is instantiated with a bound
type parameter.

PiperOrigin-RevId: 521409933
  • Loading branch information
stompchicken authored and ChexDev committed Apr 3, 2023
1 parent 06426a1 commit 6adb270
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 14 deletions.
70 changes: 60 additions & 10 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,37 @@ def new_init(self, *orig_args, **orig_kwargs):
return cls


def make_mappable(cls):
"""Exposes dataclass as ``collections.abc.Mapping`` descendent.
Allows to traverse dataclasses in methods from `dm-tree` library.
NOTE: changes dataclasses constructor to dict-type
(i.e. positional args aren't supported; however can use generators/iterables).
Args:
cls: A dataclass to mutate.
Returns:
Mutated dataclass implementing ``collections.abc.Mapping`` interface.
"""
# Define methods for compatibility with `collections.abc.Mapping`.
setattr(cls, "__getitem__", lambda self, x: self.__dict__[x])
setattr(cls, "__len__", lambda self: len(self.__dict__))
setattr(cls, "__iter__", lambda self: iter(self.__dict__))

# Update base class to derive from Mapping
dct = dict(cls.__dict__)
if "__dict__" in dct:
dct.pop("__dict__") # Avoid self-references.

# Remove object from the sequence of base classes. Deriving from both Mapping
# and object will cause a failure to create a MRO for the updated class
bases = tuple(b for b in cls.__bases__ if b != object)
cls = type(cls.__name__, bases + (collections.abc.Mapping,), dct)
return cls


@dataclass_transform()
def dataclass(
cls=None,
Expand Down Expand Up @@ -155,6 +186,14 @@ def __init__(
def __call__(self, cls):
"""Forwards class to dataclasses's wrapper and registers it with JAX."""

if self.mappable_dataclass:
cls = make_mappable(cls)
# We remove `collection.abc.Mapping` mixin methods here to allow
# fields with these names.
for attr in ("values", "keys", "get", "items"):
setattr(cls, attr, None) # redefine
delattr(cls, attr) # delete

# Remove once https://github.com/python/cpython/pull/24484 is merged.
for base in cls.__bases__:
if (dataclasses.is_dataclass(base) and
Expand All @@ -169,6 +208,7 @@ def __call__(self, cls):
eq=self.eq,
order=self.order,
unsafe_hash=self.unsafe_hash,
kw_only=True,
frozen=self.frozen)
# pytype: enable=wrong-keyword-args

Expand All @@ -178,16 +218,8 @@ def __call__(self, cls):
raise ValueError(f"The following dataclass fields are disallowed: "
f"{invalid_fields} ({dcls}).")

if self.mappable_dataclass:
dcls = mappable_dataclass(dcls)
# We remove `collection.abc.Mapping` mixin methods here to allow
# fields with these names.
for attr in ("values", "keys", "get", "items"):
setattr(dcls, attr, None) # redefine
delattr(dcls, attr) # delete

def _from_tuple(args):
return dcls(zip(dcls.__dataclass_fields__.keys(), args))
return dcls(**dict(zip(dcls.__dataclass_fields__.keys(), args)))

def _to_tuple(self):
return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys())
Expand All @@ -209,14 +241,32 @@ def _setstate(self, state):

orig_init = dcls.__init__

all_fields = set(f.name for f in cls.__dataclass_fields__.values())
init_fields = [f.name for f in dcls.__dataclass_fields__.values() if f.init]

# Patch object's __init__ such that the class is registered on creation if
# it is not registered on deserialization.
@functools.wraps(orig_init)
def _init(self, *args, **kwargs):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
class_self.registered = True
return orig_init(self, *args, **kwargs)

if self.mappable_dataclass:
if (args and kwargs) or len(args) > 1:
raise ValueError(
"Mappable dataclass constructor doesn't support positional args."
"(it has the same constructor as python dict)")
all_kwargs = dict(*args, **kwargs)
unknown_kwargs = set(all_kwargs.keys()) - all_fields
if unknown_kwargs:
raise ValueError(
f"__init__() got unexpected kwargs: {unknown_kwargs}."
)
valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields}
return orig_init(self, **valid_kwargs)
else:
return orig_init(self, *args, **kwargs)

setattr(dcls, "from_tuple", _from_tuple)
setattr(dcls, "to_tuple", _to_tuple)
Expand Down
13 changes: 9 additions & 4 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,19 +521,24 @@ def _is_leaf(value) -> bool:
jax.tree_util.tree_map(lambda x: x, dcls, is_leaf=_is_leaf), dcls)

@parameterized.named_parameters(
('mappable', True),
('not_mappable', False),
('mappable_frozen', True, True),
('not_mappable_frozen', False, True),
('mappable_not_frozen', True, False),
('not_mappable_not_frozen', False, False),
)
def test_generic_dataclass(self, mappable):
def test_generic_dataclass(self, mappable, frozen):
T = TypeVar('T')

@chex_dataclass(mappable_dataclass=mappable)
@chex_dataclass(mappable_dataclass=mappable, frozen=frozen)
class GenericDataclass(Generic[T]):
a: T # pytype: disable=invalid-annotation # enable-bare-annotations

obj = GenericDataclass(a=np.array([1.0, 1.0]))
asserts.assert_tree_all_close(obj.a, 1.0)

obj = GenericDataclass[np.array](a=np.array([1.0, 1.0]))
asserts.assert_tree_all_close(obj.a, 1.0)

def test_mappable_eq_override(self):

@chex_dataclass(mappable_dataclass=True)
Expand Down

0 comments on commit 6adb270

Please sign in to comment.