diff --git a/src/llmcompressor/utils/fsdp/context.py b/src/llmcompressor/utils/fsdp/context.py index 177b2c02f..8cc062c19 100644 --- a/src/llmcompressor/utils/fsdp/context.py +++ b/src/llmcompressor/utils/fsdp/context.py @@ -1,10 +1,13 @@ try: from accelerate import Accelerator +except ImportError: + Accelerator = None + +try: from torch.distributed.fsdp import FullyShardedDataParallel - from torch.distributed.fsdp._common_utils import TrainingState + from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE, TrainingState except ImportError: FullyShardedDataParallel = None - Accelerator = None from contextlib import nullcontext @@ -14,8 +17,6 @@ "fix_fsdp_module_name", ] -FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" - def summon_full_params_context(model, offload_to_cpu: bool = False): if FullyShardedDataParallel is not None: @@ -46,12 +47,15 @@ def main_process_first_context(): def fix_fsdp_module_name(name: str) -> str: """ Remove FSDP wrapper prefixes from a module name. - Accounts for scenario where FSDP_WRAPPER_NAME is + Accounts for scenario where FSDP_WRAPPED_MODULE is at the end of the name, as well as in the middle. :param name: name to strip :return: stripped name """ - return name.replace(FSDP_WRAPPER_NAME + ".", "").replace( - "." + FSDP_WRAPPER_NAME, "" + if FullyShardedDataParallel is None: + return name + + return name.replace(FSDP_WRAPPED_MODULE + ".", "").replace( + "." + FSDP_WRAPPED_MODULE, "" )