Skip to content

Commit

Permalink
Fix GA loss bugs and add unit test (#35121)
Browse files Browse the repository at this point in the history
* fix GA bugs and add unit test

* narrow down model loss unit test diff gap

* format code to make ruff happy

* send num_items_in_batch argument to decoder

* fix GA loss bug in BertLMHeadModel

* use TinyStories-33M to narrow down diff gap

* fotmat code

* missing .config

* avoid add extra args

---------

Co-authored-by: kangsheng <[email protected]>
  • Loading branch information
techkang and kangsheng authored Dec 9, 2024
1 parent c8c8dff commit 1ccca8f
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 24 deletions.
7 changes: 2 additions & 5 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -1375,11 +1376,7 @@ def forward(

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)

if not return_dict:
output = (prediction_scores,) + outputs[2:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ def forward(
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
if "num_items_in_batch" in kwargs_encoder:
kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)

if encoder_outputs is None:
if inputs is None:
Expand Down
9 changes: 1 addition & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3649,10 +3649,7 @@ def training_step(
return loss_mb.reduce_mean().detach().to(self.args.device)

with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

del inputs
if (
Expand Down Expand Up @@ -5132,10 +5129,6 @@ def get_batch_samples(self, epoch_iterator, num_batches):
except StopIteration:
break

# Keep default behavior the same
if not self.model_accepts_loss_kwargs:
return batch_samples, None

if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
Expand Down
114 changes: 103 additions & 11 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,11 +750,102 @@ def test_model_init(self):
self.check_trained_model(trainer.model, alternate_seed=True)

@slow
def test_gradient_accumulation_loss_alignment(self):
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
set_seed(42)
import datasets

model_name = "distilgpt2"
model_name = "nickypro/tinyllama-110M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

model = AutoModelForCausalLM.from_pretrained(model_name)

base_loss_callback = StoreLossCallback()

args_kwargs = {
"report_to": "none",
"logging_steps": 1,
"max_steps": 20,
"learning_rate": 3e-4,
"disable_tqdm": True,
}

args = TrainingArguments(
"./generation",
**args_kwargs,
)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[base_loss_callback],
data_collator=data_collator,
)
assert trainer.model_accepts_loss_kwargs
trainer.train()

grad_accum_loss_callback = StoreLossCallback()
args = TrainingArguments(
"./generation",
**args_kwargs,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
)
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[grad_accum_loss_callback],
data_collator=data_collator,
)
trainer.train()

set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
broken_loss_callback = StoreLossCallback()
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[broken_loss_callback],
data_collator=data_collator,
)
# disable model_accepts_loss_kwargs
trainer.model_accepts_loss_kwargs = False
trainer.train()

# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
]
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]

# all diff truth should be quite close
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")

# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")

@slow
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
set_seed(42)
import datasets

model_name = "roneneldan/TinyStories-33M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
Expand Down Expand Up @@ -836,15 +927,16 @@ def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_ite
trainer.train()

# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
# These should be quite close
for diff in diff_truth:
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")

# These should be very off
for diff in diff_broken:
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
diff_truth = [
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
]
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]

# all diff truth should be quite close
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")

# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")

def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
Expand Down

0 comments on commit 1ccca8f

Please sign in to comment.