From 6adb2700a528209d141461666ea0e9c37aa94403 Mon Sep 17 00:00:00 2001 From: Stephen Spencer Date: Mon, 3 Apr 2023 03:15:54 -0700 Subject: [PATCH] Fix generic dataclasses with bound parameters. This alters the way in which mappable dataclasses are created in order to fix a crash when a mappable, frozen generic dataclass is instantiated with a bound type parameter. PiperOrigin-RevId: 521409933 --- chex/_src/dataclass.py | 70 +++++++++++++++++++++++++++++++------ chex/_src/dataclass_test.py | 13 ++++--- 2 files changed, 69 insertions(+), 14 deletions(-) diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index 94c9279d..249d73c4 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -83,6 +83,37 @@ def new_init(self, *orig_args, **orig_kwargs): return cls +def make_mappable(cls): + """Exposes dataclass as ``collections.abc.Mapping`` descendent. + + Allows to traverse dataclasses in methods from `dm-tree` library. + + NOTE: changes dataclasses constructor to dict-type + (i.e. positional args aren't supported; however can use generators/iterables). + + Args: + cls: A dataclass to mutate. + + Returns: + Mutated dataclass implementing ``collections.abc.Mapping`` interface. + """ + # Define methods for compatibility with `collections.abc.Mapping`. + setattr(cls, "__getitem__", lambda self, x: self.__dict__[x]) + setattr(cls, "__len__", lambda self: len(self.__dict__)) + setattr(cls, "__iter__", lambda self: iter(self.__dict__)) + + # Update base class to derive from Mapping + dct = dict(cls.__dict__) + if "__dict__" in dct: + dct.pop("__dict__") # Avoid self-references. + + # Remove object from the sequence of base classes. Deriving from both Mapping + # and object will cause a failure to create a MRO for the updated class + bases = tuple(b for b in cls.__bases__ if b != object) + cls = type(cls.__name__, bases + (collections.abc.Mapping,), dct) + return cls + + @dataclass_transform() def dataclass( cls=None, @@ -155,6 +186,14 @@ def __init__( def __call__(self, cls): """Forwards class to dataclasses's wrapper and registers it with JAX.""" + if self.mappable_dataclass: + cls = make_mappable(cls) + # We remove `collection.abc.Mapping` mixin methods here to allow + # fields with these names. + for attr in ("values", "keys", "get", "items"): + setattr(cls, attr, None) # redefine + delattr(cls, attr) # delete + # Remove once https://github.com/python/cpython/pull/24484 is merged. for base in cls.__bases__: if (dataclasses.is_dataclass(base) and @@ -169,6 +208,7 @@ def __call__(self, cls): eq=self.eq, order=self.order, unsafe_hash=self.unsafe_hash, + kw_only=True, frozen=self.frozen) # pytype: enable=wrong-keyword-args @@ -178,16 +218,8 @@ def __call__(self, cls): raise ValueError(f"The following dataclass fields are disallowed: " f"{invalid_fields} ({dcls}).") - if self.mappable_dataclass: - dcls = mappable_dataclass(dcls) - # We remove `collection.abc.Mapping` mixin methods here to allow - # fields with these names. - for attr in ("values", "keys", "get", "items"): - setattr(dcls, attr, None) # redefine - delattr(dcls, attr) # delete - def _from_tuple(args): - return dcls(zip(dcls.__dataclass_fields__.keys(), args)) + return dcls(**dict(zip(dcls.__dataclass_fields__.keys(), args))) def _to_tuple(self): return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys()) @@ -209,6 +241,9 @@ def _setstate(self, state): orig_init = dcls.__init__ + all_fields = set(f.name for f in cls.__dataclass_fields__.values()) + init_fields = [f.name for f in dcls.__dataclass_fields__.values() if f.init] + # Patch object's __init__ such that the class is registered on creation if # it is not registered on deserialization. @functools.wraps(orig_init) @@ -216,7 +251,22 @@ def _init(self, *args, **kwargs): if not class_self.registered: register_dataclass_type_with_jax_tree_util(dcls) class_self.registered = True - return orig_init(self, *args, **kwargs) + + if self.mappable_dataclass: + if (args and kwargs) or len(args) > 1: + raise ValueError( + "Mappable dataclass constructor doesn't support positional args." + "(it has the same constructor as python dict)") + all_kwargs = dict(*args, **kwargs) + unknown_kwargs = set(all_kwargs.keys()) - all_fields + if unknown_kwargs: + raise ValueError( + f"__init__() got unexpected kwargs: {unknown_kwargs}." + ) + valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields} + return orig_init(self, **valid_kwargs) + else: + return orig_init(self, *args, **kwargs) setattr(dcls, "from_tuple", _from_tuple) setattr(dcls, "to_tuple", _to_tuple) diff --git a/chex/_src/dataclass_test.py b/chex/_src/dataclass_test.py index b3fb0215..27ad3744 100644 --- a/chex/_src/dataclass_test.py +++ b/chex/_src/dataclass_test.py @@ -521,19 +521,24 @@ def _is_leaf(value) -> bool: jax.tree_util.tree_map(lambda x: x, dcls, is_leaf=_is_leaf), dcls) @parameterized.named_parameters( - ('mappable', True), - ('not_mappable', False), + ('mappable_frozen', True, True), + ('not_mappable_frozen', False, True), + ('mappable_not_frozen', True, False), + ('not_mappable_not_frozen', False, False), ) - def test_generic_dataclass(self, mappable): + def test_generic_dataclass(self, mappable, frozen): T = TypeVar('T') - @chex_dataclass(mappable_dataclass=mappable) + @chex_dataclass(mappable_dataclass=mappable, frozen=frozen) class GenericDataclass(Generic[T]): a: T # pytype: disable=invalid-annotation # enable-bare-annotations obj = GenericDataclass(a=np.array([1.0, 1.0])) asserts.assert_tree_all_close(obj.a, 1.0) + obj = GenericDataclass[np.array](a=np.array([1.0, 1.0])) + asserts.assert_tree_all_close(obj.a, 1.0) + def test_mappable_eq_override(self): @chex_dataclass(mappable_dataclass=True)