Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss gen tokens #3677

Merged
merged 12 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import collections.abc
import logging
import textwrap
import warnings
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Optional, Sequence, Union
Expand All @@ -20,6 +21,8 @@

__all__ = ['DataSpec', 'ensure_data_spec']

log = logging.getLogger(__name__)


def _split_list(l, microbatch_size: int):
if len(l) < microbatch_size:
Expand Down Expand Up @@ -185,14 +188,14 @@ def __init__(
device_transforms: Optional[Callable[[Batch], Batch]] = None,
split_batch: Optional[Callable[[Batch, Union[int, float]], Sequence[Batch]]] = None,
get_num_samples_in_batch: Optional[Callable[[Batch], Union[int, float]]] = None,
get_num_tokens_in_batch: Optional[Callable[[Batch], int]] = None,
get_num_tokens_in_batch: Optional[Callable[[Batch], Union[int, dict[str, int]]]] = None,
) -> None:
self.dataloader: Union[Iterable, torch.utils.data.DataLoader] = dataloader
self.num_tokens = num_tokens
self.device_transforms = self._default_device_transforms if device_transforms is None else device_transforms
self.split_batch = default_split_batch if split_batch is None else split_batch
self.get_num_samples_in_batch = self._default_get_num_samples_in_batch if get_num_samples_in_batch is None else get_num_samples_in_batch
self.get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch
self._get_num_tokens_in_batch = self._default_get_num_tokens_in_batch if get_num_tokens_in_batch is None else get_num_tokens_in_batch

if num_samples is not None:
self.num_samples = num_samples
Expand Down Expand Up @@ -295,6 +298,23 @@ def _default_get_num_tokens_in_batch(self, batch: Batch) -> int:
return self.dataloader.dataset.max_seq_len * samples_per_batch # type: ignore
return 0

def get_num_tokens_in_batch(self, batch: Batch, token_type: str = 'total') -> int:
num_tokens = self._get_num_tokens_in_batch(batch)

if isinstance(num_tokens, int):
if token_type != 'total':
log.warning(
f'get_num_tokens_in_batch returned an int, but token_type is {token_type}. ' +
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
'Returning the total number of tokens in the batch.',
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
)
return num_tokens
elif isinstance(num_tokens, dict):
if token_type not in num_tokens:
raise ValueError(f'Token type {token_type} not found in num_tokens dict.')
return num_tokens[token_type]
else:
raise ValueError(f'Unexpected return type from get_num_tokens_in_batch: {type(num_tokens)}')


def ensure_data_spec(dataloader: Union[DataSpec, Iterable, dict]) -> DataSpec:
"""Ensures that the ``dataloader`` is a :class:`.DataSpec`.
Expand Down
16 changes: 12 additions & 4 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,10 @@ class Trainer:
it into sections of size ``device_train_microbatch_size``. If the batch size of the dataloader
is not divisible by ``device_train_microbatch_size``, the last section will be potentially smaller.
accumulate_train_batch_on_tokens (bool, optional): Whether training loss is accumulated over the number of tokens in a batch,
rather than the number of samples. Only works if the train data spec implements `get_num_tokens_in_batch`. (default: ``False``)
rather than the number of samples. Only works if the train data spec implements `get_num_tokens_in_batch`.
Note: If you are using this flag, you can optionally have your `get_num_tokens_in_batch` function return a dictionary
with two keys (`total` and `loss_generating`). Composer will then accumulate the batch on loss generating tokens specifically,
even though total tokens will be used for any other time involving tokens. (default: ``False``)
seed (int, optional): The seed used in randomization. If ``None``, then a random seed
will be created. (default: ``None``)

Expand Down Expand Up @@ -3061,11 +3064,13 @@ def _train_microbatches(

# Tracker for gradient accumulation
if self.accumulate_train_batch_on_tokens:
current_batch_size = sum([self._train_data_spec.get_num_tokens_in_batch(b) for b in microbatches])
current_batch_size = sum([
self._train_data_spec.get_num_tokens_in_batch(b, token_type='loss_generating') for b in microbatches
])
if current_batch_size == 0:
raise ValueError(
textwrap.dedent(
'Requested loss accumulation based on number of tokens in training batch, '
'Requested loss accumulation based on number of loss generating tokens in training batch, '
'but zero tokens found (perhaps due to an improper DataSpec).',
),
)
Expand Down Expand Up @@ -3124,7 +3129,10 @@ def _train_microbatch(
device_batch = deepcopy(self.state.batch)

if self.accumulate_train_batch_on_tokens:
microbatch_size = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)
microbatch_size = self._train_data_spec.get_num_tokens_in_batch(
self.state.batch,
token_type='loss_generating',
)
else:
microbatch_size = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
if self.state.deepspeed_enabled or not isinstance(self.state.model, DistributedDataParallel):
Expand Down
84 changes: 84 additions & 0 deletions tests/test_simple_nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,87 @@ def test_simple_nlp_mlm_token_batch(tiny_bert_tokenizer, device):
trainer2.fit()
assert trainer2.state.train_metrics is not None
assert trainer2.state.train_metrics['LanguageCrossEntropy'].compute() == cross_entropy


@device('gpu')
def test_simple_nlp_mlm_loss_gen_token_batch(tiny_bert_tokenizer, device):
transformers = pytest.importorskip('transformers')

vocab_size = tiny_bert_tokenizer.vocab_size
sequence_length = 32
size = 96
batch_size = 8
device = get_device(device)

train_dataset = RandomTextLMDataset(
size=size,
vocab_size=vocab_size,
sequence_length=sequence_length,
use_keys=True,
pad_token_id=tiny_bert_tokenizer.pad_token_id,
)
for i in range(size): # Proactively load dataset for consistent randomization
train_dataset[i]
collator = transformers.DataCollatorForLanguageModeling(tokenizer=tiny_bert_tokenizer)

# Get the model's state dict before training starts, so we can reproduce results
model = SimpleTransformerMaskedLM(vocab_size=vocab_size)
state_dict = model.state_dict()

# Set up the data spec that can count the non-padding tokens in a batch
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
sampler=dist.get_sampler(train_dataset),
collate_fn=collator,
)
data_spec = DataSpec(
dataloader=train_dataloader,
get_num_tokens_in_batch=lambda b: (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item(),
)

# Arbitrarily divide num tokens by 2 to simulate loss-generating tokens
loss_gen_data_spec = DataSpec(
dataloader=train_dataloader,
get_num_tokens_in_batch=lambda b: {
'total': (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item(),
'loss_generating': (b['input_ids'] != tiny_bert_tokenizer.pad_token_id).sum().item() // 2,
},
)

trainer = Trainer(
model=model,
seed=42,
train_dataloader=data_spec,
max_duration='2ep',
device_train_microbatch_size=batch_size // 2,
accumulate_train_batch_on_tokens=False,
device=device,
)
trainer.fit()

# Check that there is some train cross entropy
assert trainer.state.train_metrics is not None
cross_entropy = trainer.state.train_metrics['LanguageCrossEntropy'].compute()
assert cross_entropy != 0.0

# Set up a trainer that accumulates train loss based on token counts, after reloading original state dict
model.load_state_dict(state_dict)
token_trainer = Trainer(
model=model,
seed=42,
train_dataloader=loss_gen_data_spec,
max_duration='2ep',
device_train_microbatch_size=batch_size // 2,
accumulate_train_batch_on_tokens=True,
device=device,
)
token_trainer.fit()

# Check that there is some train cross entropy
assert token_trainer.state.train_metrics is not None
token_cross_entropy = token_trainer.state.train_metrics['LanguageCrossEntropy'].compute()
assert token_cross_entropy != 0.0

# Require that the train cross entropies are different between the trainers
assert cross_entropy != token_cross_entropy
37 changes: 36 additions & 1 deletion tests/trainer/test_dataspec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any
import contextlib
from typing import Any, Optional

import pytest
import torch
Expand Down Expand Up @@ -71,6 +72,40 @@ def test_get_num_tokens_hf_default(batch_size: int, sequence_length: int, use_ke
assert actual == expected


@pytest.mark.parametrize(
'return_dict,requested_key,expected',
[
[True, None, 8], # dict with default key
[False, None, 8], # int with default key
[False, 'loss_generating', 8], # int with non-default key
[True, 'loss_generating', 4], # dict with non-default key
],
)
def test_get_num_tokens_types(return_dict: bool, requested_key: Optional[str], expected: Optional[int]):
should_error = expected is None
error_context = pytest.raises(ValueError) if should_error else contextlib.nullcontext()

def get_num_tokens_in_batch(batch):
num_tokens = 8
num_loss_generating_tokens = 4

if return_dict:
return {'total': num_tokens, 'loss_generating': num_loss_generating_tokens}

return num_tokens

dataspec = DataSpec(dataloader=[], get_num_tokens_in_batch=get_num_tokens_in_batch)

batch = {}
extra_args = {}
if requested_key is not None:
extra_args['token_type'] = requested_key

with error_context:
actual = dataspec.get_num_tokens_in_batch(batch, **extra_args)
assert actual == expected


def test_small_batch_at_end_warning():
batch_size = 4
dataset_size = 17
Expand Down
Loading