Skip to content

Commit

Permalink
Improve torch memory profiling arguments processing (mosaicml#2777)
Browse files Browse the repository at this point in the history
* improve torch profile args

* improve torch profile args

* change default torch_prof_memory_filename

* add memory profiling arg test

* fix check

* fix check

* fix check

* fix check

* fix check

* fix check
  • Loading branch information
cli99 authored Dec 13, 2023
1 parent a7cad7c commit db3d187
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 99 deletions.
10 changes: 5 additions & 5 deletions composer/profiler/marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Marker:
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
.. doctest::
Expand All @@ -57,7 +57,7 @@ class Marker:
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
.. doctest::
Expand Down Expand Up @@ -124,7 +124,7 @@ def start(self) -> None:
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
.. doctest::
Expand Down Expand Up @@ -187,7 +187,7 @@ def instant(self) -> None:
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
.. doctest::
Expand All @@ -213,7 +213,7 @@ def counter(self, values: Dict[str, Union[float, int]]) -> None:
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
.. doctest::
Expand Down
17 changes: 15 additions & 2 deletions composer/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Profiler:
def new_profiler_init(self, dummy_ellipsis=None, **kwargs):
if 'trace_handlers' not in kwargs:
kwargs['trace_handlers'] = []
kwargs['torch_prof_memory_filename'] = None
original_profiler_init(self, **kwargs)
Profiler.__init__ = new_profiler_init
Expand All @@ -62,6 +63,7 @@ def new_profiler_init(self, dummy_ellipsis=None, **kwargs):
active=4,
repeat=1,
),
torch_prof_memory_filename=None,
)
trace_handlers (TraceHandler | Sequence[TraceHandler]): Trace handlers which record and
Expand Down Expand Up @@ -100,7 +102,7 @@ def __init__(
torch_prof_folder: str = '{run_name}/torch_traces',
torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json',
torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
torch_prof_memory_filename: str = 'rank{rank}.{batch}.pt.memory_trace.html',
torch_prof_memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.memory_trace.html',
torch_prof_memory_remote_file_name: Optional[
str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.memory_trace.html',
torch_prof_overwrite: bool = False,
Expand Down Expand Up @@ -143,6 +145,17 @@ def __init__(
profile_net=sys_prof_net,
stats_thread_interval_seconds=sys_prof_stats_thread_interval_seconds))

if torch_prof_memory_filename is not None:
if not (torch_prof_with_stack and torch_prof_record_shapes and torch_prof_profile_memory):
raise ValueError(
f'torch_prof_memory_filename is set. Generating the memory timeline graph requires all the three flags torch_prof_with_stack, torch_prof_record_shapes, and torch_prof_profile_memory to be true. Got torch_prof_with_stack={torch_prof_with_stack}, torch_prof_record_shapes={torch_prof_record_shapes}, torch_prof_profile_memory={torch_prof_profile_memory}'
)
log.info(
f'Memory profiling is enabled and uses {torch_prof_memory_filename} as the filename to generate the memory timeline graph. To disable the memory timeline graph generation, explicitly set torch_prof_memory_filename to None.'
)
else:
log.info(f'torch_prof_memory_filename is explicitly set to None. Memory timeline will not be be generated.')

if torch_prof_record_shapes or torch_prof_profile_memory or torch_prof_with_stack or torch_prof_with_flops:
self._callbacks.append(
TorchProfiler(filename=torch_prof_filename,
Expand Down Expand Up @@ -230,7 +243,7 @@ def marker(
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
state.profiler = profiler
Expand Down
11 changes: 3 additions & 8 deletions composer/profiler/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ def __init__(
memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.trace.memory.html',
memory_remote_file_name: Optional[
str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory.html',
memory_custom_plot: bool = True,
overwrite: bool = False,
use_gzip: bool = False,
record_shapes: bool = False,
Expand Down Expand Up @@ -226,7 +225,6 @@ def __init__(
self.num_traces_to_keep = num_traces_to_keep
self.saved_traces = OrderedDict()
self.profiler: Optional[torch.profiler.profile] = None
self.memory_custom_plot = memory_custom_plot

def init(self, state: State, logger: Logger) -> None:
if state.profiler is None:
Expand Down Expand Up @@ -295,12 +293,9 @@ def handler_fn(prof: torch.profiler.profiler.profile):
memory_trace_file_dirname = os.path.dirname(memory_trace_file_name)
if memory_trace_file_dirname:
os.makedirs(memory_trace_file_dirname, exist_ok=True)
if self.memory_custom_plot:
from composer.profiler.utils import export_memory_timeline_html
export_memory_timeline_html(prof, memory_trace_file_name,
torch.cuda.current_device()) # type: ignore
else:
prof.export_memory_timeline(memory_trace_file_name, torch.cuda.current_device()) # type: ignore
from composer.profiler.utils import export_memory_timeline_html
export_memory_timeline_html(prof, memory_trace_file_name,
torch.cuda.current_device()) # type: ignore
log.debug(f'Uploaded memory trace to {self.memory_remote_file_name}')
if self.memory_remote_file_name is not None:
memory_trace_remote_file_name = format_name_with_dist_and_time(self.memory_remote_file_name,
Expand Down
49 changes: 25 additions & 24 deletions docs/source/trainer/performance_tutorials/profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ Note, we support both local and object store paths for the composer profiler, e.
profiler = Profiler(
trace_handlers=[JSONTraceHandler(remote_file_name='oci://your-bucket/composer_profiler/')],
torch_remote_filename='s3://your-bucket/torch_profiler/',
torch_prof_memory_filename=None,
...
)
```
Expand Down Expand Up @@ -119,30 +120,30 @@ For example, let’s assume the profiling options are set as follows:

Given the configuration above, profiling will be performed as follows:

| Epoch | Batch | Profiler State | Profiler Action |
| --- | --- | --- | --- |
| 0 | 0 | skip_first | Do not record |
| | 1 | wait | Do not record |
| | 2 | warmup | Record, Torch Profiler does not record |
| | 3 | active | Record |
| | 4 | active | Record |
| | 5 | wait | Do not record |
| | 6 | warmup | Record, Torch Profiler does not record |
| | 7 | active | Record |
| | 8 | active | Record |
| | 9 | disabled | Do not record |
| | ... | | |
| 1 | 0 | skip_first | Do not record |
| | 1 | wait | Do not record |
| | 2 | warmup | Record, Torch Profiler does not record |
| | 3 | active | Record |
| | 4 | active | Record |
| | 5 | wait | Do not record |
| | 6 | warmup | Record, Torch Profiler does not record |
| | 7 | active | Record |
| | 8 | active | Record |
| | 9 | disabled | Do not record |
| | ... | | |
| Epoch | Batch | Profiler State | Profiler Action |
| ----- | ----- | -------------- | -------------------------------------- |
| 0 | 0 | skip_first | Do not record |
| | 1 | wait | Do not record |
| | 2 | warmup | Record, Torch Profiler does not record |
| | 3 | active | Record |
| | 4 | active | Record |
| | 5 | wait | Do not record |
| | 6 | warmup | Record, Torch Profiler does not record |
| | 7 | active | Record |
| | 8 | active | Record |
| | 9 | disabled | Do not record |
| | ... | | |
| 1 | 0 | skip_first | Do not record |
| | 1 | wait | Do not record |
| | 2 | warmup | Record, Torch Profiler does not record |
| | 3 | active | Record |
| | 4 | active | Record |
| | 5 | wait | Do not record |
| | 6 | warmup | Record, Torch Profiler does not record |
| | 7 | active | Record |
| | 8 | active | Record |
| | 9 | disabled | Do not record |
| | ... | | |

As we can see above, the profiler skips the first batch of each epoch and is in the wait state during the following
batch, after which the profiler performs warms up in the next batch and actively records trace data for the
Expand Down
1 change: 1 addition & 0 deletions examples/profiler_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
),
torch_prof_folder=torch_trace_dir,
torch_prof_overwrite=True,
torch_prof_memory_filename=None,
))
# [trainer-end]

Expand Down
16 changes: 12 additions & 4 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def test_multiple_fit_start_and_end(self, cb_cls: Type[Callback], dummy_state: S
"""Test that callbacks do not crash when Event.FIT_START and Event.FIT_END is called multiple times."""
cb_kwargs = get_cb_kwargs(cb_cls)
dummy_state.callbacks.append(cb_cls(**cb_kwargs))
dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[])
dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP,
trace_handlers=[],
torch_prof_memory_filename=None)
dummy_state.profiler.bind_to_state(dummy_state)

logger = Logger(dummy_state)
Expand All @@ -71,7 +73,9 @@ def test_idempotent_close(self, cb_cls: Type[Callback], dummy_state: State):
"""Test that callbacks do not crash when .close() and .post_close() are called multiple times."""
cb_kwargs = get_cb_kwargs(cb_cls)
dummy_state.callbacks.append(cb_cls(**cb_kwargs))
dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[])
dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP,
trace_handlers=[],
torch_prof_memory_filename=None)
dummy_state.profiler.bind_to_state(dummy_state)

logger = Logger(dummy_state)
Expand All @@ -85,7 +89,9 @@ def test_multiple_init_and_close(self, cb_cls: Type[Callback], dummy_state: Stat
"""Test that callbacks do not crash when INIT/.close()/.post_close() are called multiple times in that order."""
cb_kwargs = get_cb_kwargs(cb_cls)
dummy_state.callbacks.append(cb_cls(**cb_kwargs))
dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[])
dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP,
trace_handlers=[],
torch_prof_memory_filename=None)
dummy_state.profiler.bind_to_state(dummy_state)

logger = Logger(dummy_state)
Expand Down Expand Up @@ -125,7 +131,9 @@ def _get_trainer(self, cb: Callback, device_train_microbatch_size: int):
device_train_microbatch_size=device_train_microbatch_size,
callbacks=callbacks,
loggers=loggers,
profiler=Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[]),
profiler=Profiler(schedule=lambda _: ProfilerAction.SKIP,
trace_handlers=[],
torch_prof_memory_filename=None),
)

def test_trains(self, cb_cls: Type[Callback], device_train_microbatch_size: int, _remote: bool):
Expand Down
1 change: 1 addition & 0 deletions tests/profiler/test_json_trace_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_json_trace_profiler_handler(tmp_path: pathlib.Path):
torch_prof_profile_memory=False,
torch_prof_with_stack=False,
torch_prof_with_flops=False,
torch_prof_memory_filename=None,
)
trainer = Trainer(
model=SimpleModel(),
Expand Down
52 changes: 0 additions & 52 deletions tests/profiler/test_memory_timeline.py

This file was deleted.

Loading

0 comments on commit db3d187

Please sign in to comment.