diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 9680ed4c4606b9..77c592a1c700b9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -514,8 +514,8 @@ class TrainingArguments: If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization - cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`) - If `True`, only the first process loads the pretrained model checkpoint while all other processes - have empty weights. When this setting as `True`, `sync_module_states` also must to be `True`, + If `"True"`, only the first process loads the pretrained model checkpoint while all other processes + have empty weights. When this setting as `"True"`, `sync_module_states` also must to be `"True"`, otherwise all the processes except the main process would have random weights leading to unexpected behaviour during training. - activation_checkpointing (`bool`, *optional*, defaults to `False`): @@ -1831,10 +1831,18 @@ def __post_init__(self): prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false") - os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true") - os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = self.fsdp_config.get( - "cpu_ram_efficient_loading", "false" - ) + + sync_module_states = self.fsdp_config.get("sync_module_states", "true") + cpu_ram_efficient_loading = self.fsdp_config.get("cpu_ram_efficient_loading", "false") + + if str(sync_module_states).lower() == "false" and str(cpu_ram_efficient_loading).lower() == "true": + # In this case, all the processes except the main process would have random weights leading + # to unexpected behaviour during training, thus throwing error here to prevent it. + raise ValueError("`sync_module_states` must be \"True\" if `cpu_ram_efficient_loading` is `\"True\"`") + + os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states + os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading + os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true") if is_accelerate_available():