Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FSDP config for CPU RAM efficient loading through accelerate #30002

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,11 @@ class TrainingArguments:
- sync_module_states (`bool`, *optional*, defaults to `True`)
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
ensure they are the same across all ranks after initialization
- cpu_ram_efficient_loading (`bool`, *optional*, defaults to `False`)
If `"True"`, only the first process loads the pretrained model checkpoint while all other processes
have empty weights. When this setting as `"True"`, `sync_module_states` also must to be `"True"`,
otherwise all the processes except the main process would have random weights leading to unexpected
behaviour during training.
- activation_checkpointing (`bool`, *optional*, defaults to `False`):
If `"True"`, activation checkpointing is a technique to reduce memory usage by clearing activations of
certain layers and recomputing them during a backward pass. Effectively, this trades extra
Expand Down Expand Up @@ -1826,7 +1831,18 @@ def __post_init__(self):
prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH")
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false")
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")

sync_module_states = self.fsdp_config.get("sync_module_states", "true")
cpu_ram_efficient_loading = self.fsdp_config.get("cpu_ram_efficient_loading", "false")

if str(sync_module_states).lower() == "false" and str(cpu_ram_efficient_loading).lower() == "true":
# In this case, all the processes except the main process would have random weights leading
# to unexpected behaviour during training, thus throwing error here to prevent it.
raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`')

os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading

os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")

if is_accelerate_available():
Expand Down
4 changes: 4 additions & 0 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def setUp(self):
"limit_all_gathers": "False",
"use_orig_params": "True",
"sync_module_states": "True",
"cpu_ram_efficient_loading": "True",
"activation_checkpointing": "False",
"min_num_params": 1,
}
Expand Down Expand Up @@ -208,6 +209,9 @@ def test_fsdp_config_transformers_auto_wrap(self, sharding_strategy, dtype):
self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"])
self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"])
self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"])
self.assertEqual(
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"], fsdp_config["cpu_ram_efficient_loading"]
)
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")

@parameterized.expand(params, name_func=_parameterized_custom_name_func)
Expand Down
Loading