diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 45b45992bf425a..92025cb979d331 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -82,6 +82,7 @@ ) from .trainer_pt_utils import ( DistributedTensorGatherer, + EvalLoopContainer, IterableDatasetShard, LabelSmoother, LayerWiseDummyOptimizer, @@ -3627,20 +3628,14 @@ def evaluation_loop( self._past = None # Initialize containers - # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) - losses_host = None - preds_host = None - labels_host = None - inputs_host = None - - # losses/preds/labels on CPU (final containers) - all_losses = None - all_preds = None - all_labels = None - all_inputs = None - # Will be useful when we have an iterable dataset so don't know its length. + all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100) + # Will be useful when we have an iterable dataset so don't know its length. observed_num_examples = 0 + # Main evaluation loop for step, inputs in enumerate(dataloader): # Update the observed num examples @@ -3659,56 +3654,33 @@ def evaluation_loop( if is_torch_xla_available(): xm.mark_step() - # Update containers on host + # Update containers if loss is not None: losses = self.gather_function((loss.repeat(batch_size))) - losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) - if labels is not None: - labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) + all_losses.add(losses) if inputs_decode is not None: inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) inputs_decode = self.gather_function((inputs_decode)) - inputs_host = ( - inputs_decode - if inputs_host is None - else nested_concat(inputs_host, inputs_decode, padding_index=-100) - ) + all_inputs.add(inputs_decode) if logits is not None: logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) logits = self.gather_function((logits)) - preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) - + all_preds.add(logits) if labels is not None: + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) labels = self.gather_function((labels)) - labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + all_labels.add(labels) self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: - if losses_host is not None: - losses = nested_numpify(losses_host) - all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: - logits = nested_numpify(preds_host) - all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if inputs_host is not None: - inputs_decode = nested_numpify(inputs_host) - all_inputs = ( - inputs_decode - if all_inputs is None - else nested_concat(all_inputs, inputs_decode, padding_index=-100) - ) - if labels_host is not None: - labels = nested_numpify(labels_host) - all_labels = ( - labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) - ) - - # Set back to None to begin a new accumulation - losses_host, preds_host, inputs_host, labels_host = None, None, None, None + all_losses.to_cpu_and_numpy() + all_preds.to_cpu_and_numpy() + all_labels.to_cpu_and_numpy() + all_inputs.to_cpu_and_numpy() # After all calls to `.gather_function`, reset to `gather_for_metrics`: self.gather_function = self.accelerator.gather_for_metrics @@ -3717,20 +3689,10 @@ def evaluation_loop( delattr(self, "_past") # Gather all remaining tensors and put them back on the CPU - if losses_host is not None: - losses = nested_numpify(losses_host) - all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) - if preds_host is not None: - logits = nested_numpify(preds_host) - all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) - if inputs_host is not None: - inputs_decode = nested_numpify(inputs_host) - all_inputs = ( - inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) - ) - if labels_host is not None: - labels = nested_numpify(labels_host) - all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + all_losses = all_losses.get_arrays() + all_preds = all_preds.get_arrays() + all_labels = all_labels.get_arrays() + all_inputs = all_inputs.get_arrays() # Number of samples if has_length(eval_dataset): @@ -3761,7 +3723,9 @@ def evaluation_loop( # To be JSON-serializable, we need to remove numpy types or zero-d tensors metrics = denumpify_detensorize(metrics) - if all_losses is not None: + if isinstance(all_losses, list) and all_losses: + metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item() + elif isinstance(all_losses, np.ndarray): metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() if hasattr(self, "jit_compilation_time"): metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time @@ -4204,6 +4168,7 @@ def prediction_loop( logger.info(f"***** Running {description} *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Batch size = {batch_size}") + losses_host: torch.Tensor = None preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 9ee670e94288b8..a4372ae78a79a2 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -299,6 +299,58 @@ def __iter__(self): return iter(indices) +class EvalLoopContainer: + """ + Container to store intermediate results of evaluation loop + + Args: + do_nested_concat (`bool`, *optional*, defaults to `True`): + If set to `True`, each iteration will recursively concatenate a new object containing tensors to + the existing stored tensors, provided that the structure of the existing object and the new one + are identical. If set to `False`, all newly added tensors will be stored in a list. + padding_index (`int`, *optional*, defaults to -100): + Value used to pad tensors of different shapes when `do_nested_concat=True`. + """ + + def __init__(self, do_nested_concat: bool = True, padding_index: int = -100): + self.do_nested_concat = do_nested_concat + self.padding_index = padding_index + self.tensors = None + self.arrays = None + + def add(self, tensors) -> None: + """Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively.""" + if self.tensors is None: + self.tensors = tensors if self.do_nested_concat else [tensors] + elif self.do_nested_concat: + self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index) + else: + self.tensors.append(tensors) + + def to_cpu_and_numpy(self) -> None: + """Move tensors in stored objects to CPU and convert them to numpy arrays.""" + + # Check if we have something to add, if not just return + if self.tensors is None: + return + + new_arrays = nested_numpify(self.tensors) + if self.arrays is None: + self.arrays = new_arrays + elif self.do_nested_concat: + self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index) + else: + self.arrays.extend(new_arrays) + + # reset device tensors after adding to cpu + self.tensors = None + + def get_arrays(self): + """Returns the numpified and moved to CPU stored objects.""" + self.to_cpu_and_numpy() + return self.arrays + + class SequentialDistributedSampler(Sampler): """ Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end. diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 8c226459fc3db5..338bb116dddece 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -670,6 +670,9 @@ class TrainingArguments: include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics that need inputs, predictions and references for scoring calculation in Metric class. + eval_do_concat_batches (`bool`, *optional*, defaults to `True`): + Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, + will instead store them as lists, with each batch kept separate. auto_find_batch_size (`bool`, *optional*, defaults to `False`) Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) @@ -1289,6 +1292,12 @@ class TrainingArguments: include_inputs_for_metrics: bool = field( default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."} ) + eval_do_concat_batches: bool = field( + default=True, + metadata={ + "help": "Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, will instead store them as lists, with each batch kept separate." + }, + ) # Deprecated arguments fp16_backend: str = field( default="auto", diff --git a/tests/trainer/test_trainer_utils.py b/tests/trainer/test_trainer_utils.py index ccf162677e9fce..a730ff07ccb2bb 100644 --- a/tests/trainer/test_trainer_utils.py +++ b/tests/trainer/test_trainer_utils.py @@ -35,6 +35,7 @@ DistributedLengthGroupedSampler, DistributedSamplerWithLoop, DistributedTensorGatherer, + EvalLoopContainer, IterableDatasetShard, LabelSmoother, LengthGroupedSampler, @@ -497,3 +498,92 @@ def info(self, msg): remove_columns_collator(data_batch) self.assertEqual(logger.called, 1) self.assertIn("col3", logger.last_msg) + + def test_eval_loop_container(self): + batch_1 = [ + torch.ones([8, 5]), + {"loss": torch.tensor(1.0)}, + (torch.ones([8, 2, 3]), torch.ones([8, 2])), + ] + batch_2 = [ + torch.ones([4, 5]), + {"loss": torch.tensor(2.0)}, + (torch.ones([4, 2, 3]), torch.ones([4, 6])), + ] + + concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100) + concat_container.add(batch_1) + concat_container.add(batch_2) + concat_container.to_cpu_and_numpy() + arrays = concat_container.get_arrays() + + # Test two nested batches concatenation + self.assertIsInstance(arrays, list) + self.assertEqual(len(arrays), 3) + self.assertIsInstance(arrays[0], np.ndarray) + self.assertEqual(arrays[0].shape, (12, 5)) + self.assertIsInstance(arrays[1], dict) + self.assertIsInstance(arrays[1]["loss"], np.ndarray) + self.assertEqual(arrays[1]["loss"].shape, (2,)) + self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0, 2.0]))) + self.assertIsInstance(arrays[2], tuple) + self.assertEqual(len(arrays[2]), 2) + self.assertEqual(arrays[2][0].shape, (12, 2, 3)) + self.assertEqual(arrays[2][1].shape, (12, 6)) + # check that first batch padded with padding index -100 after concatenation + self.assertEqual(arrays[2][1][0][2], -100) + + # Test two batches with no concatenation + list_container = EvalLoopContainer(do_nested_concat=False) + list_container.add(batch_1) + list_container.add(batch_2) + list_container.to_cpu_and_numpy() + arrays = list_container.get_arrays() + + self.assertEqual(len(arrays), 2) + self.assertIsInstance(arrays, list) + np_batch_1, np_batch_2 = arrays + + self.assertIsInstance(np_batch_1, list) + self.assertEqual(len(np_batch_1), 3) + self.assertIsInstance(np_batch_1[0], np.ndarray) + self.assertIsInstance(np_batch_1[1], dict) + self.assertIsInstance(np_batch_1[2], tuple) + self.assertEqual(np_batch_1[0].shape, (8, 5)) + self.assertEqual(np_batch_1[1]["loss"].shape, ()) + self.assertEqual(np_batch_1[2][0].shape, (8, 2, 3)) + self.assertEqual(np_batch_1[2][1].shape, (8, 2)) + + self.assertIsInstance(np_batch_2, list) + self.assertEqual(len(np_batch_2), 3) + self.assertIsInstance(np_batch_2[0], np.ndarray) + self.assertIsInstance(np_batch_2[1], dict) + self.assertIsInstance(np_batch_2[2], tuple) + self.assertEqual(np_batch_2[0].shape, (4, 5)) + self.assertEqual(np_batch_2[1]["loss"].shape, ()) + self.assertEqual(np_batch_2[2][0].shape, (4, 2, 3)) + self.assertEqual(np_batch_2[2][1].shape, (4, 6)) + + # Test no batches + none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=-100).get_arrays() + self.assertIsNone(none_arr) + + none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays() + self.assertIsNone(none_arr) + + # Test one batch + concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100) + concat_container.add(batch_1) + arrays = concat_container.get_arrays() + self.assertIsInstance(arrays, list) + self.assertEqual(len(arrays), 3) + self.assertIsInstance(arrays[0], np.ndarray) + self.assertEqual(arrays[0].shape, (8, 5)) + self.assertIsInstance(arrays[1], dict) + self.assertIsInstance(arrays[1]["loss"], np.ndarray) + self.assertEqual(arrays[1]["loss"].shape, ()) + self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0]))) + self.assertIsInstance(arrays[2], tuple) + self.assertEqual(len(arrays[2]), 2) + self.assertEqual(arrays[2][0].shape, (8, 2, 3)) + self.assertEqual(arrays[2][1].shape, (8, 2))