Skip to content

Commit

Permalink
Fix is_fsdp_enabled() logic and make low_cpu_mem_usage by default
Browse files Browse the repository at this point in the history
  • Loading branch information
helloworld1 committed Apr 1, 2024
1 parent c9f6e5e commit ef76f46
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def is_fsdp_enabled():
torch.distributed.is_available()
and torch.distributed.is_initialized()
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
)


Expand Down Expand Up @@ -2869,7 +2868,10 @@ def from_pretrained(
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)

if is_fsdp_enabled():
# FSDP_CPU_RAM_EFFICIENT_LOADING is set by Accelerate launcher not by TrainingArguments.
# Unless it is turned off explicitly low_cpu_mem_usage should be turned on by default
# when FSDP is on since sync_module_states is also default on.
if is_fsdp_enabled() and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "True")) == 1:
low_cpu_mem_usage = True

if use_auth_token is not None:
Expand Down

0 comments on commit ef76f46

Please sign in to comment.