Skip to content

Commit

Permalink
Fin refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Apr 25, 2024
1 parent 3aa1053 commit fdb147d
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2036,16 +2036,17 @@ def _setup_devices(self) -> "torch.device":
"Please run `pip install transformers[torch]` or `pip install accelerate -U`"
)
# We delay the init of `PartialState` to the end for clarity
accelerator_state_kwargs = {"enabled": True}
use_configured_accelerator_state = False
accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
if isinstance(self.accelerator_config, AcceleratorConfig):
use_configured_accelerator_state = self.accelerator_config.pop("use_configured_state", False)
if use_configured_accelerator_state and PartialState._shared_state == {}:
raise ValueError(
"Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured "
"`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. "
accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop(
"use_configured_state", False
)
if use_configured_accelerator_state:
if accelerator_state_kwargs["use_configured_state"]:
if PartialState._shared_state == {}:
raise ValueError(
"Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured "
"`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. "
)
self.distributed_state = PartialState()
else:
AcceleratorState._reset_state(reset_partial_state=True)
Expand All @@ -2072,10 +2073,12 @@ def _setup_devices(self) -> "torch.device":
accelerator_state_kwargs["backend"] = self.ddp_backend
accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)

accelerator_state_enabled = accelerator_state_kwargs.pop("enabled", False)
use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False)
if accelerator_state_enabled:
# Now we pop everything
if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop(
"use_configured_state", False
):
# We need to patch this env var when enabling to detect deepspeed
use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False)
if use_deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.distributed_state = PartialState(**accelerator_state_kwargs)
Expand Down

0 comments on commit fdb147d

Please sign in to comment.