diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f4a54ecc4dabbd..bbf5d4abf8a924 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -76,6 +76,7 @@ TrainerState, ) from .trainer_pt_utils import ( + AcceleratorConfig, DistributedTensorGatherer, IterableDatasetShard, LabelSmoother, @@ -4029,11 +4030,21 @@ def create_accelerator_and_postprocess(self): gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) # create accelerator object + accelerator_kwargs = {} + if self.args.accelerator_config is not None: + accelerator_kwargs = self.args.accelerator_config + # dict and AcceleratorConfigs are parseable, json files are not + if isinstance(accelerator_kwargs, AcceleratorConfig): + accelerator_kwargs = accelerator_kwargs.to_dict() + elif isinstance(accelerator_kwargs, dict): + # Some values may need to go through non-accelerate aligned defaults + # and we need to run the `__post_init__` to set them + accelerator_kwargs = AcceleratorConfig(**accelerator_kwargs).to_dict() + self.accelerator = Accelerator( - dispatch_batches=self.args.dispatch_batches, - split_batches=self.args.split_batches, deepspeed_plugin=self.args.deepspeed_plugin, gradient_accumulation_plugin=gradient_accumulation_plugin, + **accelerator_kwargs, ) # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag self.gather_function = self.accelerator.gather_for_metrics diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index b8dfb3124c5e9f..dce0eeaf818604 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -16,7 +16,9 @@ Torch utilities for the Trainer class. """ +import copy import datetime +import io import json import math import os @@ -24,7 +26,7 @@ import warnings from collections.abc import Mapping from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from logging import StreamHandler from typing import Any, Dict, Iterator, List, Optional, Union @@ -1140,3 +1142,87 @@ def smp_nested_concat(tensor): # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step` # which is also the name of the decorator so Python is confused. return tensor.concat().detach().cpu() + + +@dataclass +class AcceleratorConfig: + """ + A subset of arguments relating to the underlying [`accelerate.Accelerator`] + implementation utilized in the `Trainer` that can be customized. + Mostly relating to data. + + Parameters: + split_batches (`bool`, *optional*, defaults to `False`): + Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If + `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a + round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set + in your script multiplied by the number of processes. + dispatch_batches (`bool`, *optional*): + If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process + and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose + underlying dataset is an `IterableDataset`, `False` otherwise. + even_batches (`bool`, *optional*, defaults to `True`): + If set to `True`, in cases where the total batch size across all processes does not exactly divide the + dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among + all workers. + use_seedable_sampler (`bool`, *optional*, defaults to `True`): + Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures + training results are fully reproducable using a different sampling technique. While seed-to-seed results + may differ, on average the differences are neglible when using multiple different seeds to compare. Should + also be ran with [`~utils.set_seed`] for the best results. + + """ + + # Data related arguments + split_batches: bool = field( + default=False, + metadata={ + "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If" + " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a" + " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set" + " in your script multiplied by the number of processes." + }, + ) + dispatch_batches: bool = field( + default=None, + metadata={ + "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" + " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" + " underlying dataset is an `IterableDataslet`, `False` otherwise." + }, + ) + even_batches: bool = field( + default=True, + metadata={ + "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the" + " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among" + " all workers." + }, + ) + use_seedable_sampler: bool = field( + default=True, + metadata={ + "help": "Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`])." + "Ensures training results are fully reproducable using a different sampling technique. " + "While seed-to-seed results may differ, on average the differences are neglible when using" + "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." + }, + ) + + @classmethod + def from_json_file(cls, json_file): + # Check if exists + open_file = io.open if os.path.exists(json_file) else open + with open_file(json_file, "r", encoding="utf-8") as f: + config_dict = json.load(f) + # Check for keys and load sensible defaults + extra_keys = sorted(key for key in config_dict.keys() if key not in cls.__dataclass_fields__.keys()) + if len(extra_keys) > 0: + raise ValueError( + f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `transformers`" + " version or fix (and potentially remove these keys) from your config file." + ) + return cls(**config_dict) + + def to_dict(self): + return copy.deepcopy(self.__dict__) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 56f102396e0fe5..e51cf41106ee80 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -70,6 +70,8 @@ from accelerate.state import AcceleratorState, PartialState from accelerate.utils import DistributedType + from .trainer_pt_utils import AcceleratorConfig + if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm @@ -487,6 +489,32 @@ class TrainingArguments: Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may evolve in the future. The value is either the location of DeepSpeed json config file (e.g., `ds_config.json`) or an already loaded json file as a `dict`" + + accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*): + Config to be used with the internal `Accelerator` implementation. The value is either a location of + accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`, + or an instance of [`~trainer_pt_utils.AcceleratorConfig`]. + + A list of config and its options: + - split_batches (`bool`, *optional*, defaults to `False`): + Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If + `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a + round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set + in your script multiplied by the number of processes. + - dispatch_batches (`bool`, *optional*): + If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process + and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose + underlying dataset is an `IterableDataset`, `False` otherwise. + - even_batches (`bool`, *optional*, defaults to `True`): + If set to `True`, in cases where the total batch size across all processes does not exactly divide the + dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among + all workers. + - use_seedable_sampler (`bool`, *optional*, defaults to `True`): + Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures + training results are fully reproducable using a different sampling technique. While seed-to-seed results + may differ, on average the differences are neglible when using multiple different seeds to compare. Should + also be ran with [`~utils.set_seed`] for the best results. + label_smoothing_factor (`float`, *optional*, defaults to 0.0): The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor + @@ -1085,6 +1113,16 @@ class TrainingArguments: }, ) # Do not touch this type annotation or it will stop working in CLI + accelerator_config: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Config to be used with the internal Accelerator object initializtion. The value is either a " + "accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`." + ) + }, + ) + # Do not touch this type annotation or it will stop working in CLI deepspeed: Optional[str] = field( default=None, metadata={ @@ -1282,20 +1320,12 @@ class TrainingArguments: dispatch_batches: Optional[bool] = field( default=None, - metadata={ - "help": "Whether to dispatch batches across devices in distributed training. If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process " - "and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" - "underlying dataset is an `IterableDataset`, `False` otherwise." - }, + metadata={"help": "Deprecated. Pass {'dispatch_batches':VALUE} to `accelerator_config`."}, ) split_batches: Optional[bool] = field( - default=False, - metadata={ - "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices during distributed training. If" - "set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a" - "round multiple of the number of processes you are using (such as GPUs)." - }, + default=None, + metadata={"help": "Deprecated. Pass {'split_batches':True} to `accelerator_config`."}, ) include_tokens_per_second: Optional[bool] = field( @@ -1702,6 +1732,28 @@ def __post_init__(self): os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true") os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true") + if is_accelerate_available(): + if not isinstance(self.accelerator_config, (AcceleratorConfig, dict)): + if self.accelerator_config is None: + self.accelerator_config = AcceleratorConfig() + else: + self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config) + if self.dispatch_batches is not None: + warnings.warn( + "Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" + " `--accelerator_config {'dispatch_batches':VALUE} instead", + FutureWarning, + ) + self.accelerator_config["dispatch_batches"] = self.dispatch_batches + + if self.split_batches is not None: + warnings.warn( + "Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use" + " `--accelerator_config {'split_batches':VALUE} instead", + FutureWarning, + ) + self.accelerator_config["split_batches"] = self.split_batches + if self.tpu_metrics_debug: warnings.warn( "using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use" @@ -2156,6 +2208,9 @@ def to_dict(self): d[k] = [x.value for x in v] if k.endswith("_token"): d[k] = f"<{k.upper()}>" + # Handle the accelerator_config if passed + if is_accelerate_available() and isinstance(v, AcceleratorConfig): + d[k] = v.to_dict() return d def to_json_string(self): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2a098007852c87..530d98016142cb 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -118,6 +118,7 @@ TrainerState, ) from transformers.modeling_utils import unwrap_model + from transformers.trainer_pt_utils import AcceleratorConfig if is_safetensors_available(): import safetensors.torch @@ -2412,6 +2413,146 @@ def test_end_to_end_example(self): execute_subprocess_async(command) # successful return here == success - any errors would have caused an error or a timeout in the sub-call + def test_accelerator_config_empty(self): + # Checks that a config can be made with the defaults if not passed + with tempfile.TemporaryDirectory() as tmp_dir: + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() + + # Leaves one option as something *not* basic + args = RegressionTrainingArguments( + output_dir=tmp_dir, + ) + trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertEqual(trainer.accelerator.split_batches, False) + self.assertEqual(trainer.accelerator.dispatch_batches, None) + self.assertEqual(trainer.accelerator.even_batches, True) + self.assertEqual(trainer.accelerator.use_seedable_sampler, True) + + def test_accelerator_config_from_dict(self): + # Checks that accelerator kwargs can be passed through + # and the accelerator is initialized respectively + with tempfile.TemporaryDirectory() as tmp_dir: + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() + + # Leaves all options as something *not* basic + args = RegressionTrainingArguments( + output_dir=tmp_dir, + accelerator_config={ + "split_batches": True, + "dispatch_batches": True, + "even_batches": False, + "use_seedable_sampler": True, + }, + ) + trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertEqual(trainer.accelerator.split_batches, True) + self.assertEqual(trainer.accelerator.dispatch_batches, True) + self.assertEqual(trainer.accelerator.even_batches, False) + self.assertEqual(trainer.accelerator.use_seedable_sampler, True) + + def test_accelerator_config_from_yaml(self): + # Checks that accelerator kwargs can be passed through + # and the accelerator is initialized respectively + with tempfile.TemporaryDirectory() as tmp_dir: + path_file = Path(tmp_dir) / "accelerator_config.json" + with open(path_file, "w") as f: + accelerator_config = { + "split_batches": True, + "dispatch_batches": True, + "even_batches": False, + "use_seedable_sampler": False, + } + json.dump(accelerator_config, f) + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() + + # Leaves all options as something *not* basic + args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=path_file) + trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertEqual(trainer.accelerator.split_batches, True) + self.assertEqual(trainer.accelerator.dispatch_batches, True) + self.assertEqual(trainer.accelerator.even_batches, False) + self.assertEqual(trainer.accelerator.use_seedable_sampler, False) + + def test_accelerator_config_from_dataclass(self): + # Checks that accelerator kwargs can be passed through + # and the accelerator is initialized respectively + accelerator_config = AcceleratorConfig( + split_batches=True, dispatch_batches=True, even_batches=False, use_seedable_sampler=False + ) + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() + with tempfile.TemporaryDirectory() as tmp_dir: + args = RegressionTrainingArguments(output_dir=tmp_dir, accelerator_config=accelerator_config) + trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertEqual(trainer.accelerator.split_batches, True) + self.assertEqual(trainer.accelerator.dispatch_batches, True) + self.assertEqual(trainer.accelerator.even_batches, False) + self.assertEqual(trainer.accelerator.use_seedable_sampler, False) + + def test_accelerator_config_from_partial(self): + # Checks that accelerator kwargs can be passed through + # and the accelerator is initialized respectively + with tempfile.TemporaryDirectory() as tmp_dir: + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() + + # Leaves one option as something *not* basic + args = RegressionTrainingArguments( + output_dir=tmp_dir, + accelerator_config={ + "split_batches": True, + }, + ) + trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertEqual(trainer.accelerator.split_batches, True) + self.assertEqual(trainer.accelerator.dispatch_batches, None) + self.assertEqual(trainer.accelerator.even_batches, True) + self.assertEqual(trainer.accelerator.use_seedable_sampler, True) + + def test_accelerator_config_from_dict_with_deprecated_args(self): + # Checks that accelerator kwargs can be passed through + # and the accelerator is initialized respectively + # and maintains the deprecated args if passed in + with tempfile.TemporaryDirectory() as tmp_dir: + config = RegressionModelConfig(a=1.5, b=2.5) + model = RegressionPreTrainedModel(config) + eval_dataset = SampleIterableDataset() + + # Leaves all options as something *not* basic + with self.assertWarns(FutureWarning) as cm: + args = RegressionTrainingArguments( + output_dir=tmp_dir, + accelerator_config={ + "split_batches": True, + }, + dispatch_batches=False, + ) + self.assertIn("dispatch_batches", str(cm.warnings[0].message)) + trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertEqual(trainer.accelerator.dispatch_batches, False) + self.assertEqual(trainer.accelerator.split_batches, True) + with self.assertWarns(FutureWarning) as cm: + args = RegressionTrainingArguments( + output_dir=tmp_dir, + accelerator_config={ + "even_batches": False, + }, + split_batches=True, + ) + self.assertIn("split_batches", str(cm.warnings[0].message)) + trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset) + self.assertEqual(trainer.accelerator.split_batches, True) + self.assertEqual(trainer.accelerator.even_batches, False) + self.assertEqual(trainer.accelerator.dispatch_batches, None) + @require_torch @is_staging_test