diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py index 6d91237bc8..4712de5d5e 100644 --- a/llmfoundry/callbacks/__init__.py +++ b/llmfoundry/callbacks/__init__.py @@ -22,6 +22,8 @@ from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import ( MegaBlocksMoE_TokPerExpert, ) +from llmfoundry.callbacks.loss_perp_v_len_callback import \ + LossPerpVsContextLengthLogger from llmfoundry.callbacks.monolithic_ckpt_callback import ( MonolithicCheckpointSaver, ) @@ -52,6 +54,8 @@ callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert) callbacks.register('run_timeout', func=RunTimeoutCallback) +callbacks.register('loss_perp_v_len', func=LossPerpVsContextLengthLogger) + callbacks_with_config.register('async_eval', func=AsyncEval) callbacks_with_config.register('curriculum_learning', func=CurriculumLearning) @@ -66,4 +70,5 @@ 'MegaBlocksMoE_TokPerExpert', 'AsyncEval', 'CurriculumLearning', + 'LossPerpVsContextLengthLogger', ] diff --git a/llmfoundry/callbacks/loss_perp_v_len_callback.py b/llmfoundry/callbacks/loss_perp_v_len_callback.py new file mode 100644 index 0000000000..1a3ac05651 --- /dev/null +++ b/llmfoundry/callbacks/loss_perp_v_len_callback.py @@ -0,0 +1,351 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Mapping, Optional, Tuple + +import torch +from composer.core import Callback, State +from composer.loggers import Logger, MLFlowLogger +from composer.utils import dist +from torchmetrics import Metric + +from llmfoundry.models.mpt import ComposerMPTCausalLM +from llmfoundry.utils.warnings import experimental_class + +__all__ = [ + 'LossPerpVsContextLengthLogger', +] + + +@experimental_class('LossPerpVsContextLengthLogger') +class LossPerpVsContextLengthLogger(Callback): + """Logs the average loss and perplexity for every context length. + + Note: Currently only works with MLFlow logger. + + Args: + log_batch_interval (int): The interval for logging. Currently logging takes longer because MLFlow downloads the table, appends rows to it, and then re-uploads it. Once this is fixed, log_batch_interval will be removed and this will always log as soon as the metric is computed. + compute_batch_interval (int): The interval for computing the metric. + ignore_index (int): Specifies a target value that is ignored for computing loss. + """ + + def __init__( + self, + log_batch_interval: int, + compute_batch_interval: int, + ignore_index: int = -100, + ): + if compute_batch_interval > log_batch_interval: + raise ValueError( + 'log_batch_interval is shorter than the compute_batch_interval for LossPerpVsContextLengthLogger.', + ) + self.log_batch_interval = log_batch_interval + self.compute_batch_interval = compute_batch_interval + self.ignore_index = ignore_index + self.metric_dict = {} + self.loss_perp_v_len = LossPerpVLen(ignore_index) + + def init(self, state: State, logger: Logger) -> None: + if not isinstance(state.model, ComposerMPTCausalLM): + raise ValueError( + 'LossPerpVsContextLengthLogger only supported for ComposerMPTCausalLM models.', + ) + if state.model.shift_labels is None: + raise ValueError( + 'state.model.shift_labels should be set for LossPerpVsContextLengthLogger.', + ) + if all( + not isinstance(destination, MLFlowLogger) + for destination in logger.destinations + ): + raise NotImplementedError( + 'Did not find MLflow in the list of loggers. LossPerpVsContextLengthLogger is only implemented for the MLflow logger.', + ) + + def after_backward(self, state: State, logger: Logger) -> None: + if state.timestamp.batch.value % self.compute_batch_interval == 0: + sequence_id = state.batch['sequence_id' + ] if 'sequence_id' in state.batch else None + labels = state.batch['labels'] + if state.model.shift_labels: + labels[:, :-1] = labels[:, 1:].detach().clone() + labels[:, -1] = -100 + seq_parallel_world_size = getattr( + state.model.model.transformer, + 'seq_parallel_world_size', + 1, + ) + seq_parallel_rank = state.model.model.transformer.seq_parallel_rank if seq_parallel_world_size > 1 else 0 + + if isinstance(state.outputs, Mapping): + logits = state.outputs['logits'] # type: ignore + elif isinstance(state.outputs, torch.Tensor): + logits = state.outputs + else: + raise Exception( + f'Type {type(state.outputs)} for the output is unsupported.', + ) + + if labels.shape[1] != logits.shape[1]: + raise ValueError( + f'The length of labels, {labels.shape[1]=} does not match the length of logits {logits.shape[1]=}.', + ) + + labels, logits = self.preprocess_metric_inputs( + sequence_id, + labels, + logits, + seq_parallel_world_size, + seq_parallel_rank, + ) + + self.loss_perp_v_len.update( + labels, + logits, + sequence_id, + state.model.loss_fn, + ) + + def batch_end(self, state: State, logger: Logger) -> None: + if state.timestamp.batch.value % self.compute_batch_interval == 0: + current_metric_dict = self.loss_perp_v_len.compute() + if dist.get_global_rank() == 0: + for k, v in current_metric_dict.items(): + v = v.tolist() + v.append( + state.timestamp.batch.value, + ) # Add the current batch index as the last column + if k not in self.metric_dict: + self.metric_dict[k] = [] + self.metric_dict[k].append(v) + if state.timestamp.batch.value % self.log_batch_interval == 0 and dist.get_global_rank( + ) == 0: + for k, v in self.metric_dict.items(): + columns = [] + columns = [ + f'context_length_{i}' for i in range(len(v[0]) - 1) + ] # len(v[0]) - 1 because the last column is the batch index + columns.append( + 'batch_index', + ) # Add batch as the last column name + for destination in logger.destinations: + if isinstance(destination, MLFlowLogger): + destination.log_table( + columns=columns, + rows=v, + name=f'metrics/train/LossPerpVLenTable/{k}', + step=state.timestamp.batch.value, + ) + self.metric_dict = {} + + def preprocess_metric_inputs( + self, + sequence_id: Optional[torch.Tensor], + labels: torch.Tensor, + logits: torch.Tensor, + seq_parallel_world_size: int, + seq_parallel_rank: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + del sequence_id, seq_parallel_rank + if seq_parallel_world_size > 1: + raise ValueError( + 'LossPerpVsContextLengthLogger does not support sequence parallelism', + ) + + return labels, logits + + +class LossPerpVLen(Metric): + + full_state_update = False + + def __init__( + self, + ignore_index: int, + dist_sync_on_step: bool = False, + ): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.ignore_index = ignore_index + self.add_state('sum_loss', default=torch.Tensor(), dist_reduce_fx='sum') + self.add_state( + 'sum_perplexity', + default=torch.Tensor(), + dist_reduce_fx='sum', + ) + self.add_state( + 'sum_length', + default=torch.Tensor(), + dist_reduce_fx='sum', + ) + + self.add_state( + 'sum_loss_seq_id', + default=torch.Tensor(), + dist_reduce_fx='sum', + ) + self.add_state( + 'sum_perplexity_seq_id', + default=torch.Tensor(), + dist_reduce_fx='sum', + ) + self.add_state( + 'sum_length_seq_id', + default=torch.Tensor(), + dist_reduce_fx='sum', + ) + + def update( + self, + labels: torch.Tensor, + logits: torch.Tensor, + sequence_id: Optional[torch.Tensor], + loss_fn: Any, + ) -> None: + """Updates the internal state with results from a new batch. + + Args: + labels (torch.Tensor): A Tensor of ground-truth values to compare against. + logits (torch.Tensor): A Tensor of labels. + sequence_id (torch.Tensor | None): The sequence ids for tokens. + loss_fn (Any): The cross entropy loss to use. + """ + valid_labels_mask = torch.where( + labels != self.ignore_index, + torch.ones_like(labels), + torch.zeros_like(labels), + ) + bsz, seq_len = labels.shape + loss = loss_fn(logits.view(bsz * seq_len, -1), labels.view(-1)) + loss = loss.view(bsz, seq_len) + perplexity = torch.exp(loss) + + if self.sum_loss.numel() == 0: + self.sum_loss = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=loss.dtype, + ) + self.sum_perplexity = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=loss.dtype, + ) + self.sum_length = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=torch.long, + ) + self.sum_loss_seq_id = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=loss.dtype, + ) + self.sum_perplexity_seq_id = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=loss.dtype, + ) + self.sum_length_seq_id = torch.zeros( # type: ignore + seq_len, + device=loss.device, + dtype=torch.long, + ) + + self.sum_loss += torch.sum(loss, dim=(0)) + self.sum_perplexity += torch.sum(perplexity, dim=(0)) + self.sum_length += valid_labels_mask.sum(dim=0) + + if sequence_id is not None: + seq_id_expanded = torch.nn.functional.one_hot( + sequence_id, + ).transpose(-1, -2) + seq_lens = seq_id_expanded.sum(dim=-1) + max_num_seq = seq_lens.shape[1] + seq_tok_ids = torch.arange(seq_len, device=sequence_id.device)[ + None, None, :].expand(bsz, max_num_seq, -1) + mask = seq_tok_ids < seq_lens[:, :, None] + seq_len_offsets = torch.nn.functional.pad( + seq_lens.cumsum(dim=1)[:, :-1], + (1, 0), + value=0, + ) + seq_tok_ids = seq_tok_ids + seq_len_offsets[:, :, None] + seq_tok_ids = torch.where( + mask, + seq_tok_ids, + torch.zeros_like(seq_tok_ids), + ) + + loss = loss[:, None, :].expand(-1, max_num_seq, -1) + perplexity = perplexity[:, None, :].expand(-1, max_num_seq, -1) + valid_labels_mask = valid_labels_mask[:, None, :].expand( + -1, + max_num_seq, + -1, + ) + loss = torch.where( + mask, + torch.gather(input=loss, dim=2, index=seq_tok_ids), + torch.zeros_like(loss), + ) + perplexity = torch.where( + mask, + torch.gather(input=perplexity, dim=2, index=seq_tok_ids), + torch.zeros_like(perplexity), + ) + mask = torch.where( + mask, + torch.gather(input=valid_labels_mask, dim=2, index=seq_tok_ids), + torch.zeros_like(valid_labels_mask), + ) + + self.sum_loss_seq_id += torch.sum(loss, dim=(0, 1)) + self.sum_perplexity_seq_id += torch.sum(perplexity, dim=(0, 1)) + self.sum_length_seq_id += torch.sum(mask, dim=(0, 1)) + + def compute(self) -> Dict[str, torch.Tensor]: + """Aggregate the state over all processes to compute the metric. + + Returns: + loss: The loss averaged across all batches as a :class:`~torch.Tensor`. + """ + # Return average loss over entire dataset + sum_perplexity = torch.where( + self.sum_length != 0, + self.sum_perplexity, + -1, + ) + sum_loss = torch.where(self.sum_length != 0, self.sum_loss, -1) + sum_length = torch.where(self.sum_length != 0, self.sum_length, 1) + + sum_perplexity_seq_id = torch.where( + self.sum_length_seq_id != 0, + self.sum_perplexity_seq_id, + -1, + ) + sum_loss_seq_id = torch.where( + self.sum_length_seq_id != 0, + self.sum_loss_seq_id, + -1, + ) + sum_length_seq_id = torch.where( + self.sum_length_seq_id != 0, + self.sum_length_seq_id, + 1, + ) + + return { + 'mean_loss_v_len': + sum_loss / sum_length, + 'mean_perplexity_v_len': + sum_perplexity / sum_length, + 'sum_length': + self.sum_length, + 'mean_loss_seq_id_v_len': + sum_loss_seq_id / sum_length_seq_id, + 'mean_perplexity_seq_id_v_len': + sum_perplexity_seq_id / sum_length_seq_id, + 'sum_length_seq_id': + self.sum_length_seq_id, + } diff --git a/tests/callbacks/test_loss_perp_v_len_callback.py b/tests/callbacks/test_loss_perp_v_len_callback.py new file mode 100644 index 0000000000..46bde1c2f1 --- /dev/null +++ b/tests/callbacks/test_loss_perp_v_len_callback.py @@ -0,0 +1,174 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import MagicMock + +import pytest +import torch +import transformers +from composer.core import State +from composer.core.precision import get_precision_context +from composer.devices import DeviceGPU +from composer.loggers import Logger +from composer.utils import get_device +from omegaconf import DictConfig +from omegaconf import OmegaConf as om + +from llmfoundry import registry +from llmfoundry.data.text_data import ( + StreamingTextDataset, + build_text_dataloader, +) +from llmfoundry.utils.builders import build_composer_model +from llmfoundry.utils.registry_utils import construct_from_registry + + +@pytest.mark.gpu +@pytest.mark.parametrize('shift_labels', [True, False]) +def test_loss_perp_v_len_callback( + shift_labels: bool, + monkeypatch: pytest.MonkeyPatch, +): + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip + except: + pytest.skip('Fused cross entropy was not installed') + + composer_device = get_device(None) + + model_max_length = 12 + + gptt = transformers.AutoTokenizer.from_pretrained('gpt2') + gptt.pad_token_id = gptt.eos_token_id + gptt.model_max_length = model_max_length + gptt.padding_side = 'right' + + cfg = { + 'dataset': { + 'local': 'dummy-path', + 'remote': 'dummy-path', + 'split': 'train', + 'max_seq_len': model_max_length, + 'shuffle': True, + 'shuffle_seed': 0, + 'eos_token_id': gptt.eos_token_id, + }, + 'drop_last': False, + 'num_workers': 0, + 'prefetch_factor': None, + 'pin_memory': False, + 'persistent_workers': False, + 'timeout': 0, + } + + ds_mock = MagicMock(spec=StreamingTextDataset) + ds_mock.tokenizer = gptt + monkeypatch.setattr( + 'llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, + **kwargs: ds_mock, + ) + dl = build_text_dataloader( + **cfg, + tokenizer=gptt, + device_batch_size=1, + ) + + batch_strings = [ + 'hello hey' + gptt.eos_token + ' the quick brown fox jumps', + ] + + batch_tokenized = [gptt(b, padding=False) for b in batch_strings] + + batch_tokenized = [b['input_ids'] for b in batch_tokenized] + + batch = dl.dataloader.collate_fn(batch_tokenized) # type: ignore + + for k, v in batch.items(): # type: ignore + if isinstance(v, torch.Tensor): + batch[k] = composer_device.tensor_to_device(v) # type: ignore + + attention_impl = 'flash' + + conf_path = 'scripts/train/yamls/pretrain/testing.yaml' + with open(conf_path) as f: + test_cfg = om.load(f) + + assert isinstance(test_cfg, DictConfig) + + attn_config = { + 'attn_type': 'grouped_query_attention', + 'attn_impl': attention_impl, + 'attn_uses_sequence_id': True, + 'alibi': False, + 'rope': True, + 'rope_theta': 10000, + 'rope_impl': 'dail', + 'rope_dail_config': { + 'type': 'original', + 'pos_idx_in_fp32': True, + 'xpos_scale_base': 512, + }, + } + attn_config['kv_n_heads'] = 4 + + test_cfg.model.init_device = 'cpu' + test_cfg.model.init_config = { + 'name': 'baseline_', + 'init_std': 0.02, + } + test_cfg.model.attn_config = attn_config + test_cfg.model.n_layers = 2 + test_cfg.model.n_heads = 8 + test_cfg.model.d_model = 128 + + test_cfg = dict(om.to_container(test_cfg, resolve=True)) # type: ignore + + model = build_composer_model( + name=test_cfg['model']['name'], + cfg=test_cfg['model'], + tokenizer=gptt, + ) + assert model.shift_labels == True + model.shift_labels = shift_labels + + model = composer_device.module_to_device(model) + + with get_precision_context('amp_bf16'): + output = model(batch) + loss = model.loss(output, batch) + + assert isinstance(loss, torch.Tensor) + + callback = construct_from_registry( + name='loss_perp_v_len', + registry=registry.callbacks, + kwargs={ + 'log_batch_interval': 100, + 'compute_batch_interval': 1, + }, + ) + + callback.loss_perp_v_len = callback.loss_perp_v_len.to(loss.device) + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceGPU(), + ) + logger = Logger(state) + state.outputs = output + state.batch = batch + + callback.after_backward(state, logger) + current_metric_dict = callback.loss_perp_v_len.compute() + + mean_loss_seq_id = torch.sum( + current_metric_dict['mean_loss_seq_id_v_len'] * + current_metric_dict['sum_length_seq_id'], + ) / torch.sum(current_metric_dict['sum_length_seq_id']) + mean_loss = torch.sum( + current_metric_dict['mean_loss_v_len'] * + current_metric_dict['sum_length'], + ) / torch.sum(current_metric_dict['sum_length']) + assert torch.allclose(loss, mean_loss_seq_id) + assert torch.allclose(loss, mean_loss)