From 9c0ba844101ea8cdf4bd3b8d1740605d052354aa Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 3 Oct 2023 11:55:05 -0400 Subject: [PATCH] Free outputs callback (#2598) * free train metrics * lint * int * rename * add callback * import * wrap * fix more tests * Update composer/callbacks/free_outputs.py Co-authored-by: Charles Tang --------- Co-authored-by: Charles Tang --- composer/callbacks/__init__.py | 2 ++ composer/callbacks/free_outputs.py | 16 ++++++++++++++++ composer/callbacks/generate.py | 1 + tests/callbacks/callback_settings.py | 10 +++++++--- 4 files changed, 26 insertions(+), 3 deletions(-) create mode 100644 composer/callbacks/free_outputs.py diff --git a/composer/callbacks/__init__.py b/composer/callbacks/__init__.py index 5d22d2cc57..1fdfed4767 100644 --- a/composer/callbacks/__init__.py +++ b/composer/callbacks/__init__.py @@ -10,6 +10,7 @@ from composer.callbacks.checkpoint_saver import CheckpointSaver from composer.callbacks.early_stopper import EarlyStopper from composer.callbacks.export_for_inference import ExportForInferenceCallback +from composer.callbacks.free_outputs import FreeOutputs from composer.callbacks.generate import Generate from composer.callbacks.health_checker import HealthChecker from composer.callbacks.image_visualizer import ImageVisualizer @@ -38,4 +39,5 @@ 'RuntimeEstimator', 'SystemMetricsMonitor', 'Generate', + 'FreeOutputs', ] diff --git a/composer/callbacks/free_outputs.py b/composer/callbacks/free_outputs.py new file mode 100644 index 0000000000..f8cabe24bd --- /dev/null +++ b/composer/callbacks/free_outputs.py @@ -0,0 +1,16 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Free train metrics.""" + +import torch + +from composer.core import Callback, State +from composer.loggers import Logger + + +class FreeOutputs(Callback): + """Free train metrics on AFTER_LOSS to reduce peak memory usage if not using train metrics.""" + + def after_loss(self, state: State, logger: Logger) -> None: + state.outputs = torch.Tensor() diff --git a/composer/callbacks/generate.py b/composer/callbacks/generate.py index 25206bb0fb..ef854e6c0d 100644 --- a/composer/callbacks/generate.py +++ b/composer/callbacks/generate.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Periodically log generations from a set of prompts.""" + from typing import Any, List, Optional, Union, cast from composer.callbacks.utils import create_interval_scheduler diff --git a/tests/callbacks/callback_settings.py b/tests/callbacks/callback_settings.py index 01cca57be3..ef9fe12187 100644 --- a/tests/callbacks/callback_settings.py +++ b/tests/callbacks/callback_settings.py @@ -11,8 +11,9 @@ import composer.loggers import composer.profiler from composer import Callback -from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, Generate, HealthChecker, ImageVisualizer, - MemoryMonitor, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor, ThresholdStopper) +from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, FreeOutputs, Generate, HealthChecker, + ImageVisualizer, MemoryMonitor, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor, + ThresholdStopper) from composer.loggers import (CometMLLogger, ConsoleLogger, LoggerDestination, MLFlowLogger, ProgressBarLogger, RemoteUploaderDownloader, TensorboardLogger, WandBLogger) from composer.models.base import ComposerModel @@ -223,5 +224,8 @@ def get_cb_model_and_datasets(cb: Callback, ) return (configure_tiny_gpt2_hf_model(), dummy_gpt_lm_dataloader(size=dl_size), dummy_gpt_lm_dataloader(size=dl_size)) - return (SimpleModel(), DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs), + model = SimpleModel() + if isinstance(cb, FreeOutputs): + model.get_metrics = lambda is_train=False: {} + return (model, DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs), DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs))