diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 3bb5e81bff9..d3ef525413c 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -97,6 +97,7 @@ wait_for_everyone, ) from .utils.constants import FSDP_PYTORCH_VERSION +from .utils.other import is_compiled_module if is_deepspeed_available(): @@ -1221,7 +1222,12 @@ def prepare(self, *args, device_placement=None): for obj in args: if isinstance(obj, torch.nn.Module): model_count += 1 - is_type_fsdp = type(obj) == FSDP + # if the model is compiled using PyTorch 2.0, + # check that the wrapped model is FSDP or not; + # else check if it is FSDP or not; + is_type_fsdp = isinstance(obj, FSDP) or ( + is_compiled_module(obj) and isinstance(obj._orig_mod, FSDP) + ) if isinstance(obj, torch.optim.Optimizer): optimizer_present = True if model_count > 1 and optimizer_present: @@ -1377,7 +1383,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e elif device_placement and not self.verify_device_map(model): model = model.to(self.device) - if self.native_amp and self.distributed_type != DistributedType.FSDP: + if self.native_amp: model._original_forward = model.forward model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) @@ -1429,7 +1435,13 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, # don't wrap it again - if type(model) != FSDP: + # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it + # is a FSDP model, don't wrap it again + is_type_fsdp = isinstance(model, FSDP) or ( + is_compiled_module(model) and isinstance(model._orig_mod, FSDP) + ) + + if not is_type_fsdp: self.state.fsdp_plugin.set_auto_wrap_policy(model) fsdp_plugin = self.state.fsdp_plugin kwargs = { @@ -1462,15 +1474,17 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e ), auto_wrap_policy=fsdp_plugin.auto_wrap_policy, ) - + # if the previous and current models are same, delete the previous one + if len(self._models) > 1 and (self._models[-2] is self._models[-1]): + del self._models[-2] self._models[-1] = model elif self.distributed_type == DistributedType.MULTI_CPU: kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) elif self.distributed_type == DistributedType.TPU and self.state.fork_launched: model = xmp.MpModelWrapper(model).to(self.device) - # torch.compile should be called last. - if self.state.dynamo_plugin.backend != DynamoBackend.NO: + # torch.compile should be called last and only if the model isn't already compiled. + if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): if not is_torch_version(">=", "2.0"): raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.") model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())