Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Apr 25, 2024
1 parent 1988035 commit 3aa1053
Showing 1 changed file with 69 additions and 63 deletions.
132 changes: 69 additions & 63 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,39 @@ def __post_init__(self):
if version.parse(version.parse(torch.__version__).base_version) == version.parse("2.0.0") and self.fp16:
raise ValueError("--optim adamw_torch_fused with --fp16 requires PyTorch>2.0")

# We need to setup the accelerator config here
if is_accelerate_available():
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif isinstance(self.accelerator_config, type):
raise NotImplementedError(
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)

if self.dispatch_batches is not None:
warnings.warn(
"Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'dispatch_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.dispatch_batches = self.dispatch_batches

if self.split_batches is not None:
warnings.warn(
"Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'split_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.split_batches = self.split_batches

if (
self.framework == "pt"
and is_torch_available()
Expand Down Expand Up @@ -1850,37 +1883,6 @@ def __post_init__(self):

os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "true")

if is_accelerate_available():
if not isinstance(self.accelerator_config, (AcceleratorConfig)):
if self.accelerator_config is None:
self.accelerator_config = AcceleratorConfig()
elif isinstance(self.accelerator_config, dict):
self.accelerator_config = AcceleratorConfig(**self.accelerator_config)
# Check that a user didn't pass in the class instantiator
# such as `accelerator_config = AcceleratorConfig`
elif isinstance(self.accelerator_config, type):
raise NotImplementedError(
"Tried passing in a callable to `accelerator_config`, but this is not supported. "
"Please pass in a fully constructed `AcceleratorConfig` object instead."
)
else:
self.accelerator_config = AcceleratorConfig.from_json_file(self.accelerator_config)
if self.dispatch_batches is not None:
warnings.warn(
"Using `--dispatch_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'dispatch_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.dispatch_batches = self.dispatch_batches

if self.split_batches is not None:
warnings.warn(
"Using `--split_batches` is deprecated and will be removed in version 4.41 of 🤗 Transformers. Use"
" `--accelerator_config {'split_batches':VALUE} instead",
FutureWarning,
)
self.accelerator_config.split_batches = self.split_batches

if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down Expand Up @@ -2033,42 +2035,52 @@ def _setup_devices(self) -> "torch.device":
f"Using the `Trainer` with `PyTorch` requires `accelerate>={ACCELERATE_MIN_VERSION}`: "
"Please run `pip install transformers[torch]` or `pip install accelerate -U`"
)
use_configured_state = False
if isinstance(self.accelerator_config, AcceleratorConfig):
use_configured_state = self.accelerator_config.pop("use_configured_state", False)
if use_configured_state:
if AcceleratorState._shared_state == {}:
raise ValueError(
"Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured "
"`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. "
)
else:
AcceleratorState._reset_state(reset_partial_state=True)
self.distributed_state = None
# We delay the init of `PartialState` to the end for clarity
accelerator_state_kwargs = {"enabled": True}
use_configured_accelerator_state = False
if isinstance(self.accelerator_config, AcceleratorConfig):
use_configured_accelerator_state = self.accelerator_config.pop("use_configured_state", False)
if use_configured_accelerator_state and PartialState._shared_state == {}:
raise ValueError(
"Passing `'use_configured_state':True` to the AcceleratorConfig requires a pre-configured "
"`AcceleratorState` or `PartialState` to be defined before calling `TrainingArguments`. "
)
if use_configured_accelerator_state:
self.distributed_state = PartialState()
else:
AcceleratorState._reset_state(reset_partial_state=True)
self.distributed_state = None
if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ:
os.environ["ACCELERATE_USE_IPEX"] = "false"

self._n_gpu = 1
if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")):
self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend)
accelerator_state_kwargs["cpu"] = True
accelerator_state_kwargs["backend"] = self.ddp_backend
self._n_gpu = 0
elif is_sagemaker_mp_enabled():
accelerator_state_kwargs["enabled"] = False
local_rank = smp.local_rank()
device = torch.device("cuda", local_rank)
self._n_gpu = 1
torch.cuda.set_device(device)
elif is_sagemaker_dp_enabled():
self.distributed_state = PartialState(_use_sagemaker_dp=True)
self._n_gpu = 1
accelerator_state_kwargs["_use_sagemaker_dp"] = True
elif self.deepspeed:
# Need to do similar for Accelerator init
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
del os.environ["ACCELERATE_USE_DEEPSPEED"]
self._n_gpu = 1
accelerator_state_kwargs["use_deepspeed"] = True
accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)
else:
self.distributed_state = PartialState(
backend=self.ddp_backend, timeout=timedelta(seconds=self.ddp_timeout)
)
self._n_gpu = 1
accelerator_state_kwargs["backend"] = self.ddp_backend
accelerator_state_kwargs["timeout"] = timedelta(seconds=self.ddp_timeout)

accelerator_state_enabled = accelerator_state_kwargs.pop("enabled", False)
use_deepspeed = accelerator_state_kwargs.pop("use_deepspeed", False)
if accelerator_state_enabled:
# We need to patch this env var when enabling to detect deepspeed
if use_deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
self.distributed_state = PartialState(**accelerator_state_kwargs)
if use_deepspeed:
del os.environ["ACCELERATE_USE_DEEPSPEED"]
if not is_sagemaker_mp_enabled():
device = self.distributed_state.device
self.local_rank = self.distributed_state.local_process_index
Expand All @@ -2095,23 +2107,17 @@ def _setup_devices(self) -> "torch.device":
"Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ "
"or current PyTorch install was not built with MPS enabled."
)
if device.type == "mps":
self._n_gpu = 1
elif self.use_cpu:
if self.use_cpu:
device = torch.device("cpu")
self._n_gpu = 0
elif is_torch_xpu_available():
device = torch.device("xpu:0")
torch.xpu.set_device(device)
self._n_gpu = 1
elif is_torch_mlu_available():
device = torch.device("mlu:0")
torch.mlu.set_device(device)
self._n_gpu = 1
elif is_torch_npu_available():
device = torch.device("npu:0")
torch.npu.set_device(device)
self._n_gpu = 1
else:
# if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
Expand Down

0 comments on commit 3aa1053

Please sign in to comment.