diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 12ae77908ebfae..18bf004efa05b0 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1840,12 +1840,12 @@ def __post_init__(self): ) prefetch_policy = self.fsdp_config.get("backward_prefetch", "NO_PREFETCH") os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() - os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefetch", "false") + os.environ[f"{prefix}FORWARD_PREFETCH"] = str(self.fsdp_config.get("forward_prefetch", "false")).lower() - sync_module_states = self.fsdp_config.get("sync_module_states", "true") - cpu_ram_efficient_loading = self.fsdp_config.get("cpu_ram_efficient_loading", "false") + sync_module_states = str(self.fsdp_config.get("sync_module_states", "true")).lower() + cpu_ram_efficient_loading = str(self.fsdp_config.get("cpu_ram_efficient_loading", "false")).lower() - if str(sync_module_states).lower() == "false" and str(cpu_ram_efficient_loading).lower() == "true": + if sync_module_states == "false" and cpu_ram_efficient_loading == "true": # In this case, all the processes except the main process would have random weights leading # to unexpected behaviour during training, thus throwing error here to prevent it. raise ValueError('`sync_module_states` must be `"True"` if `cpu_ram_efficient_loading` is `"True"`') @@ -1853,7 +1853,7 @@ def __post_init__(self): os.environ[f"{prefix}SYNC_MODULE_STATES"] = sync_module_states os.environ[f"{prefix}CPU_RAM_EFFICIENT_LOADING"] = cpu_ram_efficient_loading - os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true") + os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() if is_accelerate_available(): if not isinstance(self.accelerator_config, (AcceleratorConfig)):