From 10dc0fea40da9adc3547de0dc732b1766b6f519a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 21 Nov 2024 16:20:56 +0000 Subject: [PATCH] revert summon_full_params_context Signed-off-by: Kyle Sayers --- src/llmcompressor/utils/fsdp/context.py | 31 ++++++++++++++----------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/src/llmcompressor/utils/fsdp/context.py b/src/llmcompressor/utils/fsdp/context.py index a359530ed..8cc062c19 100644 --- a/src/llmcompressor/utils/fsdp/context.py +++ b/src/llmcompressor/utils/fsdp/context.py @@ -1,4 +1,8 @@ -from accelerate import Accelerator +try: + from accelerate import Accelerator +except ImportError: + Accelerator = None + try: from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE, TrainingState @@ -15,19 +19,18 @@ def summon_full_params_context(model, offload_to_cpu: bool = False): - if FullyShardedDataParallel is None: - return nullcontext() - - # do not call from within summon_full_param context - if ( - hasattr(model, "training_state") - and model.training_state is TrainingState.SUMMON_FULL_PARAMS - ): - return nullcontext() - - return FullyShardedDataParallel.summon_full_params( - model, offload_to_cpu=offload_to_cpu - ) + if FullyShardedDataParallel is not None: + # avoid nested summon_full_param context + if ( + hasattr(model, "training_state") + and model.training_state is TrainingState.SUMMON_FULL_PARAMS + ): + return nullcontext() + return FullyShardedDataParallel.summon_full_params( + model, offload_to_cpu=offload_to_cpu + ) + + return nullcontext() def main_process_first_context():