Skip to content

Commit

Permalink
Introduce AcceleratorConfig dataclass (#28664)
Browse files Browse the repository at this point in the history
* Introduce acceleratorconfig dataclass

* Extra second warn

* Move import

* Try moving import under is_accelerate_available

* Quality

* Apply suggestions from code review

Co-authored-by: amyeroberts <[email protected]>

* Clean

* Remove to_kwargs

* Change version

* Improve tests by including dispatch and split batches

* Improve reliability

* Update tests/trainer/test_trainer.py

Co-authored-by: amyeroberts <[email protected]>

* Fixup tests and review nits

* Make tests pass

* protect import

* Protect import

* Empty-Commit

* Make training_args.to_dict handle the AcceleratorConfig

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
muellerzr and amyeroberts authored Feb 14, 2024
1 parent 69ca640 commit 0507e69
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 14 deletions.
15 changes: 13 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
TrainerState,
)
from .trainer_pt_utils import (
AcceleratorConfig,
DistributedTensorGatherer,
IterableDatasetShard,
LabelSmoother,
Expand Down Expand Up @@ -4029,11 +4030,21 @@ def create_accelerator_and_postprocess(self):
gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)

# create accelerator object
accelerator_kwargs = {}
if self.args.accelerator_config is not None:
accelerator_kwargs = self.args.accelerator_config
# dict and AcceleratorConfigs are parseable, json files are not
if isinstance(accelerator_kwargs, AcceleratorConfig):
accelerator_kwargs = accelerator_kwargs.to_dict()
elif isinstance(accelerator_kwargs, dict):
# Some values may need to go through non-accelerate aligned defaults
# and we need to run the `__post_init__` to set them
accelerator_kwargs = AcceleratorConfig(**accelerator_kwargs).to_dict()

self.accelerator = Accelerator(
dispatch_batches=self.args.dispatch_batches,
split_batches=self.args.split_batches,
deepspeed_plugin=self.args.deepspeed_plugin,
gradient_accumulation_plugin=gradient_accumulation_plugin,
**accelerator_kwargs,
)
# some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
self.gather_function = self.accelerator.gather_for_metrics
Expand Down
88 changes: 87 additions & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
Torch utilities for the Trainer class.
"""

import copy
import datetime
import io
import json
import math
import os
import sys
import warnings
from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from logging import StreamHandler
from typing import Any, Dict, Iterator, List, Optional, Union

Expand Down Expand Up @@ -1140,3 +1142,87 @@ def smp_nested_concat(tensor):
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
# which is also the name of the decorator so Python is confused.
return tensor.concat().detach().cpu()


@dataclass
class AcceleratorConfig:
"""
A subset of arguments relating to the underlying [`accelerate.Accelerator`]
implementation utilized in the `Trainer` that can be customized.
Mostly relating to data.
Parameters:
split_batches (`bool`, *optional*, defaults to `False`):
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
`True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
in your script multiplied by the number of processes.
dispatch_batches (`bool`, *optional*):
If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
underlying dataset is an `IterableDataset`, `False` otherwise.
even_batches (`bool`, *optional*, defaults to `True`):
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
all workers.
use_seedable_sampler (`bool`, *optional*, defaults to `True`):
Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
"""

# Data related arguments
split_batches: bool = field(
default=False,
metadata={
"help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If"
" `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a"
" round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set"
" in your script multiplied by the number of processes."
},
)
dispatch_batches: bool = field(
default=None,
metadata={
"help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
" and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
" underlying dataset is an `IterableDataslet`, `False` otherwise."
},
)
even_batches: bool = field(
default=True,
metadata={
"help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the"
" dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among"
" all workers."
},
)
use_seedable_sampler: bool = field(
default=True,
metadata={
"help": "Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`])."
"Ensures training results are fully reproducable using a different sampling technique. "
"While seed-to-seed results may differ, on average the differences are neglible when using"
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
},
)

@classmethod
def from_json_file(cls, json_file):
# Check if exists
open_file = io.open if os.path.exists(json_file) else open
with open_file(json_file, "r", encoding="utf-8") as f:
config_dict = json.load(f)
# Check for keys and load sensible defaults
extra_keys = sorted(key for key in config_dict.keys() if key not in cls.__dataclass_fields__.keys())
if len(extra_keys) > 0:
raise ValueError(
f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `transformers`"
" version or fix (and potentially remove these keys) from your config file."
)
return cls(**config_dict)

def to_dict(self):
return copy.deepcopy(self.__dict__)
77 changes: 66 additions & 11 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import DistributedType

from .trainer_pt_utils import AcceleratorConfig

if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -487,6 +489,32 @@ class TrainingArguments:
Use [Deepspeed](https://github.com/microsoft/deepspeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
`ds_config.json`) or an already loaded json file as a `dict`"
accelerator_config (`str`, `dict`, or `AcceleratorConfig`, *optional*):
Config to be used with the internal `Accelerator` implementation. The value is either a location of
accelerator json config file (e.g., `accelerator_config.json`), an already loaded json file as `dict`,
or an instance of [`~trainer_pt_utils.AcceleratorConfig`].
A list of config and its options:
- split_batches (`bool`, *optional*, defaults to `False`):
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
`True` the actual batch size used will be the same on any kind of distributed processes, but it must be a
round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set
in your script multiplied by the number of processes.
- dispatch_batches (`bool`, *optional*):
If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
underlying dataset is an `IterableDataset`, `False` otherwise.
- even_batches (`bool`, *optional*, defaults to `True`):
If set to `True`, in cases where the total batch size across all processes does not exactly divide the
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
all workers.
- use_seedable_sampler (`bool`, *optional*, defaults to `True`):
Whether or not use a fully seedable random sampler ([`accelerate.data_loader.SeedableRandomSampler`]). Ensures
training results are fully reproducable using a different sampling technique. While seed-to-seed results
may differ, on average the differences are neglible when using multiple different seeds to compare. Should
also be ran with [`~utils.set_seed`] for the best results.
label_smoothing_factor (`float`, *optional*, defaults to 0.0):
The label smoothing factor to use. Zero means no label smoothing, otherwise the underlying onehot-encoded
labels are changed from 0s and 1s to `label_smoothing_factor/num_labels` and `1 - label_smoothing_factor +
Expand Down Expand Up @@ -1085,6 +1113,16 @@ class TrainingArguments:
},
)
# Do not touch this type annotation or it will stop working in CLI
accelerator_config: Optional[str] = field(
default=None,
metadata={
"help": (
"Config to be used with the internal Accelerator object initializtion. The value is either a "
"accelerator json config file (e.g., `accelerator_config.json`) or an already loaded json file as `dict`."
)
},
)
# Do not touch this type annotation or it will stop working in CLI
deepspeed: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -1282,20 +1320,12 @@ class TrainingArguments:

dispatch_batches: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to dispatch batches across devices in distributed training. If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process "
"and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
"underlying dataset is an `IterableDataset`, `False` otherwise."
},
metadata={"help": "Deprecated. Pass {'dispatch_batches':VALUE} to `accelerator_config`."},
)

split_batches: Optional[bool] = field(
default=False,
metadata={
"help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices during distributed training. If"
"set to `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a"
"round multiple of the number of processes you are using (such as GPUs)."
},
default=None,
metadata={"help": "Deprecated. Pass {'split_batches':True} to `accelerator_config`."},
)

include_tokens_per_second: Optional[bool] = field(
Expand Down Expand Up @@ -1702,6 +1732,28 @@ def __post_init__(self):
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")

if is_accelerate_available():
if not isinstance(self.accelerator_config, (AcceleratorConfig, dict)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
if self.dispatch_batches is not None:
warnings.warn(
"Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'dispatch_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config["dispatch_batches"] = self.dispatch_batches

if self.split_batches is not None:
warnings.warn(
"Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'split_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config["split_batches"] = self.split_batches

if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down Expand Up @@ -2156,6 +2208,9 @@ def to_dict(self):
d[k] = [x.value for x in v]
if k.endswith("_token"):
d[k] = f"<{k.upper()}>"
# Handle the accelerator_config if passed
if is_accelerate_available() and isinstance(v, AcceleratorConfig):
d[k] = v.to_dict()
return d

def to_json_string(self):
Expand Down
Loading

0 comments on commit 0507e69

Please sign in to comment.