Skip to content

Commit

Permalink
Fix Profiler schedule skip_first (mosaicml#2992)
Browse files Browse the repository at this point in the history
* fix skip_first for resumption

* update doc

* v2

* move after_load callback to profiler

* fix unit tests

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
bigning and mvpatel2000 authored Feb 12, 2024
1 parent 6d4575d commit 375ea0c
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 10 deletions.
13 changes: 11 additions & 2 deletions composer/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pathlib
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Tuple, Union

from composer.core import Callback
from composer.loggers import Logger
from composer.profiler.json_trace_handler import JSONTraceHandler
from composer.profiler.marker import Marker
from composer.profiler.profiler_action import ProfilerAction
Expand All @@ -18,14 +20,14 @@
from composer.utils import ensure_tuple, parse_uri

if TYPE_CHECKING:
from composer.core import Callback, State
from composer.core import State

__all__ = ['Profiler']

log = logging.getLogger(__name__)


class Profiler:
class Profiler(Callback):
"""Composer Profiler.
See the :doc:`Profiling Guide </trainer/performance_tutorials/profiling>` for additional information.
Expand Down Expand Up @@ -118,6 +120,8 @@ def __init__(
self.schedule = schedule
self.state = None
self._callbacks: List[Callback] = []
# Used to count skip_first starting from resumption timestamp
self.resumption_batch_idx: int = 0
self.remote_filenames: List[str] = []
# First, add each remote file name to self.remote_filenames to create RemoteUploaderDownloader logger in trainer. [s3://bucket/path/to/file]
# Then modify remote file name to be a local path to pass into torch_profiler and system_profiler. e.g: path/to/file
Expand Down Expand Up @@ -185,6 +189,7 @@ def bind_to_state(
state (State): The training state.
"""
self.state = state
self.state.callbacks.append(self)
self.state.callbacks.extend(self._callbacks)
self.state.callbacks.extend(self._trace_handlers)

Expand Down Expand Up @@ -289,3 +294,7 @@ def should_record(state: State) -> bool:
)
self._names_to_markers[name].categories = categories
return self._names_to_markers[name]

def after_load(self, state: State, logger: Logger) -> None:
del logger
self.resumption_batch_idx = int(state.timestamp.batch_in_epoch)
19 changes: 12 additions & 7 deletions composer/profiler/profiler_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ def cyclic_schedule(
This function returns a schedule function that uses a cyclic profiling window. The resulting function can be
passed as the ``prof_schedule`` argument to the :class:`.Trainer`.
The cyclic window skips the first ``skip_first`` batches in every epoch. Then, it performs a cycle of
skipping ``wait`` batches, warming up for ``warmup`` batches, and recording ``active`` batches.
It repeats this cycle up to ``repeat`` times per epoch (or for the entire epoch, if ``repeat`` is 0).
This logic repeats every epoch.
The cyclic window skips the first ``skip_first`` + ``resumption_batch_idx`` batches in every epoch.
``resumption_batch_idx`` is accessed from state.profiler. It is the ``state.timestamp.batch_in_epoch``
when resuming training. Then, it performs a cycle of skipping ``wait`` batches, warming up for ``warmup``
batches, and recording ``active`` batches. It repeats this cycle up to ``repeat`` times per epoch (or
for the entire epoch, if ``repeat`` is 0). This logic repeats every epoch.
Args:
skip_first (int, optional): Number of batches to skip profiling at epoch start. Defaults to ``0``.
Expand All @@ -46,12 +47,16 @@ def schedule(state: State):
# do wait, then warump, then active, up to repeat times per cycle
cycle_len = wait + warmup + active
batch_idx = int(state.timestamp.batch_in_epoch)
if batch_idx < skip_first:
if state.profiler is not None:
skip_first_after_resumption = skip_first + state.profiler.resumption_batch_idx
else:
skip_first_after_resumption = skip_first
if batch_idx < skip_first_after_resumption:
return ProfilerAction.SKIP
if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first:
if repeat != 0 and batch_idx >= cycle_len * repeat + skip_first_after_resumption:
# exhausted the repeat
return ProfilerAction.SKIP
position_in_cycle = (batch_idx - skip_first) % cycle_len
position_in_cycle = (batch_idx - skip_first_after_resumption) % cycle_len
if position_in_cycle < wait:
return ProfilerAction.SKIP
if position_in_cycle < wait + warmup:
Expand Down
5 changes: 5 additions & 0 deletions tests/callbacks/callback_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
from typing import Any, Dict, List, Tuple, Type
from unittest.mock import MagicMock

import pytest
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -125,6 +126,10 @@
NeptuneLogger: {
'mode': 'debug',
},
composer.profiler.Profiler: {
'trace_handlers': [MagicMock()],
'schedule': composer.profiler.cyclic_schedule(),
}
}

_callback_marks: Dict[Type[Callback], List[pytest.MarkDecorator],] = {
Expand Down
40 changes: 39 additions & 1 deletion tests/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import pytest
import torch
from packaging import version
from torch.profiler.profiler import ProfilerAction as TorchProfilerAction

from composer.core import State
from composer.core import Engine, Event, State, Timestamp
from composer.loggers import Logger
from composer.profiler import Profiler, ProfilerAction, SystemProfiler, TorchProfiler, cyclic_schedule
from composer.profiler.utils import export_memory_timeline_html

Expand Down Expand Up @@ -170,3 +172,39 @@ def test_memory_timeline(tmp_path: pathlib.Path) -> None:
assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True'
_, end = fig.gca().get_ylim()
assert round(end, 2) == 0.06


def test_skip_first_after_resumption(minimal_state: State) -> None:
skip_first = 1
wait = 2
warmup = 3
active = 4
repeat = 1
schedule = cyclic_schedule(skip_first=skip_first, wait=wait, warmup=warmup, active=active, repeat=repeat)
mock_trace_handler = MagicMock()
profiler = Profiler(
trace_handlers=[mock_trace_handler],
schedule=schedule,
)
profiler.bind_to_state(minimal_state)
minimal_state.profiler = profiler

assert len(profiler._callbacks) >= 1
assert isinstance(profiler._callbacks[-1], TorchProfiler)
torch_profiler = profiler._callbacks[-1]

# Create torch.profiler.profile
logger = Logger(minimal_state)
engine = Engine(state=minimal_state, logger=logger)
engine.run_event(Event.INIT)
assert torch_profiler.profiler is not None

minimal_state.timestamp = Timestamp(batch_in_epoch=7)
assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.RECORD

# Load checkpoint at batch 4
minimal_state.timestamp = Timestamp(batch_in_epoch=4)
engine.run_event(Event.BEFORE_LOAD)
engine.run_event(Event.AFTER_LOAD)
minimal_state.timestamp = Timestamp(batch_in_epoch=7)
assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.WARMUP

0 comments on commit 375ea0c

Please sign in to comment.