Skip to content

Commit

Permalink
Update wrt feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed May 6, 2024
1 parent 257e47a commit b5fd489
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,10 @@ class AcceleratorConfig:
Whether to use non-blocking CUDA calls to help minimize synchronization during
distributed training with prepared `DataLoader` inputs being moved to device.
Best if used with `pin_memory=True` in the `TrainingArguments`.
use_configured_state (`bool*, *optional*, defaults to `False`):
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined
before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState`
must be initialized. May lead to issues using sweeps or hyperparameter tuning.
"""

Expand Down
9 changes: 8 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,7 +2067,14 @@ def _setup_devices(self) -> "torch.device":
"Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured "
"`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. "
)
self.distributed_state = PartialState()
# We rely on `PartialState` to yell if there's issues here (which it will)
self.distributed_state = PartialState(cpu=self.use_cpu)
if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED:
raise RuntimeError(
"Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, "
"but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set "
"`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly."
)
else:
AcceleratorState._reset_state(reset_partial_state=True)
self.distributed_state = None
Expand Down

0 comments on commit b5fd489

Please sign in to comment.