Skip to content

Commit

Permalink
Add sync_module_states and cpu_ram_efficient_loading validation logic
Browse files Browse the repository at this point in the history
  • Loading branch information
helloworld1 committed Apr 19, 2024
1 parent 48b2fc4 commit e81cd98
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,8 @@ class TrainingArguments:
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
ensure they are the same across all ranks after initialization
- cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`)
If `True`, only the first process loads the pretrained model checkpoint while all other processes
have empty weights. When this setting as `True`, `sync_module_states` also must to be `True`,
If `"True"`, only the first process loads the pretrained model checkpoint while all other processes
have empty weights. When this setting as `"True"`, `sync_module_states` also must to be `"True"`,
otherwise all the processes except the main process would have random weights leading to unexpected
behaviour during training.
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -1831,10 +1831,18 @@ def __post_init__(self):
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false")
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = self.fsdp_config.get(
"cpu_ram_efficient_loading", "false"
)

sync_module_states = self.fsdp_config.get("sync_module_states", "true")
cpu_ram_efficient_loading = self.fsdp_config.get("cpu_ram_efficient_loading", "false")

if str(sync_module_states).lower() == "false" and str(cpu_ram_efficient_loading).lower() == "true":
# In this case, all the processes except the main process would have random weights leading
# to unexpected behaviour during training, thus throwing error here to prevent it.
raise ValueError("`sync_module_states` must be \"True\" if `cpu_ram_efficient_loading` is `\"True\"`")

os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading

os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")

if is_accelerate_available():
Expand Down

0 comments on commit e81cd98

Please sign in to comment.