Skip to content

Commit

Permalink
Style fix
Browse files Browse the repository at this point in the history
  • Loading branch information
helloworld1 committed Apr 15, 2024
1 parent 8b5c9a2 commit cf40b24
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,7 +1773,9 @@ def __post_init__(self):
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false")
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = self.fsdp_config.get("cpu_ram_efficient_loading", "false")
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = self.fsdp_config.get(
"cpu_ram_efficient_loading", "false"
)
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")

if is_accelerate_available():
Expand Down
4 changes: 3 additions & 1 deletion tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def test_fsdp_config_transformers_auto_wrap(self, sharding_strategy, dtype):
self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"])
self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"])
self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"])
self.assertEqual(os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"], fsdp_config["cpu_ram_efficient_loading"])
self.assertEqual(
os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"], fsdp_config["cpu_ram_efficient_loading"]
)
self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true")

@parameterized.expand(params, name_func=_parameterized_custom_name_func)
Expand Down

0 comments on commit cf40b24

Please sign in to comment.