diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 9defa91b2b8bc8..8ac0281912ce19 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -1250,6 +1250,10 @@ class AcceleratorConfig: Whether to use non-blocking CUDA calls to help minimize synchronization during distributed training with prepared `DataLoader` inputs being moved to device. Best if used with `pin_memory=True` in the `TrainingArguments`. + use_configured_state (`bool*, *optional*, defaults to `False`): + Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined + before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState` + must be initialized. May lead to issues using sweeps or hyperparameter tuning. """ @@ -1312,6 +1316,13 @@ class AcceleratorConfig: " The [`accelerate.utils.GradientAccumulationPlugin`] default is `False`." }, ) + use_configured_state: bool = field( + default=False, + metadata={ + "help": "Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`." + "If `True`, an `Accelerator` or `PartialState` must be initialized. May lead to issues using sweeps or hyperparameter tuning." + }, + ) @classmethod def from_json_file(cls, json_file): @@ -1331,6 +1342,9 @@ def from_json_file(cls, json_file): def to_dict(self): return copy.deepcopy(self.__dict__) + def pop(self, key, default=None): + return self.__dict__.pop(key, default) + class LayerWiseDummyOptimizer(torch.optim.Optimizer): """ diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 2807c9951aa6d6..08d9000cb6258c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -572,6 +572,10 @@ class TrainingArguments: training results are fully reproducable using a different sampling technique. While seed-to-seed results may differ, on average the differences are neglible when using multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results. + - use_configured_state (`bool`, *optional*, defaults to `False`): + Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined before calling `TrainingArguments`. + If `True`, an `Accelerator` or `PartialState` must be initialized. Note that by doing so, this could lead to issues + with hyperparameter tuning. label_smoothing_factor (`float`, *optional*, defaults to 0.0): The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded @@ -1635,6 +1639,39 @@ def __post_init__(self): if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16: raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0") + # We need to setup the accelerator config here *before* the first call to `self.device` + if is_accelerate_available(): + if not isinstance(self.accelerator_config, (AcceleratorConfig)): + if self.accelerator_config is None: + self.accelerator_config = AcceleratorConfig() + elif isinstance(self.accelerator_config, dict): + self.accelerator_config = AcceleratorConfig(**self.accelerator_config) + # Check that a user didn't pass in the class instantiator + # such as `accelerator_config = AcceleratorConfig` + elif isinstance(self.accelerator_config, type): + raise NotImplementedError( + "Tried passing in a callable to `accelerator_config`, but this is not supported. " + "Please pass in a fully constructed `AcceleratorConfig` object instead." + ) + else: + self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) + + if self.dispatch_batches is not None: + warnings.warn( + "Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" + " `--accelerator_config {'dispatch_batches':VALUE} instead", + FutureWarning, + ) + self.accelerator_config.dispatch_batches = self.dispatch_batches + + if self.split_batches is not None: + warnings.warn( + "Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" + " `--accelerator_config {'split_batches':VALUE} instead", + FutureWarning, + ) + self.accelerator_config.split_batches = self.split_batches + if ( self.framework == "pt" and is_torch_available() @@ -1873,37 +1910,6 @@ def __post_init__(self): os.environ[f"{prefix}USE_ORIG_PARAMS"] = str(self.fsdp_config.get("use_orig_params", "true")).lower() - if is_accelerate_available(): - if not isinstance(self.accelerator_config, (AcceleratorConfig)): - if self.accelerator_config is None: - self.accelerator_config = AcceleratorConfig() - elif isinstance(self.accelerator_config, dict): - self.accelerator_config = AcceleratorConfig(**self.accelerator_config) - # Check that a user didn't pass in the class instantiator - # such as `accelerator_config = AcceleratorConfig` - elif isinstance(self.accelerator_config, type): - raise NotImplementedError( - "Tried passing in a callable to `accelerator_config`, but this is not supported. " - "Please pass in a fully constructed `AcceleratorConfig` object instead." - ) - else: - self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) - if self.dispatch_batches is not None: - warnings.warn( - "Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" - " `--accelerator_config {'dispatch_batches':VALUE} instead", - FutureWarning, - ) - self.accelerator_config.dispatch_batches = self.dispatch_batches - - if self.split_batches is not None: - warnings.warn( - "Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" - " `--accelerator_config {'split_batches':VALUE} instead", - FutureWarning, - ) - self.accelerator_config.split_batches = self.split_batches - if self.tpu_metrics_debug: warnings.warn( "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" @@ -2056,32 +2062,62 @@ def _setup_devices(self) -> "torch.device": f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: " "Please run `pip install transformers[torch]` or `pip install accelerate -U`" ) + # We delay the init of `PartialState` to the end for clarity + accelerator_state_kwargs = {"enabled": True, "use_configured_state": False} + if isinstance(self.accelerator_config, AcceleratorConfig): + accelerator_state_kwargs["use_configured_state"] = self.accelerator_config.pop( + "use_configured_state", False + ) + if accelerator_state_kwargs["use_configured_state"]: + if PartialState._shared_state == {}: + raise ValueError( + "Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured " + "`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. " + ) + # We rely on `PartialState` to yell if there's issues here (which it will) + self.distributed_state = PartialState(cpu=self.use_cpu) + if self.deepspeed and self.distributed_state.distributed_type != DistributedType.DEEPSPEED: + raise RuntimeError( + "Tried to use an already configured `Accelerator` or `PartialState` that was not initialized for DeepSpeed, " + "but also passed in a `deepspeed` configuration to the `TrainingArguments`. Please set " + "`use_configured_state:False` instead or setup your `Accelerator` or `PartialState` properly." + ) + else: AcceleratorState._reset_state(reset_partial_state=True) - self.distributed_state = None + self.distributed_state = None if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ: os.environ["ACCELERATE_USE_IPEX"] = "false" + + self._n_gpu = 1 if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): - self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend) + accelerator_state_kwargs["cpu"] = True + accelerator_state_kwargs["backend"] = self.ddp_backend self._n_gpu = 0 elif is_sagemaker_mp_enabled(): + accelerator_state_kwargs["enabled"] = False local_rank = smp.local_rank() device = torch.device("cuda", local_rank) - self._n_gpu = 1 torch.cuda.set_device(device) elif is_sagemaker_dp_enabled(): - self.distributed_state = PartialState(_use_sagemaker_dp=True) - self._n_gpu = 1 + accelerator_state_kwargs["_use_sagemaker_dp"] = True elif self.deepspeed: - # Need to do similar for Accelerator init - os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" - self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) - del os.environ["ACCELERATE_USE_DEEPSPEED"] - self._n_gpu = 1 + accelerator_state_kwargs["use_deepspeed"] = True + accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) else: - self.distributed_state = PartialState( - backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout) - ) - self._n_gpu = 1 + accelerator_state_kwargs["backend"] = self.ddp_backend + accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout) + + # Now we pop everything + if accelerator_state_kwargs.pop("enabled", False) and not accelerator_state_kwargs.pop( + "use_configured_state", False + ): + # We need to patch this env var when enabling to detect deepspeed + use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False) + if use_deepspeed: + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" + self.distributed_state = PartialState(**accelerator_state_kwargs) + if use_deepspeed: + del os.environ["ACCELERATE_USE_DEEPSPEED"] if not is_sagemaker_mp_enabled(): device = self.distributed_state.device self.local_rank = self.distributed_state.local_process_index @@ -2108,23 +2144,17 @@ def _setup_devices(self) -> "torch.device": "Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ " "or current PyTorch install was not built with MPS enabled." ) - if device.type == "mps": - self._n_gpu = 1 - elif self.use_cpu: + if self.use_cpu: device = torch.device("cpu") - self._n_gpu = 0 elif is_torch_xpu_available(): device = torch.device("xpu:0") torch.xpu.set_device(device) - self._n_gpu = 1 elif is_torch_mlu_available(): device = torch.device("mlu:0") torch.mlu.set_device(device) - self._n_gpu = 1 elif is_torch_npu_available(): device = torch.device("npu:0") torch.npu.set_device(device) - self._n_gpu = 1 else: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c420da4052f186..1711f600cebbc8 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -131,6 +131,10 @@ # for version specific tests in TrainerIntegrationTest require_accelerate_version_min_0_28 = partial(require_accelerate, min_version="0.28") GRAD_ACCUM_KWARGS_VERSION_AVAILABLE = is_accelerate_available("0.28") +if is_accelerate_available(): + from accelerate import Accelerator + from accelerate.state import AcceleratorState + PATH_SAMPLE_TEXT = f"{get_tests_dir()}/fixtures/sample_text.txt" @@ -3266,6 +3270,16 @@ def test_accelerator_config_only_deprecated_args(self): trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) self.assertEqual(trainer.accelerator.split_batches, True) + def test_accelerator_custom_state(self): + AcceleratorState._reset_state(reset_partial_state=True) + with tempfile.TemporaryDirectory() as tmp_dir: + with self.assertRaises(ValueError) as cm: + _ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True}) + self.assertIn("Please define this beforehand", str(cm.warnings[0].message)) + _ = Accelerator() + _ = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config={"use_configured_state": True}) + AcceleratorState._reset_state(reset_partial_state=True) + @require_accelerate_version_min_0_28 def test_accelerator_config_from_dict_grad_accum_num_steps(self): with tempfile.TemporaryDirectory() as tmp_dir: