diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 4dd69ca72ef908..8ef4f3c54a0a60 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1773,7 +1773,9 @@ def __post_init__(self): os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false") os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true") - os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = self.fsdp_config.get("cpu_ram_efficient_loading", "false") + os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = self.fsdp_config.get( + "cpu_ram_efficient_loading", "false" + ) os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true") if is_accelerate_available(): diff --git a/tests/fsdp/test_fsdp.py b/tests/fsdp/test_fsdp.py index d9dea012f13bf8..205d095e0357c9 100644 --- a/tests/fsdp/test_fsdp.py +++ b/tests/fsdp/test_fsdp.py @@ -209,7 +209,9 @@ def test_fsdp_config_transformers_auto_wrap(self, sharding_strategy, dtype): self.assertEqual(os.environ[f"{prefix}FORWARD_PREFETCH"], fsdp_config["forward_prefetch"]) self.assertEqual(os.environ[f"{prefix}USE_ORIG_PARAMS"], fsdp_config["use_orig_params"]) self.assertEqual(os.environ[f"{prefix}SYNC_MODULE_STATES"], fsdp_config["sync_module_states"]) - self.assertEqual(os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"], fsdp_config["cpu_ram_efficient_loading"]) + self.assertEqual( + os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"], fsdp_config["cpu_ram_efficient_loading"] + ) self.assertEqual(os.environ.get("ACCELERATE_USE_FSDP", "false"), "true") @parameterized.expand(params, name_func=_parameterized_custom_name_func)