Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loss generating token counts #1610

Merged
merged 12 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions llmfoundry/callbacks/loss_perp_v_len_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torchmetrics import Metric

from llmfoundry.models.mpt import ComposerMPTCausalLM
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX
from llmfoundry.utils.warnings import experimental_class

__all__ = [
Expand All @@ -33,7 +34,7 @@ def __init__(
self,
log_batch_interval: int,
compute_batch_interval: int,
ignore_index: int = -100,
ignore_index: int = CROSS_ENTROPY_IGNORE_INDEX,
):
if compute_batch_interval > log_batch_interval:
raise ValueError(
Expand Down Expand Up @@ -69,7 +70,7 @@ def after_backward(self, state: State, logger: Logger) -> None:
labels = state.batch['labels']
if state.model.shift_labels:
labels[:, :-1] = labels[:, 1:].detach().clone()
labels[:, -1] = -100
labels[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
seq_parallel_world_size = getattr(
state.model.model.transformer,
'seq_parallel_world_size',
Expand Down
24 changes: 13 additions & 11 deletions llmfoundry/data/finetuning/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX

log = logging.getLogger(__name__)

__all__ = [
'Seq2SeqFinetuningCollator',
]

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

TokenizedExample = dict[str, list[dict[str, list[int]]]]


Expand Down Expand Up @@ -79,7 +78,7 @@ def _sequence_to_labels_none(
cutoff: Optional[int] = None,
) -> list[int]:
del is_last_turn, cutoff # unused
return [_HF_IGNORE_INDEX] * len(sequence)
return [CROSS_ENTROPY_IGNORE_INDEX] * len(sequence)


def _sequence_to_labels_last(
Expand All @@ -91,7 +90,7 @@ def _sequence_to_labels_last(
if is_last_turn:
return sequence
else:
return [_HF_IGNORE_INDEX] * len(sequence)
return [CROSS_ENTROPY_IGNORE_INDEX] * len(sequence)


def _sequence_to_labels_cutoff(
Expand All @@ -105,7 +104,7 @@ def _sequence_to_labels_cutoff(
if len(sequence) >= cutoff:
return sequence
else:
return [_HF_IGNORE_INDEX] * len(sequence)
return [CROSS_ENTROPY_IGNORE_INDEX] * len(sequence)


_TARGET_POLICY_LOOKUP = {
Expand Down Expand Up @@ -352,7 +351,8 @@ def _process_and_batch_decoder_only(
labels = labels[:max_seq_len]

# Check to make sure there are still loss-generating tokens. Error if not.
if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0:
if len([l for l in labels if l != CROSS_ENTROPY_IGNORE_INDEX
],) == 0:
raise ValueError(
f'Truncating to max_seq_len={max_seq_len} has removed all loss-generating tokens. ' +\
f'Pre-truncation sequence length was {orig_size}. ' +\
Expand All @@ -375,7 +375,7 @@ def _process_and_batch_decoder_only(
# Annoyingly, we need to pad everything but input_ids
# and attention_mask ourselves
n_total = len(input_ids)
i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - n_total)
i_pad = [CROSS_ENTROPY_IGNORE_INDEX] * (max_seq_len - n_total)
if self.tokenizer.padding_side == 'left':
labels = i_pad + labels
else:
Expand Down Expand Up @@ -444,7 +444,9 @@ def _process_and_batch_encoder_decoder(
for context, target in contexts_and_targets:
# We need to pad labels ourselves. Because HF.
if len(target) < max_seq_len:
i_pad = [_HF_IGNORE_INDEX] * (max_seq_len - len(target))
i_pad = [
CROSS_ENTROPY_IGNORE_INDEX,
] * (max_seq_len - len(target))
target = target + i_pad
else:
if not self._warned_target:
Expand Down Expand Up @@ -491,12 +493,12 @@ def _process_and_batch_encoder_decoder(
],
dim=1)
batch['decoder_input_ids'].masked_fill_(
batch['decoder_input_ids'] == _HF_IGNORE_INDEX,
batch['decoder_input_ids'] == CROSS_ENTROPY_IGNORE_INDEX,
self.tokenizer.pad_token_id,
)
batch['decoder_attention_mask'] = torch.not_equal(
batch['labels'],
_HF_IGNORE_INDEX,
CROSS_ENTROPY_IGNORE_INDEX,
)

# This logic prevents trimming on at least the first batch
Expand Down
16 changes: 8 additions & 8 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import build_streams
from llmfoundry.utils.config_utils import to_dict_container
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX
from llmfoundry.utils.exceptions import (
FinetuningFileNotFoundError,
MissingHuggingFaceURLSplitError,
Expand All @@ -40,9 +41,6 @@
'build_finetuning_dataloader',
]

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

# Extra keys present in the dataset config dictionary beyond the constructor keys
_ALLOWED_DATASET_KEYS = {
'shuffle',
Expand Down Expand Up @@ -786,7 +784,7 @@ def build_collate_fn(
)
context = torch.logical_and(
batch['attention_mask'][j] == 1,
batch['labels'][j] == _HF_IGNORE_INDEX,
batch['labels'][j] == CROSS_ENTROPY_IGNORE_INDEX,
)
print(
'\033[92m{}\033[00m\n'.format('CONTEXT: '),
Expand All @@ -804,7 +802,8 @@ def build_collate_fn(
j,
torch.logical_and(
is_subseq,
batch['labels'][j] != _HF_IGNORE_INDEX,
batch['labels'][j] !=
CROSS_ENTROPY_IGNORE_INDEX,
)],
skip_special_tokens=False,
clean_up_tokenization_spaces=True,
Expand All @@ -822,7 +821,7 @@ def build_collate_fn(
)
context = torch.logical_and(
batch['attention_mask'][j] == 1,
batch['labels'][j] == _HF_IGNORE_INDEX,
batch['labels'][j] == CROSS_ENTROPY_IGNORE_INDEX,
)
print(
'\033[92m{}\033[00m\n'.format('CONTEXT: '),
Expand All @@ -835,8 +834,9 @@ def build_collate_fn(
print(
'\033[91m{}\033[00m\n'.format('TARGET: '),
tokenizer.decode(
batch['input_ids'][
j, batch['labels'][j] != _HF_IGNORE_INDEX],
batch['input_ids']
[j,
batch['labels'][j] != CROSS_ENTROPY_IGNORE_INDEX],
skip_special_tokens=False,
clean_up_tokenization_spaces=True,
),
Expand Down
5 changes: 3 additions & 2 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
stream_remote_local_validate,
)
from llmfoundry.data.finetuning.collator import (
_HF_IGNORE_INDEX,
stitch_turns_decoder_only,
stitch_turns_encoder_decoder,
)
from llmfoundry.tokenizers import get_date_string
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX
# yapf: disable
from llmfoundry.utils.exceptions import (
ALLOWED_MESSAGES_KEYS,
Expand Down Expand Up @@ -501,7 +501,8 @@ def is_valid_ift_example(
if len(input_ids) == 0:
return False

if len([label for label in labels if label != _HF_IGNORE_INDEX]) == 0:
if len([label for label in labels if label != CROSS_ENTROPY_IGNORE_INDEX
],) == 0:
return False

return True
Expand Down
6 changes: 4 additions & 2 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from composer.utils import dist
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX

log = logging.getLogger(__name__)

__all__ = [
Expand Down Expand Up @@ -152,7 +154,7 @@ def _convert_to_batch(

pad_vals = {
'input_ids': self.pad_token_id,
'labels': -100,
'labels': CROSS_ENTROPY_IGNORE_INDEX,
'attention_mask': 0,
'sequence_id': -1,
}
Expand Down Expand Up @@ -317,7 +319,7 @@ def _combine_in_place(
if 'labels' in add_on:
# Prevents the last token in example from being trained to
# predict the first token in add_on, which would make no sense.
add_on['labels'][0] = -100
add_on['labels'][0] = CROSS_ENTROPY_IGNORE_INDEX

for k in example.keys():
if k == 'sequence_id':
Expand Down
24 changes: 22 additions & 2 deletions llmfoundry/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from llmfoundry.data.finetuning.dataloader import build_collate_fn
from llmfoundry.data.packing import BinPackCollator
from llmfoundry.data.text_data import ConcatenatedSequenceCollatorWrapper
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -83,7 +84,7 @@ def get_data_spec(

def get_tokens_per_batch_func(
decoder_only: bool = True,
) -> Callable[[Batch], int]:
) -> Callable[[Batch], Union[int, dict[str, int]]]:
"""Returns a callable that counts the number of tokens in a batch.

Args:
Expand All @@ -95,7 +96,7 @@ def get_tokens_per_batch_func(
Callable[[Batch], int]: A callable that counts the number of tokens in a batch.
"""

def get_num_tokens_in_batch(batch: Batch) -> int:
def get_num_tokens_in_batch(batch: Batch) -> Union[int, dict[str, int]]:
if not isinstance(batch, Mapping) or (
'attention_mask' not in batch and 'input_ids' not in batch
):
Expand All @@ -114,13 +115,32 @@ def get_num_tokens_in_batch(batch: Batch) -> int:
else:
input_ids_tokens = batch['input_ids'].numel()

loss_generating_tokens = None
if 'labels' in batch:
loss_generating_tokens = int(
torch.sum(batch['labels'] != CROSS_ENTROPY_IGNORE_INDEX).item(),
)

# Subtract one for each example in the batch that starts with a non -100,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dakinggg I don't think this subtraction isn't necessary. Instead you can just do this:

loss_generating_tokens = int(
                torch.sum(batch['labels'][...,1:] != CROSS_ENTROPY_IGNORE_INDEX).item(),
            )

*I just came across this pr while looking into how mosaic's libs handle the gradient accumulation bug recently discussed on x.com

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah, that should work too :)

# because those will be shifted off
loss_generating_tokens -= int(
torch.sum(
batch['labels'][:, 0] != CROSS_ENTROPY_IGNORE_INDEX,
).item(),
)

# For encoder decoder models only
decoder_input_ids_tokens = 0
if not decoder_only:
decoder_input_ids_tokens = int(
torch.sum(batch['decoder_attention_mask']).item(),
)

if loss_generating_tokens is not None:
return {
'total': input_ids_tokens + decoder_input_ids_tokens,
'loss_generating': loss_generating_tokens,
}
return input_ids_tokens + decoder_input_ids_tokens

return get_num_tokens_in_batch
Expand Down
3 changes: 0 additions & 3 deletions llmfoundry/models/hf/hf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@

__all__ = ['BaseHuggingFaceModel']

# HuggingFace hardcodes the ignore index to -100
_HF_IGNORE_INDEX = -100

log = logging.getLogger(__name__)


Expand Down
3 changes: 2 additions & 1 deletion llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from llmfoundry.eval.metrics import InContextLearningMetric
from llmfoundry.metrics import DEFAULT_CAUSAL_LM_EVAL_METRICS
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX

__all__ = ['InferenceAPIEvalWrapper']

Expand Down Expand Up @@ -92,7 +93,7 @@ def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
batch = self.rebatch(batch)
self.labels = batch.pop('labels')
self.labels[:, :-1] = self.labels[:, 1:].clone()
self.labels[:, -1] = -100
self.labels[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
if isinstance(
metric,
InContextLearningMetric,
Expand Down
10 changes: 6 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@

log = logging.getLogger(__name__)

CROSS_ENTROPY_IGNORE_INDEX = -100


class InvalidConfigAccessError(KeyError):
pass
Expand Down Expand Up @@ -1181,7 +1183,7 @@ def forward(
loss = None
if labels is not None:
_labels = torch.roll(labels, shifts=-1)
_labels[:, -1] = -100
_labels[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
_labels.to(logits.device).view(-1),
Expand Down Expand Up @@ -1331,7 +1333,7 @@ def _reorder_cache(

def get_targets(labels: torch.Tensor) -> torch.Tensor:
targets = torch.roll(labels, shifts=-1)
targets[:, -1] = -100
targets[:, -1] = CROSS_ENTROPY_IGNORE_INDEX
return targets


Expand Down Expand Up @@ -1410,7 +1412,7 @@ def __init__(
CrossEntropyLoss as FusedCrossEntropyLoss

self.loss_fn = FusedCrossEntropyLoss(
ignore_index=-100,
ignore_index=CROSS_ENTROPY_IGNORE_INDEX,
reduction='none',
)
except:
Expand All @@ -1423,7 +1425,7 @@ def __init__(
)
elif loss_fn_config == 'torch_crossentropy':
self.loss_fn = nn.CrossEntropyLoss(
ignore_index=-100,
ignore_index=CROSS_ENTROPY_IGNORE_INDEX,
reduction='none',
)
else:
Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
process_init_device,
update_batch_size_info,
)
from llmfoundry.utils.consts import CROSS_ENTROPY_IGNORE_INDEX
from llmfoundry.utils.data_prep_utils import (
DownloadingIterable,
merge_shard_groups,
Expand Down Expand Up @@ -111,4 +112,5 @@
'ExperimentalWarning',
'experimental_function',
'experimental_class',
'CROSS_ENTROPY_IGNORE_INDEX',
]
4 changes: 4 additions & 0 deletions llmfoundry/utils/consts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

CROSS_ENTROPY_IGNORE_INDEX = -100
Loading
Loading