Skip to content

Commit

Permalink
Update src/transformers/modeling_utils.py
Browse files Browse the repository at this point in the history
Co-authored-by: Zach Mueller <[email protected]>
  • Loading branch information
SunMarc and muellerzr authored Apr 18, 2024
1 parent a691cd5 commit fac445e
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4817,15 +4817,14 @@ def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
recursively, not just the top-level distributed containers.
"""
# Use accelerate implementation if available (should always be the case when using torch)
# This is for pytorch, as we also have to handle things like dynamo
if is_accelerate_available():
kwargs = {}
if version.parse(importlib.metadata.version("accelerate")) >= version.parse("0.29.0"):
kwargs["recursive"] = recursive
# Need to update to accelerate>0.29.0 if one uses recursive=True
elif recursive:
logger.error(
"Using recursive=True in unwrap_model requires a version of accelerate >= 0.29.0. Please upgrade your version of accelerate."
)
if recursive:
if not is_accelerate_available("0.29.0"):
raise RuntimeError("Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate")
else:
kwargs["recursive"] = recursive
return extract_model_from_parallel(model, **kwargs)

# since there could be multiple levels of wrapping, unwrap recursively
Expand Down

0 comments on commit fac445e

Please sign in to comment.