Skip to content

Commit

Permalink
Free outputs callback (#2598)
Browse files Browse the repository at this point in the history
* free train metrics

* lint

* int

* rename

* add callback

* import

* wrap

* fix more tests

* Update composer/callbacks/free_outputs.py

Co-authored-by: Charles Tang <[email protected]>

---------

Co-authored-by: Charles Tang <[email protected]>
  • Loading branch information
mvpatel2000 and j316chuck authored Oct 3, 2023
1 parent e2386b3 commit 9c0ba84
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.callbacks.early_stopper import EarlyStopper
from composer.callbacks.export_for_inference import ExportForInferenceCallback
from composer.callbacks.free_outputs import FreeOutputs
from composer.callbacks.generate import Generate
from composer.callbacks.health_checker import HealthChecker
from composer.callbacks.image_visualizer import ImageVisualizer
Expand Down Expand Up @@ -38,4 +39,5 @@
'RuntimeEstimator',
'SystemMetricsMonitor',
'Generate',
'FreeOutputs',
]
16 changes: 16 additions & 0 deletions composer/callbacks/free_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Free train metrics."""

import torch

from composer.core import Callback, State
from composer.loggers import Logger


class FreeOutputs(Callback):
"""Free train metrics on AFTER_LOSS to reduce peak memory usage if not using train metrics."""

def after_loss(self, state: State, logger: Logger) -> None:
state.outputs = torch.Tensor()
1 change: 1 addition & 0 deletions composer/callbacks/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

"""Periodically log generations from a set of prompts."""

from typing import Any, List, Optional, Union, cast

from composer.callbacks.utils import create_interval_scheduler
Expand Down
10 changes: 7 additions & 3 deletions tests/callbacks/callback_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import composer.loggers
import composer.profiler
from composer import Callback
from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, Generate, HealthChecker, ImageVisualizer,
MemoryMonitor, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor, ThresholdStopper)
from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, FreeOutputs, Generate, HealthChecker,
ImageVisualizer, MemoryMonitor, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor,
ThresholdStopper)
from composer.loggers import (CometMLLogger, ConsoleLogger, LoggerDestination, MLFlowLogger, ProgressBarLogger,
RemoteUploaderDownloader, TensorboardLogger, WandBLogger)
from composer.models.base import ComposerModel
Expand Down Expand Up @@ -223,5 +224,8 @@ def get_cb_model_and_datasets(cb: Callback,
)
return (configure_tiny_gpt2_hf_model(), dummy_gpt_lm_dataloader(size=dl_size),
dummy_gpt_lm_dataloader(size=dl_size))
return (SimpleModel(), DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs),
model = SimpleModel()
if isinstance(cb, FreeOutputs):
model.get_metrics = lambda is_train=False: {}
return (model, DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs),
DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs))

0 comments on commit 9c0ba84

Please sign in to comment.