Skip to content

Commit

Permalink
FEAT [Trainer / bnb]: Add RMSProp from bitsandbytes to HF `Trai…
Browse files Browse the repository at this point in the history
…ner` (#29082)

* add RMSProp to Trainer

* revert some change

* Update src/transformers/trainer.py

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
younesbelkada and amyeroberts authored Feb 20, 2024
1 parent a7ff2f2 commit f7ef7ce
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,9 +1084,12 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
OptimizerNames.LION_8BIT,
OptimizerNames.PAGED_LION,
OptimizerNames.PAGED_LION_8BIT,
OptimizerNames.RMSPROP_BNB,
OptimizerNames.RMSPROP_8BIT,
OptimizerNames.RMSPROP_32BIT,
]:
try:
from bitsandbytes.optim import AdamW, Lion
from bitsandbytes.optim import AdamW, Lion, RMSprop

is_paged = False
optim_bits = 32
Expand All @@ -1101,8 +1104,16 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
elif "lion" in args.optim:
optimizer_cls = Lion
additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
elif "rmsprop" in args.optim:
optimizer_cls = RMSprop
# Above we pass all `adam_kwargs` to the optimizer, here
# we only pass `optim_args` which can be passed by the user.
additional_optim_kwargs = optim_args

bnb_kwargs = {"optim_bits": optim_bits}
if "rmsprop" not in args.optim:
bnb_kwargs["is_paged"] = is_paged

bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
optimizer_kwargs.update(additional_optim_kwargs)
optimizer_kwargs.update(bnb_kwargs)
except ImportError:
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ class OptimizerNames(ExplicitEnum):
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"
RMSPROP = "rmsprop"
RMSPROP_BNB = "rmsprop_bnb"
RMSPROP_8BIT = "rmsprop_bnb_8bit"
RMSPROP_32BIT = "rmsprop_bnb_32bit"


# TODO: `TrainingArguments` users rely on it being fully mutable. In the future see if we can narrow this to a few keys: https://github.com/huggingface/transformers/pull/25903
Expand Down
51 changes: 51 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
get_tests_dir,
is_staging_test,
require_accelerate,
require_bitsandbytes,
require_deepspeed,
require_intel_extension_for_pytorch,
require_optuna,
Expand Down Expand Up @@ -872,6 +873,56 @@ def test_number_of_steps_in_training_with_ipex(self):
train_output = trainer.train()
self.assertEqual(train_output.global_step, 10)

@require_bitsandbytes
def test_rmsprop_bnb(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb"
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)

# Check that it trains without errors
trainer.train()

@require_bitsandbytes
def test_rmsprop_bnb_8bit(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb_8bit"
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)

# Check that it trains without errors
trainer.train()

@require_bitsandbytes
def test_rmsprop_bnb_32bit(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)
with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="rmsprop_bnb_32bit"
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)

# Check that it trains without errors
trainer.train()

def test_neftune(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
Expand Down

0 comments on commit f7ef7ce

Please sign in to comment.