Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add serialization logic to pytree types #27871

Merged
merged 3 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)

is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
is_torch_greater_or_equal_than_1_13 = parsed_torch_version_base >= version.parse("1.13")
Expand Down
57 changes: 37 additions & 20 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from contextlib import ExitStack, contextmanager
from dataclasses import fields, is_dataclass
from enum import Enum
from functools import partial
from typing import Any, ContextManager, Iterable, List, Tuple

import numpy as np
from packaging import version

from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
from .import_utils import get_torch_version, is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy


if is_flax_available():
Expand Down Expand Up @@ -306,11 +308,19 @@ def __init_subclass__(cls) -> None:
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
if is_torch_available():
torch_pytree_register_pytree_node(
cls,
_model_output_flatten,
_model_output_unflatten,
)
if version.parse(get_torch_version()) >= version.parse("2.2"):
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved
_torch_pytree.register_pytree_node(
cls,
_model_output_flatten,
partial(_model_output_unflatten, output_type=cls),
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
)
else:
_torch_pytree._register_pytree_node(
cls,
_model_output_flatten,
partial(_model_output_unflatten, output_type=cls),
angelayi marked this conversation as resolved.
Show resolved Hide resolved
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -432,21 +442,28 @@ def to_tuple(self) -> Tuple[Any]:
import torch.utils._pytree as _torch_pytree

def _model_output_flatten(output: ModelOutput) -> Tuple[List[Any], "_torch_pytree.Context"]:
return list(output.values()), (type(output), list(output.keys()))

def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Context") -> ModelOutput:
output_type, keys = context
return output_type(**dict(zip(keys, values)))

if hasattr(_torch_pytree, "register_pytree_node"):
torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node
return list(output.values()), list(output.keys())

def _model_output_unflatten(
values: Iterable[Any],
context: "_torch_pytree.Context",
output_type=None,
) -> ModelOutput:
return output_type(**dict(zip(context, values)))

if version.parse(get_torch_version()) >= version.parse("2.2"):
_torch_pytree.register_pytree_node(
ModelOutput,
_model_output_flatten,
partial(_model_output_unflatten, output_type=ModelOutput),
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
)
else:
torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node
torch_pytree_register_pytree_node(
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
)
_torch_pytree._register_pytree_node(
ModelOutput,
_model_output_flatten,
partial(_model_output_unflatten, output_type=ModelOutput),
)


class ExplicitEnum(str, Enum):
Expand Down
41 changes: 37 additions & 4 deletions tests/utils/test_model_output.py
ydshieh marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import unittest
from dataclasses import dataclass
from typing import Optional

from transformers import AlbertForMaskedLM
from transformers.testing_utils import require_torch
from transformers.utils import ModelOutput
from transformers.utils import ModelOutput, is_torch_available


if is_torch_available():
import torch

from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_2
ydshieh marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
Expand Down Expand Up @@ -135,9 +143,7 @@ def test_torch_pytree(self):
self.assertFalse(pytree._is_leaf(x))

expected_flat_outs = [1.0, 2.0]
expected_tree_spec = pytree.TreeSpec(
ModelOutputTest, (ModelOutputTest, ["a", "c"]), [pytree.LeafSpec(), pytree.LeafSpec()]
)
expected_tree_spec = pytree.TreeSpec(ModelOutputTest, ["a", "c"], [pytree.LeafSpec(), pytree.LeafSpec()])

actual_flat_outs, actual_tree_spec = pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs)
Expand All @@ -146,6 +152,33 @@ def test_torch_pytree(self):
unflattened_x = pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)

if is_torch_greater_or_equal_than_2_2:
self.assertEqual(
pytree.treespec_dumps(actual_tree_spec),
'[1, {"type": "tests.utils.test_model_output.ModelOutputTest", "context": ["a", "c"], "children_spec": [{"type": null, "context": null, "children_spec": []}, {"type": null, "context": null, "children_spec": []}]}]',
)

@require_torch
def test_export_serialization(self):
ydshieh marked this conversation as resolved.
Show resolved Hide resolved
if not is_torch_greater_or_equal_than_2_2:
return

model_cls = AlbertForMaskedLM
model_config = model_cls.config_class()
model = model_cls(model_config)

input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}

ep = torch.export.export(model, (), input_dict)

buffer = io.BytesIO()
torch.export.save(ep, buffer)
buffer.seek(0)
loaded_ep = torch.export.load(buffer)

input_dict = {"input_ids": torch.randint(0, 30000, (1, 512), dtype=torch.int64, requires_grad=False)}
assert torch.allclose(model(**input_dict).logits, loaded_ep(**input_dict).logits)


class ModelOutputTestNoDataclass(ModelOutput):
"""Invalid test subclass of ModelOutput where @dataclass decorator is not used"""
Expand Down
Loading