Skip to content

Commit

Permalink
Add strategy to store results in evaluation loop (#30267)
Browse files Browse the repository at this point in the history
* Add evaluation loop container for interm. results

* Add tests for EvalLoopContainer

* Formatting

* Fix padding_index in test and typo

* Move EvalLoopContainer to pr_utils to avoid additional imports

* Fix `eval_do_concat_batches` arg description

* Fix EvalLoopContainer import
  • Loading branch information
qubvel authored and ydshieh committed Apr 23, 2024
1 parent c7b7418 commit 4c86c7d
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 60 deletions.
85 changes: 25 additions & 60 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
)
from .trainer_pt_utils import (
DistributedTensorGatherer,
EvalLoopContainer,
IterableDatasetShard,
LabelSmoother,
LayerWiseDummyOptimizer,
Expand Down Expand Up @@ -3627,20 +3628,14 @@ def evaluation_loop(
self._past = None

# Initialize containers
# losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
losses_host = None
preds_host = None
labels_host = None
inputs_host = None

# losses/preds/labels on CPU (final containers)
all_losses = None
all_preds = None
all_labels = None
all_inputs = None
# Will be useful when we have an iterable dataset so don't know its length.
all_losses = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_preds = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_labels = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)
all_inputs = EvalLoopContainer(self.args.eval_do_concat_batches, padding_index=-100)

# Will be useful when we have an iterable dataset so don't know its length.
observed_num_examples = 0

# Main evaluation loop
for step, inputs in enumerate(dataloader):
# Update the observed num examples
Expand All @@ -3659,56 +3654,33 @@ def evaluation_loop(
if is_torch_xla_available():
xm.mark_step()

# Update containers on host
# Update containers
if loss is not None:
losses = self.gather_function((loss.repeat(batch_size)))
losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100)
if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
all_losses.add(losses)
if inputs_decode is not None:
inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100)
inputs_decode = self.gather_function((inputs_decode))
inputs_host = (
inputs_decode
if inputs_host is None
else nested_concat(inputs_host, inputs_decode, padding_index=-100)
)
all_inputs.add(inputs_decode)
if logits is not None:
logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
if self.preprocess_logits_for_metrics is not None:
logits = self.preprocess_logits_for_metrics(logits, labels)
logits = self.gather_function((logits))
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)

all_preds.add(logits)
if labels is not None:
labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100)
labels = self.gather_function((labels))
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
all_labels.add(labels)

self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode
if all_inputs is None
else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = (
labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
)

# Set back to None to begin a new accumulation
losses_host, preds_host, inputs_host, labels_host = None, None, None, None
all_losses.to_cpu_and_numpy()
all_preds.to_cpu_and_numpy()
all_labels.to_cpu_and_numpy()
all_inputs.to_cpu_and_numpy()

# After all calls to `.gather_function`, reset to `gather_for_metrics`:
self.gather_function = self.accelerator.gather_for_metrics
Expand All @@ -3717,20 +3689,10 @@ def evaluation_loop(
delattr(self, "_past")

# Gather all remaining tensors and put them back on the CPU
if losses_host is not None:
losses = nested_numpify(losses_host)
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
if preds_host is not None:
logits = nested_numpify(preds_host)
all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
if inputs_host is not None:
inputs_decode = nested_numpify(inputs_host)
all_inputs = (
inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
)
if labels_host is not None:
labels = nested_numpify(labels_host)
all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
all_losses = all_losses.get_arrays()
all_preds = all_preds.get_arrays()
all_labels = all_labels.get_arrays()
all_inputs = all_inputs.get_arrays()

# Number of samples
if has_length(eval_dataset):
Expand Down Expand Up @@ -3761,7 +3723,9 @@ def evaluation_loop(
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
metrics = denumpify_detensorize(metrics)

if all_losses is not None:
if isinstance(all_losses, list) and all_losses:
metrics[f"{metric_key_prefix}_loss"] = np.concatenate(all_losses).mean().item()
elif isinstance(all_losses, np.ndarray):
metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
if hasattr(self, "jit_compilation_time"):
metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
Expand Down Expand Up @@ -4204,6 +4168,7 @@ def prediction_loop(
logger.info(f"***** Running {description} *****")
logger.info(f" Num examples = {num_examples}")
logger.info(f" Batch size = {batch_size}")

losses_host: torch.Tensor = None
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
Expand Down
52 changes: 52 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,58 @@ def __iter__(self):
return iter(indices)


class EvalLoopContainer:
"""
Container to store intermediate results of evaluation loop
Args:
do_nested_concat (`bool`, *optional*, defaults to `True`):
If set to `True`, each iteration will recursively concatenate a new object containing tensors to
the existing stored tensors, provided that the structure of the existing object and the new one
are identical. If set to `False`, all newly added tensors will be stored in a list.
padding_index (`int`, *optional*, defaults to -100):
Value used to pad tensors of different shapes when `do_nested_concat=True`.
"""

def __init__(self, do_nested_concat: bool = True, padding_index: int = -100):
self.do_nested_concat = do_nested_concat
self.padding_index = padding_index
self.tensors = None
self.arrays = None

def add(self, tensors) -> None:
"""Add tensors to the stored objects. If `do_nested_concat=True`, the tensors will be concatenated recursively."""
if self.tensors is None:
self.tensors = tensors if self.do_nested_concat else [tensors]
elif self.do_nested_concat:
self.tensors = nested_concat(self.tensors, tensors, padding_index=self.padding_index)
else:
self.tensors.append(tensors)

def to_cpu_and_numpy(self) -> None:
"""Move tensors in stored objects to CPU and convert them to numpy arrays."""

# Check if we have something to add, if not just return
if self.tensors is None:
return

new_arrays = nested_numpify(self.tensors)
if self.arrays is None:
self.arrays = new_arrays
elif self.do_nested_concat:
self.arrays = nested_concat(self.arrays, new_arrays, padding_index=self.padding_index)
else:
self.arrays.extend(new_arrays)

# reset device tensors after adding to cpu
self.tensors = None

def get_arrays(self):
"""Returns the numpified and moved to CPU stored objects."""
self.to_cpu_and_numpy()
return self.arrays


class SequentialDistributedSampler(Sampler):
"""
Distributed Sampler that subsamples indices sequentially, making it easier to collate all results at the end.
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ class TrainingArguments:
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
that need inputs, predictions and references for scoring calculation in Metric class.
eval_do_concat_batches (`bool`, *optional*, defaults to `True`):
Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`,
will instead store them as lists, with each batch kept separate.
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding
CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
Expand Down Expand Up @@ -1289,6 +1292,12 @@ class TrainingArguments:
include_inputs_for_metrics: bool = field(
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
)
eval_do_concat_batches: bool = field(
default=True,
metadata={
"help": "Whether to recursively concat inputs/losses/labels/predictions across batches. If `False`, will instead store them as lists, with each batch kept separate."
},
)
# Deprecated arguments
fp16_backend: str = field(
default="auto",
Expand Down
90 changes: 90 additions & 0 deletions tests/trainer/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DistributedLengthGroupedSampler,
DistributedSamplerWithLoop,
DistributedTensorGatherer,
EvalLoopContainer,
IterableDatasetShard,
LabelSmoother,
LengthGroupedSampler,
Expand Down Expand Up @@ -497,3 +498,92 @@ def info(self, msg):
remove_columns_collator(data_batch)
self.assertEqual(logger.called, 1)
self.assertIn("col3", logger.last_msg)

def test_eval_loop_container(self):
batch_1 = [
torch.ones([8, 5]),
{"loss": torch.tensor(1.0)},
(torch.ones([8, 2, 3]), torch.ones([8, 2])),
]
batch_2 = [
torch.ones([4, 5]),
{"loss": torch.tensor(2.0)},
(torch.ones([4, 2, 3]), torch.ones([4, 6])),
]

concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
concat_container.add(batch_1)
concat_container.add(batch_2)
concat_container.to_cpu_and_numpy()
arrays = concat_container.get_arrays()

# Test two nested batches concatenation
self.assertIsInstance(arrays, list)
self.assertEqual(len(arrays), 3)
self.assertIsInstance(arrays[0], np.ndarray)
self.assertEqual(arrays[0].shape, (12, 5))
self.assertIsInstance(arrays[1], dict)
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
self.assertEqual(arrays[1]["loss"].shape, (2,))
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0, 2.0])))
self.assertIsInstance(arrays[2], tuple)
self.assertEqual(len(arrays[2]), 2)
self.assertEqual(arrays[2][0].shape, (12, 2, 3))
self.assertEqual(arrays[2][1].shape, (12, 6))
# check that first batch padded with padding index -100 after concatenation
self.assertEqual(arrays[2][1][0][2], -100)

# Test two batches with no concatenation
list_container = EvalLoopContainer(do_nested_concat=False)
list_container.add(batch_1)
list_container.add(batch_2)
list_container.to_cpu_and_numpy()
arrays = list_container.get_arrays()

self.assertEqual(len(arrays), 2)
self.assertIsInstance(arrays, list)
np_batch_1, np_batch_2 = arrays

self.assertIsInstance(np_batch_1, list)
self.assertEqual(len(np_batch_1), 3)
self.assertIsInstance(np_batch_1[0], np.ndarray)
self.assertIsInstance(np_batch_1[1], dict)
self.assertIsInstance(np_batch_1[2], tuple)
self.assertEqual(np_batch_1[0].shape, (8, 5))
self.assertEqual(np_batch_1[1]["loss"].shape, ())
self.assertEqual(np_batch_1[2][0].shape, (8, 2, 3))
self.assertEqual(np_batch_1[2][1].shape, (8, 2))

self.assertIsInstance(np_batch_2, list)
self.assertEqual(len(np_batch_2), 3)
self.assertIsInstance(np_batch_2[0], np.ndarray)
self.assertIsInstance(np_batch_2[1], dict)
self.assertIsInstance(np_batch_2[2], tuple)
self.assertEqual(np_batch_2[0].shape, (4, 5))
self.assertEqual(np_batch_2[1]["loss"].shape, ())
self.assertEqual(np_batch_2[2][0].shape, (4, 2, 3))
self.assertEqual(np_batch_2[2][1].shape, (4, 6))

# Test no batches
none_arr = EvalLoopContainer(do_nested_concat=True, padding_index=-100).get_arrays()
self.assertIsNone(none_arr)

none_arr = EvalLoopContainer(do_nested_concat=False).get_arrays()
self.assertIsNone(none_arr)

# Test one batch
concat_container = EvalLoopContainer(do_nested_concat=True, padding_index=-100)
concat_container.add(batch_1)
arrays = concat_container.get_arrays()
self.assertIsInstance(arrays, list)
self.assertEqual(len(arrays), 3)
self.assertIsInstance(arrays[0], np.ndarray)
self.assertEqual(arrays[0].shape, (8, 5))
self.assertIsInstance(arrays[1], dict)
self.assertIsInstance(arrays[1]["loss"], np.ndarray)
self.assertEqual(arrays[1]["loss"].shape, ())
self.assertTrue(np.allclose(arrays[1]["loss"], np.array([1.0])))
self.assertIsInstance(arrays[2], tuple)
self.assertEqual(len(arrays[2]), 2)
self.assertEqual(arrays[2][0].shape, (8, 2, 3))
self.assertEqual(arrays[2][1].shape, (8, 2))

0 comments on commit 4c86c7d

Please sign in to comment.