Skip to content

Commit

Permalink
Introduce configured_state arg for accelerator_config (#29781)
Browse files Browse the repository at this point in the history
* Introduce configured_state

* Include note on tuning

* Allow for users to have defined a state already

* Include tests

* Add note on hpam tune

* Guard a bit better

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/training_args.py

Co-authored-by: amyeroberts <[email protected]>

* Finish rebase

* Finish rebase

* Guard carefully

* Fixup test

* Refactor

* Fin refactor

* Comment

* Update wrt feedback

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 24, 2024
1 parent 028aa58 commit 88cc715
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 52 deletions.
14 changes: 14 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,10 @@ class AcceleratorConfig:
Whether to use non-blocking CUDA calls to help minimize synchronization during
distributed training with prepared `DataLoader` inputs being moved to device.
Best if used with `pin_memory=True` in the `TrainingArguments`.
use_configured_state (`bool*, *optional*, defaults to `False`):
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined
before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState`
must be initialized. May lead to issues using sweeps or hyperparameter tuning.
"""

Expand Down Expand Up @@ -1312,6 +1316,13 @@ class AcceleratorConfig:
" The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`."
},
)
use_configured_state: bool = field(
default=False,
metadata={
"help": "Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`."
"If `True`, an `Accelerator` or `PartialState` must be initialized. May lead to issues using sweeps or hyperparameter tuning."
},
)

@classmethod
def from_json_file(cls, json_file):
Expand All @@ -1331,6 +1342,9 @@ def from_json_file(cls, json_file):
def to_dict(self):
return copy.deepcopy(self.__dict__)

def pop(self, key, default=None):
return self.__dict__.pop(key, default)


class LayerWiseDummyOptimizer(torch.optim.Optimizer):
"""
Expand Down
134 changes: 82 additions & 52 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,10 @@ class TrainingArguments:
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
- use_configured_state (`bool`, *optional*, defaults to `False`):
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`.
If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues
with hyperparameter tuning.
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
Expand Down Expand Up @@ -1635,6 +1639,39 @@ def __post_init__(self):
if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16:
raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0")

# We need to setup the accelerator config here *before* the first call to `self.device`
if is_accelerate_available():
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif isinstance(self.accelerator_config, type):
raise NotImplementedError(
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)

if self.dispatch_batches is not None:
warnings.warn(
"Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'dispatch_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.dispatch_batches = self.dispatch_batches

if self.split_batches is not None:
warnings.warn(
"Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'split_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.split_batches = self.split_batches

if (
self.framework == "pt"
and is_torch_available()
Expand Down Expand Up @@ -1873,37 +1910,6 @@ def __post_init__(self):

os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower()

if is_accelerate_available():
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif isinstance(self.accelerator_config, type):
raise NotImplementedError(
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
if self.dispatch_batches is not None:
warnings.warn(
"Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'dispatch_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.dispatch_batches = self.dispatch_batches

if self.split_batches is not None:
warnings.warn(
"Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'split_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.split_batches = self.split_batches

if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down Expand Up @@ -2056,32 +2062,62 @@ def _setup_devices(self) -> "torch.device":
f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
"Please run `pip install transformers[torch]` or `pip install accelerate -U`"
)
# We delay the init of `PartialState` to the end for clarity
accelerator_state_kwargs = {"enabled": True, "use_configured_state": False}
if isinstance(self.accelerator_config, AcceleratorConfig):
accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop(
"use_configured_state", False
)
if accelerator_state_kwargs["use_configured_state"]:
if PartialState._shared_state == {}:
raise ValueError(
"Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured "
"`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. "
)
# We rely on `PartialState` to yell if there's issues here (which it will)
self.distributed_state = PartialState(cpu=self.use_cpu)
if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED:
raise RuntimeError(
"Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, "
"but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set "
"`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly."
)
else:
AcceleratorState._reset_state(reset_partial_state=True)
self.distributed_state = None
self.distributed_state = None
if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ:
os.environ["ACCELERATE_USE_IPEX"] = "false"

self._n_gpu = 1
if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")):
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)
accelerator_state_kwargs["cpu"] = True
accelerator_state_kwargs["backend"] = self.ddp_backend
self._n_gpu = 0
elif is_sagemaker_mp_enabled():
accelerator_state_kwargs["enabled"] = False
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
self._n_gpu = 1
torch.cuda.set_device(device)
elif is_sagemaker_dp_enabled():
self.distributed_state = PartialState(_use_sagemaker_dp=True)
self._n_gpu = 1
accelerator_state_kwargs["_use_sagemaker_dp"] = True
elif self.deepspeed:
# Need to do similar for Accelerator init
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
del os.environ["ACCELERATE_USE_DEEPSPEED"]
self._n_gpu = 1
accelerator_state_kwargs["use_deepspeed"] = True
accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
else:
self.distributed_state = PartialState(
backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout)
)
self._n_gpu = 1
accelerator_state_kwargs["backend"] = self.ddp_backend
accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)

# Now we pop everything
if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop(
"use_configured_state", False
):
# We need to patch this env var when enabling to detect deepspeed
use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False)
if use_deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.distributed_state = PartialState(**accelerator_state_kwargs)
if use_deepspeed:
del os.environ["ACCELERATE_USE_DEEPSPEED"]
if not is_sagemaker_mp_enabled():
device = self.distributed_state.device
self.local_rank = self.distributed_state.local_process_index
Expand All @@ -2108,23 +2144,17 @@ def _setup_devices(self) -> "torch.device":
"Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ "
"or current PyTorch install was not built with MPS enabled."
)
if device.type == "mps":
self._n_gpu = 1
elif self.use_cpu:
if self.use_cpu:
device = torch.device("cpu")
self._n_gpu = 0
elif is_torch_xpu_available():
device = torch.device("xpu:0")
torch.xpu.set_device(device)
self._n_gpu = 1
elif is_torch_mlu_available():
device = torch.device("mlu:0")
torch.mlu.set_device(device)
self._n_gpu = 1
elif is_torch_npu_available():
device = torch.device("npu:0")
torch.npu.set_device(device)
self._n_gpu = 1
else:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
Expand Down
14 changes: 14 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@
# for version specific tests in TrainerIntegrationTest
require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28")
GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28")
if is_accelerate_available():
from accelerate import Accelerator
from accelerate.state import AcceleratorState


PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt"

Expand Down Expand Up @@ -3266,6 +3270,16 @@ def test_accelerator_config_only_deprecated_args(self):
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset)
self.assertEqual(trainer.accelerator.split_batches, True)

def test_accelerator_custom_state(self):
AcceleratorState._reset_state(reset_partial_state=True)
with tempfile.TemporaryDirectory() as tmp_dir:
with self.assertRaises(ValueError) as cm:
_ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True})
self.assertIn("Please define this beforehand", str(cm.warnings[0].message))
_ = Accelerator()
_ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True})
AcceleratorState._reset_state(reset_partial_state=True)

@require_accelerate_version_min_0_28
def test_accelerator_config_from_dict_grad_accum_num_steps(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down

0 comments on commit 88cc715

Please sign in to comment.