Skip to content

Commit

Permalink
Add serialized type name to pytrees
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Dec 14, 2023
1 parent 3060899 commit 0ee9678
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def __init_subclass__(cls) -> None:
cls,
_model_output_flatten,
_model_output_unflatten,
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
)

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -442,6 +443,7 @@ def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Conte
ModelOutput,
_model_output_flatten,
_model_output_unflatten,
serialized_type_name=f"{ModelOutput.__module__}.{ModelOutput.__name__}",
)


Expand Down

0 comments on commit 0ee9678

Please sign in to comment.