From 88ec282c15473d3f2fce71d82f4b66e5784090a7 Mon Sep 17 00:00:00 2001 From: angelayi Date: Tue, 16 Jan 2024 13:21:41 -0800 Subject: [PATCH] Modify context --- src/transformers/utils/generic.py | 41 ++++++++++--------------------- tests/utils/test_model_output.py | 6 ++--- 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 8acd2baec16eb4..034e39fc1ab740 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -22,7 +22,8 @@ from contextlib import ExitStack, contextmanager from dataclasses import fields, is_dataclass from enum import Enum -from typing import Any, ContextManager, Dict, Iterable, List, Tuple, Type +from functools import partial +from typing import Any, ContextManager, Iterable, List, Tuple import numpy as np from packaging import version @@ -311,16 +312,14 @@ def __init_subclass__(cls) -> None: _torch_pytree.register_pytree_node( cls, _model_output_flatten, - _model_output_unflatten, + partial(_model_output_unflatten, output_type=cls), serialized_type_name=f"{cls.__module__}.{cls.__name__}", - from_dumpable_context=_model_output_from_dumpable_context, - to_dumpable_context=_model_output_to_dumpable_context, ) else: _torch_pytree._register_pytree_node( cls, _model_output_flatten, - _model_output_unflatten, + partial(_model_output_unflatten, output_type=cls), ) def __init__(self, *args, **kwargs): @@ -443,41 +442,27 @@ def to_tuple(self) -> Tuple[Any]: import torch.utils._pytree as _torch_pytree def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]: - return list(output.values()), (type(output), list(output.keys())) + return list(output.values()), list(output.keys()) - def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput: - output_type, keys = context - return output_type(**dict(zip(keys, values))) + def _model_output_unflatten( + values: Iterable[Any], + context: "_torch_pytree.Context", + output_type=None, + ) -> ModelOutput: + return output_type(**dict(zip(context, values))) if version.parse(get_torch_version()) >= version.parse("2.2"): - SERIALIZED_CLASS_TO_PYTHON_CLASS: Dict[str, Type[Any]] = {} - - def _model_output_to_dumpable_context(context: "_torch_pytree.Context") -> "_torch_pytree.DumpableContext": - python_class, keys = context - serialized_class = f"{python_class.__module__}.{python_class.__name__}" - SERIALIZED_CLASS_TO_PYTHON_CLASS[serialized_class] = python_class - return (serialized_class, keys) - - def _model_output_from_dumpable_context( - dumpable_context: "_torch_pytree.DumpableContext" - ) -> "_torch_pytree.Context": - serialized_class, keys = dumpable_context - python_class = SERIALIZED_CLASS_TO_PYTHON_CLASS[serialized_class] - return (python_class, keys) - _torch_pytree.register_pytree_node( ModelOutput, _model_output_flatten, - _model_output_unflatten, + partial(_model_output_unflatten, output_type=ModelOutput), serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}", - from_dumpable_context=_model_output_from_dumpable_context, - to_dumpable_context=_model_output_to_dumpable_context, ) else: _torch_pytree._register_pytree_node( ModelOutput, _model_output_flatten, - _model_output_unflatten, + partial(_model_output_unflatten, output_type=ModelOutput), ) diff --git a/tests/utils/test_model_output.py b/tests/utils/test_model_output.py index 065197b27027c2..ded05286803e67 100644 --- a/tests/utils/test_model_output.py +++ b/tests/utils/test_model_output.py @@ -135,9 +135,7 @@ def test_torch_pytree(self): self.assertFalse(pytree._is_leaf(x)) expected_flat_outs = [1.0, 2.0] - expected_tree_spec = pytree.TreeSpec( - ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()] - ) + expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()]) actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x) self.assertEqual(expected_flat_outs, actual_flat_outs) @@ -151,7 +149,7 @@ def test_torch_pytree(self): if is_torch_greater_or_equal_than_2_2: self.assertEqual( pytree.treespec_dumps(actual_tree_spec), - '[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["tests.utils.test_model_output.ModelOutputTest", ["a", "c"]], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]', + '[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]', )