Skip to content

Commit

Permalink
Modify context
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Jan 16, 2024
1 parent b3f6e0e commit 88ec282
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 32 deletions.
41 changes: 13 additions & 28 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass
from enum import Enum
from typing import Any, ContextManager, Dict, Iterable, List, Tuple, Type
from functools import partial
from typing import Any, ContextManager, Iterable, List, Tuple

import numpy as np
from packaging import version
Expand Down Expand Up @@ -311,16 +312,14 @@ def __init_subclass__(cls) -> None:
_torch_pytree.register_pytree_node(
cls,
_model_output_flatten,
_model_output_unflatten,
partial(_model_output_unflatten, output_type=cls),
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
from_dumpable_context=_model_output_from_dumpable_context,
to_dumpable_context=_model_output_to_dumpable_context,
)
else:
_torch_pytree._register_pytree_node(
cls,
_model_output_flatten,
_model_output_unflatten,
partial(_model_output_unflatten, output_type=cls),
)

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -443,41 +442,27 @@ def to_tuple(self) -> Tuple[Any]:
import torch.utils._pytree as _torch_pytree

def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
return list(output.values()), (type(output), list(output.keys()))
return list(output.values()), list(output.keys())

def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput:
output_type, keys = context
return output_type(**dict(zip(keys, values)))
def _model_output_unflatten(
values: Iterable[Any],
context: "_torch_pytree.Context",
output_type=None,
) -> ModelOutput:
return output_type(**dict(zip(context, values)))

if version.parse(get_torch_version()) >= version.parse("2.2"):
SERIALIZED_CLASS_TO_PYTHON_CLASS: Dict[str, Type[Any]] = {}

def _model_output_to_dumpable_context(context: "_torch_pytree.Context") -> "_torch_pytree.DumpableContext":
python_class, keys = context
serialized_class = f"{python_class.__module__}.{python_class.__name__}"
SERIALIZED_CLASS_TO_PYTHON_CLASS[serialized_class] = python_class
return (serialized_class, keys)

def _model_output_from_dumpable_context(
dumpable_context: "_torch_pytree.DumpableContext"
) -> "_torch_pytree.Context":
serialized_class, keys = dumpable_context
python_class = SERIALIZED_CLASS_TO_PYTHON_CLASS[serialized_class]
return (python_class, keys)

_torch_pytree.register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
partial(_model_output_unflatten, output_type=ModelOutput),
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
from_dumpable_context=_model_output_from_dumpable_context,
to_dumpable_context=_model_output_to_dumpable_context,
)
else:
_torch_pytree._register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
partial(_model_output_unflatten, output_type=ModelOutput),
)


Expand Down
6 changes: 2 additions & 4 deletions tests/utils/test_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,7 @@ def test_torch_pytree(self):
self.assertFalse(pytree._is_leaf(x))

expected_flat_outs = [1.0, 2.0]
expected_tree_spec = pytree.TreeSpec(
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
)
expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])

actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs)
Expand All @@ -151,7 +149,7 @@ def test_torch_pytree(self):
if is_torch_greater_or_equal_than_2_2:
self.assertEqual(
pytree.treespec_dumps(actual_tree_spec),
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["tests.utils.test_model_output.ModelOutputTest", ["a", "c"]], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
)


Expand Down

0 comments on commit 88ec282

Please sign in to comment.