diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 34dac8bea70cfc..155b2941f197db 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -306,7 +306,7 @@ def __init_subclass__(cls) -> None: `static_graph=True` with modules that output `ModelOutput` subclasses. """ if is_torch_available(): - _torch_pytree._register_pytree_node( + torch_pytree_register_pytree_node( cls, _model_output_flatten, _model_output_unflatten, @@ -438,7 +438,11 @@ def _model_output_unflatten(values: Iterable[Any], context: "_torch_pytree.Conte output_type, keys = context return output_type(**dict(zip(keys, values))) - _torch_pytree._register_pytree_node( + if hasattr(_torch_pytree, "register_pytree_node"): + torch_pytree_register_pytree_node = _torch_pytree.register_pytree_node + else: + torch_pytree_register_pytree_node = _torch_pytree._register_pytree_node + torch_pytree_register_pytree_node( ModelOutput, _model_output_flatten, _model_output_unflatten,