Skip to content

Commit

Permalink
Support DeepSpeed when using auto find batch size (#28088)
Browse files Browse the repository at this point in the history
Fixup test
  • Loading branch information
muellerzr authored Jan 10, 2024
1 parent a777f52 commit 6015d0a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 17 deletions.
13 changes: 10 additions & 3 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):

fill_only = partialmethod(fill_match, must_match=False)

def trainer_config_process(self, args):
def trainer_config_process(self, args, auto_find_batch_size=False):
"""
Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object
creation.
Expand All @@ -138,10 +138,15 @@ def trainer_config_process(self, args):
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
self.fill_match(
"train_micro_batch_size_per_gpu", args.per_device_train_batch_size, "per_device_train_batch_size"
"train_micro_batch_size_per_gpu",
args.per_device_train_batch_size,
"per_device_train_batch_size",
not auto_find_batch_size,
)
self.fill_match("gradient_accumulation_steps", args.gradient_accumulation_steps, "gradient_accumulation_steps")
self.fill_match("train_batch_size", train_batch_size, "train_batch_size (calculated)")
self.fill_match(
"train_batch_size", train_batch_size, "train_batch_size (calculated)", not auto_find_batch_size
)
self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm")

self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate")
Expand Down Expand Up @@ -336,6 +341,8 @@ def deepspeed_init(trainer, num_training_steps, inference=False):
num_training_steps: per single gpu
resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
inference: launch in inference mode (no optimizer and no lr scheduler)
auto_find_batch_size: whether to ignore the `train_micro_batch_size_per_gpu` argument as it's being
set automatically by the auto batch size finder
Returns: optimizer, lr_scheduler
Expand Down
42 changes: 28 additions & 14 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,13 +1506,9 @@ def train(
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

if (
resume_from_checkpoint is not None
and not is_sagemaker_mp_enabled()
and not self.is_deepspeed_enabled
and not self.is_fsdp_enabled
):
self._load_from_checkpoint(resume_from_checkpoint)
if resume_from_checkpoint is not None:
if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled:
self._load_from_checkpoint(resume_from_checkpoint)
# In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly
state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
if state.train_batch_size is not None:
Expand Down Expand Up @@ -1553,6 +1549,19 @@ def _inner_training_loop(
self.accelerator.free_memory()
self._train_batch_size = batch_size
if self.args.auto_find_batch_size:
if self.state.train_batch_size != self._train_batch_size:
from accelerate.utils import release_memory

(self.model_wrapped,) = release_memory(self.model_wrapped)
self.model_wrapped = self.model

# Check for DeepSpeed *after* the intial pass and modify the config
if self.is_deepspeed_enabled:
# Temporarily unset `self.args.train_batch_size`
original_bs = self.args.per_device_train_batch_size
self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
self.propagate_args_to_deepspeed(True)
self.args.per_device_train_batch_size = original_bs
self.state.train_batch_size = self._train_batch_size
logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
# Data loader and number of training steps
Expand Down Expand Up @@ -3944,12 +3953,17 @@ def create_accelerator_and_postprocess(self):
"when using FSDP."
)

if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig
if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
self.propagate_args_to_deepspeed()

def propagate_args_to_deepspeed(self, auto_find_batch_size=False):
"""
Sets values in the deepspeed plugin based on the Trainer args
"""
from transformers.integrations.deepspeed import HfTrainerDeepSpeedConfig

ds_plugin = self.accelerator.state.deepspeed_plugin
ds_plugin = self.accelerator.state.deepspeed_plugin

ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args)
ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
ds_plugin.hf_ds_config.trainer_config_process(self.args, auto_find_batch_size)
46 changes: 46 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
get_tests_dir,
is_staging_test,
require_accelerate,
require_deepspeed,
require_intel_extension_for_pytorch,
require_optuna,
require_ray,
Expand Down Expand Up @@ -1551,6 +1552,51 @@ def test_auto_batch_size_finder(self):
with patch.object(sys, "argv", testargs):
run_glue.main()

@require_deepspeed
def test_auto_batch_size_with_resume_from_checkpoint_with_deepspeed(self):
train_dataset = RegressionDataset(length=128)

config = RegressionModelConfig(a=0, b=2)
model = RegressionRandomPreTrainedModel(config)

tmp_dir = self.get_auto_remove_tmp_dir()

class MockCudaOOMCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
# simulate OOM on the first step
if state.train_batch_size >= 16:
raise RuntimeError("CUDA out of memory.")

deepspeed = {
"zero_optimization": {
"stage": 1,
},
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
}

args = RegressionTrainingArguments(
tmp_dir,
do_train=True,
max_steps=2,
save_steps=1,
per_device_train_batch_size=16,
auto_find_batch_size=True,
deepspeed=deepspeed,
)
trainer = Trainer(model, args, train_dataset=train_dataset, callbacks=[MockCudaOOMCallback()])
trainer.train()
# After `auto_find_batch_size` is ran we should now be at 8
self.assertEqual(trainer._train_batch_size, 8)

# We can then make a new Trainer
trainer = Trainer(model, args, train_dataset=train_dataset)
# Check we are at 16 to start
self.assertEqual(trainer._train_batch_size, 16 * max(trainer.args.n_gpu, 1))
trainer.train(resume_from_checkpoint=True)
# We should be back to 8 again, picking up based upon the last ran Trainer
self.assertEqual(trainer._train_batch_size, 8)

def test_auto_batch_size_with_resume_from_checkpoint(self):
train_dataset = RegressionDataset(length=128)

Expand Down

0 comments on commit 6015d0a

Please sign in to comment.