From e2678b9475a90c5ef56a3ed4fb35d81a651848f7 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Tue, 11 Jun 2024 14:55:22 -0700 Subject: [PATCH 01/13] yo --- composer/metrics/nlp.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index e6877292cf..81d5d3802e 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -83,7 +83,14 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_index = ignore_index - self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss + self.loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + except ImportError: + log.debug( + 'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' + + 'to compute LanguageCrossEntropy metric, which will be slower.', + ) self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum') From 502b8281cde88701ce90196cc7d244710ae0ce7b Mon Sep 17 00:00:00 2001 From: Saaketh Date: Tue, 11 Jun 2024 17:39:48 -0700 Subject: [PATCH 02/13] slam --- composer/metrics/nlp.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index 81d5d3802e..15212e1cd2 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -86,11 +86,16 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100): try: from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss self.loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + log.debug( + 'Found `flash_attn` installation. Using CrossEntropyLoss from `flash_attn`' + + 'to compute LanguageCrossEntropy metric, which will be faster.', + ) except ImportError: log.debug( 'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' + 'to compute LanguageCrossEntropy metric, which will be slower.', ) + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum') From ed0a21926a7bf00f6a577721c0fd1452025417f2 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 14 Jun 2024 15:37:52 -0700 Subject: [PATCH 03/13] cuda --- composer/metrics/nlp.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index 15212e1cd2..d619297624 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -85,11 +85,18 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100): self.ignore_index = ignore_index try: from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss - self.loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum') - log.debug( - 'Found `flash_attn` installation. Using CrossEntropyLoss from `flash_attn`' + - 'to compute LanguageCrossEntropy metric, which will be faster.', - ) + if torch.cuda.is_available(): + self.loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + log.debug( + 'Found `flash_attn` installation. Using CrossEntropyLoss from `flash_attn`' + + 'to compute LanguageCrossEntropy metric, which will be faster.', + ) + else: + log.debug( + 'No cuda devices available. Using torch.nn.CrossEntropyLoss ' + + 'to compute LanguageCrossEntropy metric.', + ) + self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') except ImportError: log.debug( 'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' + From 48cfac0c3b0b7a81db9cd7e287e4daf2d6da03f8 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 14 Jun 2024 16:05:54 -0700 Subject: [PATCH 04/13] cuda checks --- .github/workflows/pr-cpu.yaml | 2 +- composer/metrics/nlp.py | 31 +++++++++++++++---------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/.github/workflows/pr-cpu.yaml b/.github/workflows/pr-cpu.yaml index 1bdb383823..12f471749e 100644 --- a/.github/workflows/pr-cpu.yaml +++ b/.github/workflows/pr-cpu.yaml @@ -22,7 +22,7 @@ jobs: markers: not daily and not remote and not gpu and not doctest pytest_command: coverage run -m pytest - name: cpu-3.11-2.3 - container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 + container: mosaicml/pytorch:2.3.1_cpu-python3.11-ubuntu20.04 markers: not daily and not remote and not gpu and not doctest pytest_command: coverage run -m pytest - name: cpu-doctest diff --git a/composer/metrics/nlp.py b/composer/metrics/nlp.py index d619297624..c1562e5936 100644 --- a/composer/metrics/nlp.py +++ b/composer/metrics/nlp.py @@ -83,26 +83,21 @@ def __init__(self, dist_sync_on_step: bool = False, ignore_index: int = -100): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_index = ignore_index + self.flash_loss_fn = None try: from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss + log.debug( + 'Found `flash_attn` installation. Using CrossEntropyLoss from `flash_attn`' + + 'to compute LanguageCrossEntropy metric for CUDA tensors, which will be faster.', + ) + self.flash_loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + except ImportError: if torch.cuda.is_available(): - self.loss_fn = FusedCrossEntropyLoss(ignore_index=ignore_index, reduction='sum') - log.debug( - 'Found `flash_attn` installation. Using CrossEntropyLoss from `flash_attn`' + - 'to compute LanguageCrossEntropy metric, which will be faster.', - ) - else: log.debug( - 'No cuda devices available. Using torch.nn.CrossEntropyLoss ' + - 'to compute LanguageCrossEntropy metric.', + 'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' + + 'to compute LanguageCrossEntropy metric for CUDA tensors, which will be slower.', ) - self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') - except ImportError: - log.debug( - 'Package `flash_attn` not installed. Using torch.nn.CrossEntropyLoss ' + - 'to compute LanguageCrossEntropy metric, which will be slower.', - ) - self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') + self.torch_loss_fn = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='sum') self.add_state('sum_loss', default=torch.tensor(0.), dist_reduce_fx='sum') self.add_state('total_items', default=torch.tensor(0), dist_reduce_fx='sum') @@ -123,7 +118,11 @@ def update(self, output: Union[Mapping, Tensor], target: Tensor) -> None: target = target.view(-1) logits = logits.view(target.shape[0], -1) - losses = self.loss_fn(logits, target) + # Use Flash attn's CE loss function, if available, if inputs are both CUDA tensors. + if self.flash_loss_fn is not None and target.is_cuda and logits.is_cuda: + losses = self.flash_loss_fn(logits, target) + else: + losses = self.torch_loss_fn(logits, target) total_items = (target != self.ignore_index).sum() self.total_items += total_items #type: ignore (third-party) From fb1268e3e942b3a5015aa6f5601c3e2ab72afb1e Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 14 Jun 2024 16:29:32 -0700 Subject: [PATCH 05/13] test --- tests/metrics/test_nlp_metrics.py | 82 +++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 7fe854bd96..8bc2e5cbcf 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -14,6 +14,7 @@ LanguagePerplexity, MaskedAccuracy, ) +from tests.common import device @pytest.mark.parametrize('ignore_index', [-100]) @@ -50,12 +51,93 @@ def test_masked_accuracy(ignore_index, num_classes): assert abs(final_acc - (1.0 / num_classes)) < 0.02 +@device('cpu', 'gpu') @pytest.mark.parametrize('ignore_index', [-100]) @pytest.mark.parametrize('batch_size', [1e2, 1e3]) @pytest.mark.parametrize('sequence_length', [128]) @pytest.mark.parametrize('num_classes', [2, 10]) @pytest.mark.parametrize('minibatch_size', [56, 256, 768]) +@pytest.mark.parametrize('tensor_device', ['cpu', 'gpu']) def test_cross_entropy( + device: str, + batch_size: float, + ignore_index: Optional[int], + sequence_length: int, + num_classes: int, + minibatch_size: int, + tensor_device: str, +): + """Sanity check to make sure that batched CrossEntropyLoss matches the expected performance. + + Generates a predicted distribution from a normal distribution, and a ground truth from a normal distribution. + Verifies Cross Entropy Loss against the baseline performance. + + Args: + device (str): the device to run the test on + batch_size (int): how many samples are in each batch + ignore_index (Optional[int]): if present, the class index to ignore in accuracy calculations. + sequence_length (int): the length of the generated sequence + num_classes (int): the number of classes in the classification task + minibatch_size (int): the minibatch size to simulate for model predictions + tensor_device (str): which device the input tensors to the metric are on + """ + + if device == 'cpu' and tensor_device == 'gpu': + pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.') + + batch_size = int(batch_size) + generated_preds = torch.randn((batch_size, sequence_length, num_classes)) + generated_true = torch.randint(low=0, high=num_classes, size=(batch_size, sequence_length)) + + assert ignore_index is not None + torchmetrics_xent = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index) + ce_with_keys_metric = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index) + + if device == 'gpu': + assert torchmetrics_xent.flash_loss_fn is not None + + labels_mask = torch.rand((batch_size, sequence_length)) + labels_mask[labels_mask > 0.8] = 1 + labels_mask[labels_mask <= 0.8] = 0 + labels_mask = labels_mask.bool() + generated_true[labels_mask] = ignore_index + + num_batches = math.ceil(batch_size / minibatch_size) + for batch_idx in range(num_batches): + begin_idx = (batch_idx * minibatch_size) + end_idx = ((batch_idx + 1) * minibatch_size) + preds_subset = generated_preds[begin_idx:end_idx] + true_subset = generated_true[begin_idx:end_idx] + + if tensor_device == 'cpu': + preds_subset = preds_subset.cpu() + true_subset = true_subset.cpu() + if tensor_device == 'gpu': + preds_subset = preds_subset.cuda() + true_subset = true_subset.cuda() + + torchmetrics_xent.update(preds_subset, true_subset) + ce_with_keys_metric.update( + { + 'logits': preds_subset.view(-1, num_classes), + 'loss': cross_entropy(preds_subset.view(-1, num_classes), true_subset.view(-1)), + }, + true_subset.view(-1), + ) + + torchmetrics_loss = torchmetrics_xent.compute() + ce_with_keys_loss = ce_with_keys_metric.compute() + correct_loss = cross_entropy(generated_preds.view(-1, num_classes), generated_true.view(-1)) + assert torchmetrics_loss == ce_with_keys_loss + assert torch.isclose(correct_loss, torchmetrics_loss) + + +@pytest.mark.parametrize('ignore_index', [-100]) +@pytest.mark.parametrize('batch_size', [1e2, 1e3]) +@pytest.mark.parametrize('sequence_length', [128]) +@pytest.mark.parametrize('num_classes', [2, 10]) +@pytest.mark.parametrize('minibatch_size', [56, 256, 768]) +def test_torch_cpu_cross_entropy( batch_size: float, ignore_index: Optional[int], sequence_length: int, From 4c6b6ae598b58846f8c258cc37848cee33e3d99d Mon Sep 17 00:00:00 2001 From: Saaketh Date: Fri, 14 Jun 2024 20:48:45 -0700 Subject: [PATCH 06/13] fix_test --- tests/metrics/test_nlp_metrics.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 8bc2e5cbcf..9b198003d3 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -93,6 +93,13 @@ def test_cross_entropy( torchmetrics_xent = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index) ce_with_keys_metric = LanguageCrossEntropy(dist_sync_on_step=False, ignore_index=ignore_index) + if tensor_device == 'cpu': + torchmetrics_xent = torchmetrics_xent.to('cpu') + ce_with_keys_metric = ce_with_keys_metric.to('cpu') + elif tensor_device == 'gpu': + torchmetrics_xent = torchmetrics_xent.to('cuda') + ce_with_keys_metric = ce_with_keys_metric.to('cuda') + if device == 'gpu': assert torchmetrics_xent.flash_loss_fn is not None @@ -112,7 +119,7 @@ def test_cross_entropy( if tensor_device == 'cpu': preds_subset = preds_subset.cpu() true_subset = true_subset.cpu() - if tensor_device == 'gpu': + elif tensor_device == 'gpu': preds_subset = preds_subset.cuda() true_subset = true_subset.cuda() From 0084af589db742c4520fd350037b59a1e1bdefd2 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Sun, 16 Jun 2024 12:36:28 -0700 Subject: [PATCH 07/13] gloo --- composer/devices/device_gpu.py | 2 ++ tests/metrics/test_nlp_metrics.py | 1 + 2 files changed, 3 insertions(+) diff --git a/composer/devices/device_gpu.py b/composer/devices/device_gpu.py index 19cb0a774a..9b9bb9bb35 100644 --- a/composer/devices/device_gpu.py +++ b/composer/devices/device_gpu.py @@ -42,6 +42,8 @@ def __init__( ): if not torch.cuda.is_available(): raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.') + if torch.distributed.is_gloo_available(): + DeviceGPU.dist_backend = 'cuda:nccl,gpu:gloo' if device_id is None: device_id = dist.get_local_rank() self._device = torch.device(f'cuda:{device_id}') diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 9b198003d3..bf4b6a25bf 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +import time from typing import Optional import pytest From f8ee4c36f02c396ae7c102d69fbe01c0bf88eba5 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Sun, 16 Jun 2024 12:40:58 -0700 Subject: [PATCH 08/13] gloo --- composer/devices/device_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/devices/device_gpu.py b/composer/devices/device_gpu.py index 9b9bb9bb35..ce5cf38744 100644 --- a/composer/devices/device_gpu.py +++ b/composer/devices/device_gpu.py @@ -43,7 +43,7 @@ def __init__( if not torch.cuda.is_available(): raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.') if torch.distributed.is_gloo_available(): - DeviceGPU.dist_backend = 'cuda:nccl,gpu:gloo' + DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo' if device_id is None: device_id = dist.get_local_rank() self._device = torch.device(f'cuda:{device_id}') From 7f05db7058158b42b0103c33b9bb80367457810c Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 17 Jun 2024 10:43:59 -0700 Subject: [PATCH 09/13] lint --- composer/devices/device_gpu.py | 3 ++- tests/checkpoint/test_state_dict.py | 5 ++++- tests/metrics/test_nlp_metrics.py | 1 - 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/composer/devices/device_gpu.py b/composer/devices/device_gpu.py index ce5cf38744..401368576e 100644 --- a/composer/devices/device_gpu.py +++ b/composer/devices/device_gpu.py @@ -12,6 +12,7 @@ import torch.backends.cudnn import torch.cuda import torch.cuda.amp +import torch.distributed as torch_dist import torch.utils.data from composer.devices.device import Device @@ -42,7 +43,7 @@ def __init__( ): if not torch.cuda.is_available(): raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.') - if torch.distributed.is_gloo_available(): + if torch_dist.is_gloo_available(): DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo' if device_id is None: device_id = dist.get_local_rank() diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index af0ca34961..27ce77b660 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -530,7 +530,10 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz assert 'model_name' in metadata_sd assert 'dist_backend' in metadata_sd - assert metadata_sd['dist_backend'] == 'nccl' + if torch.distributed.is_gloo_available(): + assert metadata_sd['dist_backend'] == 'cuda:nccl,cpu:gloo' + else: + assert metadata_sd['dist_backend'] == 'nccl' @pytest.mark.filterwarnings('ignore:SWA has') diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index bf4b6a25bf..9b198003d3 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import math -import time from typing import Optional import pytest From 636b1571a5f82e1e4a31888f94daf09fe0d6fa3f Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 17 Jun 2024 11:13:38 -0700 Subject: [PATCH 10/13] lint --- tests/checkpoint/test_state_dict.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index 27ce77b660..bd14154dc9 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -7,6 +7,7 @@ import pytest import torch +import torch.distributed as torch_dist from packaging import version from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim import adam @@ -530,7 +531,7 @@ def test_get_metadata_sharded_model(model_type: str, tensor_type: str, world_siz assert 'model_name' in metadata_sd assert 'dist_backend' in metadata_sd - if torch.distributed.is_gloo_available(): + if torch_dist.is_gloo_available(): assert metadata_sd['dist_backend'] == 'cuda:nccl,cpu:gloo' else: assert metadata_sd['dist_backend'] == 'nccl' From e2e56ab84991ba9f88723f1547ce2db7eab1fb23 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 5 Aug 2024 20:48:15 -0400 Subject: [PATCH 11/13] yo --- composer/devices/device_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/devices/device_gpu.py b/composer/devices/device_gpu.py index c17dee3a3a..0d8fc25041 100644 --- a/composer/devices/device_gpu.py +++ b/composer/devices/device_gpu.py @@ -44,7 +44,7 @@ def __init__( ): if not torch.cuda.is_available(): raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.') - if torch_dist.is_gloo_available() and version.parse(torch.__version__) >= version.parse('2.3.0'): + if torch_dist.is_gloo_available() and #version.parse(torch.__version__) >= version.parse('2.3.0'): # Composer checkpoint load / save from before torch 2.3.0 is not compatible with gloo + nccl backends. DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo' if device_id is None: From 7b2ae26eea99e92e670a1ac2cc8e023045c23a61 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 5 Aug 2024 20:48:58 -0400 Subject: [PATCH 12/13] yo --- composer/devices/device_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/devices/device_gpu.py b/composer/devices/device_gpu.py index 0d8fc25041..c17dee3a3a 100644 --- a/composer/devices/device_gpu.py +++ b/composer/devices/device_gpu.py @@ -44,7 +44,7 @@ def __init__( ): if not torch.cuda.is_available(): raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.') - if torch_dist.is_gloo_available() and #version.parse(torch.__version__) >= version.parse('2.3.0'): + if torch_dist.is_gloo_available() and version.parse(torch.__version__) >= version.parse('2.3.0'): # Composer checkpoint load / save from before torch 2.3.0 is not compatible with gloo + nccl backends. DeviceGPU.dist_backend = 'cuda:nccl,cpu:gloo' if device_id is None: From 3ee41e7390f8ae2108944f24ac2d737f5ee4d716 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Mon, 5 Aug 2024 20:52:03 -0400 Subject: [PATCH 13/13] cputest --- tests/metrics/test_nlp_metrics.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/metrics/test_nlp_metrics.py b/tests/metrics/test_nlp_metrics.py index 9b198003d3..0f6d989102 100644 --- a/tests/metrics/test_nlp_metrics.py +++ b/tests/metrics/test_nlp_metrics.py @@ -6,6 +6,7 @@ import pytest import torch +from packaging import version from torch.nn.functional import cross_entropy from composer.metrics.nlp import ( @@ -82,8 +83,11 @@ def test_cross_entropy( tensor_device (str): which device the input tensors to the metric are on """ - if device == 'cpu' and tensor_device == 'gpu': - pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.') + if device == 'cpu': + if tensor_device == 'gpu': + pytest.skip('Skipping test that would try to use GPU tensors when only CPU is available.') + if version.parse(torch.__version__) < version.parse('2.3.0'): + pytest.skip('Skipping test that would try to use gloo + nccl backend on torch < 2.3.0.') batch_size = int(batch_size) generated_preds = torch.randn((batch_size, sequence_length, num_classes))