diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index cfac51a09d62c7..4994aef3af8133 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1084,9 +1084,12 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: OptimizerNames.LION_8BIT, OptimizerNames.PAGED_LION, OptimizerNames.PAGED_LION_8BIT, + OptimizerNames.RMSPROP_BNB, + OptimizerNames.RMSPROP_8BIT, + OptimizerNames.RMSPROP_32BIT, ]: try: - from bitsandbytes.optim import AdamW, Lion + from bitsandbytes.optim import AdamW, Lion, RMSprop is_paged = False optim_bits = 32 @@ -1101,8 +1104,16 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: elif "lion" in args.optim: optimizer_cls = Lion additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)} + elif "rmsprop" in args.optim: + optimizer_cls = RMSprop + # Above we pass all `adam_kwargs` to the optimizer, here + # we only pass `optim_args` which can be passed by the user. + additional_optim_kwargs = optim_args + + bnb_kwargs = {"optim_bits": optim_bits} + if "rmsprop" not in args.optim: + bnb_kwargs["is_paged"] = is_paged - bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits} optimizer_kwargs.update(additional_optim_kwargs) optimizer_kwargs.update(bnb_kwargs) except ImportError: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 4ec9424396178f..19ab24c205cf72 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -157,6 +157,9 @@ class OptimizerNames(ExplicitEnum): PAGED_LION = "paged_lion_32bit" PAGED_LION_8BIT = "paged_lion_8bit" RMSPROP = "rmsprop" + RMSPROP_BNB = "rmsprop_bnb" + RMSPROP_8BIT = "rmsprop_bnb_8bit" + RMSPROP_32BIT = "rmsprop_bnb_32bit" # TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 87e95a7ea396f7..b64e93a2d17494 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -58,6 +58,7 @@ get_tests_dir, is_staging_test, require_accelerate, + require_bitsandbytes, require_deepspeed, require_intel_extension_for_pytorch, require_optuna, @@ -872,6 +873,56 @@ def test_number_of_steps_in_training_with_ipex(self): train_output = trainer.train() self.assertEqual(train_output.global_step, 10) + @require_bitsandbytes + def test_rmsprop_bnb(self): + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb" + ) + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) + + # Check that it trains without errors + trainer.train() + + @require_bitsandbytes + def test_rmsprop_bnb_8bit(self): + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb_8bit" + ) + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) + + # Check that it trains without errors + trainer.train() + + @require_bitsandbytes + def test_rmsprop_bnb_32bit(self): + config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) + tiny_gpt2 = GPT2LMHeadModel(config) + x = torch.randint(0, 100, (128,)) + train_dataset = RepeatDataset(x) + with tempfile.TemporaryDirectory() as tmpdir: + # Trainer without inf/nan filter + args = TrainingArguments( + tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb_32bit" + ) + trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset) + + # Check that it trains without errors + trainer.train() + def test_neftune(self): config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4) tiny_gpt2 = GPT2LMHeadModel(config)