Skip to content

Commit

Permalink
revert summon_full_params_context
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <[email protected]>
  • Loading branch information
kylesayrs committed Nov 21, 2024
1 parent 50e881f commit 10dc0fe
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions src/llmcompressor/utils/fsdp/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from accelerate import Accelerator
try:
from accelerate import Accelerator
except ImportError:
Accelerator = None

try:
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp._common_utils import FSDP_WRAPPED_MODULE, TrainingState
Expand All @@ -15,19 +19,18 @@


def summon_full_params_context(model, offload_to_cpu: bool = False):
if FullyShardedDataParallel is None:
return nullcontext()

# do not call from within summon_full_param context
if (
hasattr(model, "training_state")
and model.training_state is TrainingState.SUMMON_FULL_PARAMS
):
return nullcontext()

return FullyShardedDataParallel.summon_full_params(
model, offload_to_cpu=offload_to_cpu
)
if FullyShardedDataParallel is not None:
# avoid nested summon_full_param context
if (
hasattr(model, "training_state")
and model.training_state is TrainingState.SUMMON_FULL_PARAMS
):
return nullcontext()
return FullyShardedDataParallel.summon_full_params(
model, offload_to_cpu=offload_to_cpu
)

return nullcontext()


def main_process_first_context():
Expand Down

0 comments on commit 10dc0fe

Please sign in to comment.