Skip to content

Commit

Permalink
Add loss generating tokens for loss accumulation (#3677)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Oct 23, 2024
1 parent ef42f54 commit 5aaa8c9
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 7 deletions.
19 changes: 17 additions & 2 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import collections.abc
import logging
import textwrap
import warnings
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, Sequence, Union
Expand All @@ -20,6 +21,8 @@

__all__ = ['DataSpec', 'ensure_data_spec']

log = logging.getLogger(__name__)


def _split_list(l, microbatch_size: int):
if len(l) < microbatch_size:
Expand Down Expand Up @@ -185,14 +188,14 @@ def __init__(
device_transforms: Optional[Callable[[Batch], Batch]] = None,
split_batch: Optional[Callable[[Batch, Union[int, float]], Sequence[Batch]]] = None,
get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None,
get_num_tokens_in_batch: Optional[Callable[[Batch], int]] = None,
get_num_tokens_in_batch: Optional[Callable[[Batch], Union[int, dict[str, int]]]] = None,
) -> None:
self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader
self.num_tokens = num_tokens
self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms
self.split_batch = default_split_batch if split_batch is None else split_batch
self.get_num_samples_in_batch = self._default_get_num_samples_in_batch if get_num_samples_in_batch is None else get_num_samples_in_batch
self.get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch
self._get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch

if num_samples is not None:
self.num_samples = num_samples
Expand Down Expand Up @@ -295,6 +298,18 @@ def _default_get_num_tokens_in_batch(self, batch: Batch) -> int:
return self.dataloader.dataset.max_seq_len * samples_per_batch # type: ignore
return 0

def get_num_tokens_in_batch(self, batch: Batch, token_type: str = 'total') -> int:
num_tokens = self._get_num_tokens_in_batch(batch)

if isinstance(num_tokens, int):
return num_tokens
elif isinstance(num_tokens, dict):
if token_type not in num_tokens:
raise ValueError(f'Token type {token_type} not found in num_tokens dict.')
return num_tokens[token_type]
else:
raise ValueError(f'Unexpected return type from get_num_tokens_in_batch: {type(num_tokens)}')


def ensure_data_spec(dataloader: Union[DataSpec, Iterable, dict]) -> DataSpec:
"""Ensures that the ``dataloader`` is a :class:`.DataSpec`.
Expand Down
16 changes: 12 additions & 4 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,10 @@ class Trainer:
it into sections of size ``device_train_microbatch_size``. If the batch size of the dataloader
is not divisible by ``device_train_microbatch_size``, the last section will be potentially smaller.
accumulate_train_batch_on_tokens (bool, optional): Whether training loss is accumulated over the number of tokens in a batch,
rather than the number of samples. Only works if the train data spec implements `get_num_tokens_in_batch`. (default: ``False``)
rather than the number of samples. Only works if the train data spec implements `get_num_tokens_in_batch`.
Note: If you are using this flag, you can optionally have your `get_num_tokens_in_batch` function return a dictionary
with two keys (`total` and `loss_generating`). Composer will then accumulate the batch on loss generating tokens specifically,
even though total tokens will be used for any other time involving tokens. (default: ``False``)
seed (int, optional): The seed used in randomization. If ``None``, then a random seed
will be created. (default: ``None``)
Expand Down Expand Up @@ -3061,11 +3064,13 @@ def _train_microbatches(

# Tracker for gradient accumulation
if self.accumulate_train_batch_on_tokens:
current_batch_size = sum([self._train_data_spec.get_num_tokens_in_batch(b) for b in microbatches])
current_batch_size = sum([
self._train_data_spec.get_num_tokens_in_batch(b, token_type='loss_generating') for b in microbatches
])
if current_batch_size == 0:
raise ValueError(
textwrap.dedent(
'Requested loss accumulation based on number of tokens in training batch, '
'Requested loss accumulation based on number of loss generating tokens in training batch, '
'but zero tokens found (perhaps due to an improper DataSpec).',
),
)
Expand Down Expand Up @@ -3124,7 +3129,10 @@ def _train_microbatch(
device_batch = deepcopy(self.state.batch)

if self.accumulate_train_batch_on_tokens:
microbatch_size = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)
microbatch_size = self._train_data_spec.get_num_tokens_in_batch(
self.state.batch,
token_type='loss_generating',
)
else:
microbatch_size = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
if self.state.deepspeed_enabled or not isinstance(self.state.model, DistributedDataParallel):
Expand Down
84 changes: 84 additions & 0 deletions tests/test_simple_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,87 @@ def test_simple_nlp_mlm_token_batch(tiny_bert_tokenizer, device):
trainer2.fit()
assert trainer2.state.train_metrics is not None
assert trainer2.state.train_metrics['LanguageCrossEntropy'].compute() == cross_entropy


@device('gpu')
def test_simple_nlp_mlm_loss_gen_token_batch(tiny_bert_tokenizer, device):
transformers = pytest.importorskip('transformers')

vocab_size = tiny_bert_tokenizer.vocab_size
sequence_length = 32
size = 96
batch_size = 8
device = get_device(device)

train_dataset = RandomTextLMDataset(
size=size,
vocab_size=vocab_size,
sequence_length=sequence_length,
use_keys=True,
pad_token_id=tiny_bert_tokenizer.pad_token_id,
)
for i in range(size): # Proactively load dataset for consistent randomization
train_dataset[i]
collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer)

# Get the model's state dict before training starts, so we can reproduce results
model = SimpleTransformerMaskedLM(vocab_size=vocab_size)
state_dict = model.state_dict()

# Set up the data spec that can count the non-padding tokens in a batch
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
sampler=dist.get_sampler(train_dataset),
collate_fn=collator,
)
data_spec = DataSpec(
dataloader=train_dataloader,
get_num_tokens_in_batch=lambda b: (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item(),
)

# Arbitrarily divide num tokens by 2 to simulate loss-generating tokens
loss_gen_data_spec = DataSpec(
dataloader=train_dataloader,
get_num_tokens_in_batch=lambda b: {
'total': (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item(),
'loss_generating': (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item() // 2,
},
)

trainer = Trainer(
model=model,
seed=42,
train_dataloader=data_spec,
max_duration='2ep',
device_train_microbatch_size=batch_size // 2,
accumulate_train_batch_on_tokens=False,
device=device,
)
trainer.fit()

# Check that there is some train cross entropy
assert trainer.state.train_metrics is not None
cross_entropy = trainer.state.train_metrics['LanguageCrossEntropy'].compute()
assert cross_entropy != 0.0

# Set up a trainer that accumulates train loss based on token counts, after reloading original state dict
model.load_state_dict(state_dict)
token_trainer = Trainer(
model=model,
seed=42,
train_dataloader=loss_gen_data_spec,
max_duration='2ep',
device_train_microbatch_size=batch_size // 2,
accumulate_train_batch_on_tokens=True,
device=device,
)
token_trainer.fit()

# Check that there is some train cross entropy
assert token_trainer.state.train_metrics is not None
token_cross_entropy = token_trainer.state.train_metrics['LanguageCrossEntropy'].compute()
assert token_cross_entropy != 0.0

# Require that the train cross entropies are different between the trainers
assert cross_entropy != token_cross_entropy
37 changes: 36 additions & 1 deletion tests/trainer/test_dataspec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any
import contextlib
from typing import Any, Optional

import pytest
import torch
Expand Down Expand Up @@ -71,6 +72,40 @@ def test_get_num_tokens_hf_default(batch_size: int, sequence_length: int, use_ke
assert actual == expected


@pytest.mark.parametrize(
'return_dict,requested_key,expected',
[
[True, None, 8], # dict with default key
[False, None, 8], # int with default key
[False, 'loss_generating', 8], # int with non-default key
[True, 'loss_generating', 4], # dict with non-default key
],
)
def test_get_num_tokens_types(return_dict: bool, requested_key: Optional[str], expected: Optional[int]):
should_error = expected is None
error_context = pytest.raises(ValueError) if should_error else contextlib.nullcontext()

def get_num_tokens_in_batch(batch):
num_tokens = 8
num_loss_generating_tokens = 4

if return_dict:
return {'total': num_tokens, 'loss_generating': num_loss_generating_tokens}

return num_tokens

dataspec = DataSpec(dataloader=[], get_num_tokens_in_batch=get_num_tokens_in_batch)

batch = {}
extra_args = {}
if requested_key is not None:
extra_args['token_type'] = requested_key

with error_context:
actual = dataspec.get_num_tokens_in_batch(batch, **extra_args)
assert actual == expected


def test_small_batch_at_end_warning():
batch_size = 4
dataset_size = 17
Expand Down

0 comments on commit 5aaa8c9

Please sign in to comment.