diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index e6877292cf..c1562e5936 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -83,7 +83,21 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_index = ignore_index - self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + self.flash_loss_fn = None + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss + log.debug( + 'Found `flash_attn` installation. Using CrossEntropyLoss from `flash_attn`' + + 'to compute LanguageCrossEntropy metric for CUDA tensors, which will be faster.', + ) + self.flash_loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + except ImportError: + if torch.cuda.is_available(): + log.debug( + 'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' + + 'to compute LanguageCrossEntropy metric for CUDA tensors, which will be slower.', + ) + self.torch_loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum') @@ -104,7 +118,11 @@ def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None: target = target.view(-1) logits = logits.view(target.shape[0], -1) - losses = self.loss_fn(logits, target) + # Use Flash attn's CE loss function, if available, if inputs are both CUDA tensors. + if self.flash_loss_fn is not None and target.is_cuda and logits.is_cuda: + losses = self.flash_loss_fn(logits, target) + else: + losses = self.torch_loss_fn(logits, target) total_items = (target != self.ignore_index).sum() self.total_items += total_items #type: ignore (third-party) diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 7fe854bd96..0f6d989102 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -6,6 +6,7 @@ import pytest import torch +from packaging import version from torch.nn.functional import cross_entropy from composer.metrics.nlp import ( @@ -14,6 +15,7 @@ LanguagePerplexity, MaskedAccuracy, ) +from tests.common import device @pytest.mark.parametrize('ignore_index', [-100]) @@ -50,12 +52,103 @@ def test_masked_accuracy(ignore_index, num_classes): assert abs(final_acc - (1.0 / num_classes)) < 0.02 +@device('cpu', 'gpu') @pytest.mark.parametrize('ignore_index', [-100]) @pytest.mark.parametrize('batch_size', [1e2, 1e3]) @pytest.mark.parametrize('sequence_length', [128]) @pytest.mark.parametrize('num_classes', [2, 10]) @pytest.mark.parametrize('minibatch_size', [56, 256, 768]) +@pytest.mark.parametrize('tensor_device', ['cpu', 'gpu']) def test_cross_entropy( + device: str, + batch_size: float, + ignore_index: Optional[int], + sequence_length: int, + num_classes: int, + minibatch_size: int, + tensor_device: str, +): + """Sanity check to make sure that batched CrossEntropyLoss matches the expected performance. + + Generates a predicted distribution from a normal distribution, and a ground truth from a normal distribution. + Verifies Cross Entropy Loss against the baseline performance. + + Args: + device (str): the device to run the test on + batch_size (int): how many samples are in each batch + ignore_index (Optional[int]): if present, the class index to ignore in accuracy calculations. + sequence_length (int): the length of the generated sequence + num_classes (int): the number of classes in the classification task + minibatch_size (int): the minibatch size to simulate for model predictions + tensor_device (str): which device the input tensors to the metric are on + """ + + if device == 'cpu': + if tensor_device == 'gpu': + pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.') + if version.parse(torch.__version__) < version.parse('2.3.0'): + pytest.skip('Skipping test that would try to use gloo + nccl backend on torch < 2.3.0.') + + batch_size = int(batch_size) + generated_preds = torch.randn((batch_size, sequence_length, num_classes)) + generated_true = torch.randint(low=0, high=num_classes, size=(batch_size, sequence_length)) + + assert ignore_index is not None + torchmetrics_xent = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index) + ce_with_keys_metric = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index) + + if tensor_device == 'cpu': + torchmetrics_xent = torchmetrics_xent.to('cpu') + ce_with_keys_metric = ce_with_keys_metric.to('cpu') + elif tensor_device == 'gpu': + torchmetrics_xent = torchmetrics_xent.to('cuda') + ce_with_keys_metric = ce_with_keys_metric.to('cuda') + + if device == 'gpu': + assert torchmetrics_xent.flash_loss_fn is not None + + labels_mask = torch.rand((batch_size, sequence_length)) + labels_mask[labels_mask > 0.8] = 1 + labels_mask[labels_mask <= 0.8] = 0 + labels_mask = labels_mask.bool() + generated_true[labels_mask] = ignore_index + + num_batches = math.ceil(batch_size / minibatch_size) + for batch_idx in range(num_batches): + begin_idx = (batch_idx * minibatch_size) + end_idx = ((batch_idx + 1) * minibatch_size) + preds_subset = generated_preds[begin_idx:end_idx] + true_subset = generated_true[begin_idx:end_idx] + + if tensor_device == 'cpu': + preds_subset = preds_subset.cpu() + true_subset = true_subset.cpu() + elif tensor_device == 'gpu': + preds_subset = preds_subset.cuda() + true_subset = true_subset.cuda() + + torchmetrics_xent.update(preds_subset, true_subset) + ce_with_keys_metric.update( + { + 'logits': preds_subset.view(-1, num_classes), + 'loss': cross_entropy(preds_subset.view(-1, num_classes), true_subset.view(-1)), + }, + true_subset.view(-1), + ) + + torchmetrics_loss = torchmetrics_xent.compute() + ce_with_keys_loss = ce_with_keys_metric.compute() + correct_loss = cross_entropy(generated_preds.view(-1, num_classes), generated_true.view(-1)) + assert torchmetrics_loss == ce_with_keys_loss + assert torch.isclose(correct_loss, torchmetrics_loss) + + +@pytest.mark.parametrize('ignore_index', [-100]) +@pytest.mark.parametrize('batch_size', [1e2, 1e3]) +@pytest.mark.parametrize('sequence_length', [128]) +@pytest.mark.parametrize('num_classes', [2, 10]) +@pytest.mark.parametrize('minibatch_size', [56, 256, 768]) +def test_torch_cpu_cross_entropy( batch_size: float, ignore_index: Optional[int], sequence_length: int,