Skip to content

Commit

Permalink
Add serialized type name to pytrees
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Dec 14, 2023
1 parent 3060899 commit 4d19a6e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)

is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
Expand Down
60 changes: 48 additions & 12 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass
from enum import Enum
from typing import Any, ContextManager, Iterable, List, Tuple
from typing import Any, ContextManager, Dict, Iterable, List, Tuple, Type

import numpy as np
from packaging import version

from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
from .import_utils import get_torch_version, is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy


if is_flax_available():
Expand Down Expand Up @@ -306,11 +307,21 @@ def __init_subclass__(cls) -> None:
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
if is_torch_available():
_torch_pytree._register_pytree_node(
cls,
_model_output_flatten,
_model_output_unflatten,
)
if version.parse(get_torch_version()) >= version.parse("2.2"):
_torch_pytree.register_pytree_node(
cls,
_model_output_flatten,
_model_output_unflatten,
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
from_dumpable_context=_model_output_from_dumpable_context,
to_dumpable_context=_model_output_to_dumpable_context,
)
else:
_torch_pytree._register_pytree_node(
cls,
_model_output_flatten,
_model_output_unflatten,
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -438,11 +449,36 @@ def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Conte
output_type, keys = context
return output_type(**dict(zip(keys, values)))

_torch_pytree._register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
)
if version.parse(get_torch_version()) >= version.parse("2.2"):
SERIALIZED_CLASS_TO_PYTHON_CLASS: Dict[str, Type[Any]] = {}

def _model_output_to_dumpable_context(context: "_torch_pytree.Context") -> "_torch_pytree.DumpableContext":
python_class, keys = context
serialized_class = f"{python_class.__module__}.{python_class.__name__}"
SERIALIZED_CLASS_TO_PYTHON_CLASS[serialized_class] = python_class
return (serialized_class, keys)

def _model_output_from_dumpable_context(
dumpable_context: "_torch_pytree.DumpableContext"
) -> "_torch_pytree.Context":
serialized_class, keys = dumpable_context
python_class = SERIALIZED_CLASS_TO_PYTHON_CLASS[serialized_class]
return (python_class, keys)

_torch_pytree.register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
from_dumpable_context=_model_output_from_dumpable_context,
to_dumpable_context=_model_output_to_dumpable_context,
)
else:
_torch_pytree._register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
)


class ExplicitEnum(str, Enum):
Expand Down
8 changes: 8 additions & 0 deletions tests/utils/test_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ def test_torch_pytree(self):
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)

from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2

if is_torch_greater_or_equal_than_2_2:
self.assertEqual(
pytree.treespec_dumps(actual_tree_spec),
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["tests.utils.test_model_output.ModelOutputTest", ["a", "c"]], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
)


class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
Expand Down

0 comments on commit 4d19a6e

Please sign in to comment.