From e6dcf8abd6f65bb4b6dfc1831b20d9ba49ce00e2 Mon Sep 17 00:00:00 2001 From: cyyever Date: Sun, 17 Dec 2023 18:13:42 +0800 Subject: [PATCH] Fix the deprecation warning of _torch_pytree._register_pytree_node (#27803) --- src/transformers/utils/generic.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 34dac8bea70cfc..155b2941f197db 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -306,7 +306,7 @@ def __init_subclass__(cls) -> None: `static_graph=True` with modules that output `ModelOutput` subclasses. """ if is_torch_available(): - _torch_pytree._register_pytree_node( + torch_pytree_register_pytree_node( cls, _model_output_flatten, _model_output_unflatten, @@ -438,7 +438,11 @@ def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Conte output_type, keys = context return output_type(**dict(zip(keys, values))) - _torch_pytree._register_pytree_node( + if hasattr(_torch_pytree, "register_pytree_node"): + torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node + else: + torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node + torch_pytree_register_pytree_node( ModelOutput, _model_output_flatten, _model_output_unflatten,