diff --git a/composer/profiler/marker.py b/composer/profiler/marker.py index a87bfc02ae..26dcc388ed 100644 --- a/composer/profiler/marker.py +++ b/composer/profiler/marker.py @@ -40,7 +40,7 @@ class Marker: .. testsetup:: from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) .. doctest:: @@ -57,7 +57,7 @@ class Marker: .. testsetup:: from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) .. doctest:: @@ -124,7 +124,7 @@ def start(self) -> None: .. testsetup:: from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) .. doctest:: @@ -187,7 +187,7 @@ def instant(self) -> None: .. testsetup:: from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) .. doctest:: @@ -213,7 +213,7 @@ def counter(self, values: Dict[str, Union[float, int]]) -> None: .. testsetup:: from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) .. doctest:: diff --git a/composer/profiler/profiler.py b/composer/profiler/profiler.py index 876282dd99..4e5a6bbbb2 100644 --- a/composer/profiler/profiler.py +++ b/composer/profiler/profiler.py @@ -45,6 +45,7 @@ class Profiler: def new_profiler_init(self, dummy_ellipsis=None, **kwargs): if 'trace_handlers' not in kwargs: kwargs['trace_handlers'] = [] + kwargs['torch_prof_memory_filename'] = None original_profiler_init(self, **kwargs) Profiler.__init__ = new_profiler_init @@ -62,6 +63,7 @@ def new_profiler_init(self, dummy_ellipsis=None, **kwargs): active=4, repeat=1, ), + torch_prof_memory_filename=None, ) trace_handlers (TraceHandler | Sequence[TraceHandler]): Trace handlers which record and @@ -100,7 +102,7 @@ def __init__( torch_prof_folder: str = '{run_name}/torch_traces', torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json', torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json', - torch_prof_memory_filename: str = 'rank{rank}.{batch}.pt.memory_trace.html', + torch_prof_memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.memory_trace.html', torch_prof_memory_remote_file_name: Optional[ str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.memory_trace.html', torch_prof_overwrite: bool = False, @@ -143,6 +145,17 @@ def __init__( profile_net=sys_prof_net, stats_thread_interval_seconds=sys_prof_stats_thread_interval_seconds)) + if torch_prof_memory_filename is not None: + if not (torch_prof_with_stack and torch_prof_record_shapes and torch_prof_profile_memory): + raise ValueError( + f'torch_prof_memory_filename is set. Generating the memory timeline graph requires all the three flags torch_prof_with_stack, torch_prof_record_shapes, and torch_prof_profile_memory to be true. Got torch_prof_with_stack={torch_prof_with_stack}, torch_prof_record_shapes={torch_prof_record_shapes}, torch_prof_profile_memory={torch_prof_profile_memory}' + ) + log.info( + f'Memory profiling is enabled and uses {torch_prof_memory_filename} as the filename to generate the memory timeline graph. To disable the memory timeline graph generation, explicitly set torch_prof_memory_filename to None.' + ) + else: + log.info(f'torch_prof_memory_filename is explicitly set to None. Memory timeline will not be be generated.') + if torch_prof_record_shapes or torch_prof_profile_memory or torch_prof_with_stack or torch_prof_with_flops: self._callbacks.append( TorchProfiler(filename=torch_prof_filename, @@ -230,7 +243,7 @@ def marker( from composer.profiler import Profiler, cyclic_schedule - profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[]) + profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None) profiler.bind_to_state(state) state.profiler = profiler diff --git a/composer/profiler/torch_profiler.py b/composer/profiler/torch_profiler.py index cfd4c0a48b..ef3fad2554 100644 --- a/composer/profiler/torch_profiler.py +++ b/composer/profiler/torch_profiler.py @@ -188,7 +188,6 @@ def __init__( memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.trace.memory.html', memory_remote_file_name: Optional[ str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory.html', - memory_custom_plot: bool = True, overwrite: bool = False, use_gzip: bool = False, record_shapes: bool = False, @@ -226,7 +225,6 @@ def __init__( self.num_traces_to_keep = num_traces_to_keep self.saved_traces = OrderedDict() self.profiler: Optional[torch.profiler.profile] = None - self.memory_custom_plot = memory_custom_plot def init(self, state: State, logger: Logger) -> None: if state.profiler is None: @@ -295,12 +293,9 @@ def handler_fn(prof: torch.profiler.profiler.profile): memory_trace_file_dirname = os.path.dirname(memory_trace_file_name) if memory_trace_file_dirname: os.makedirs(memory_trace_file_dirname, exist_ok=True) - if self.memory_custom_plot: - from composer.profiler.utils import export_memory_timeline_html - export_memory_timeline_html(prof, memory_trace_file_name, - torch.cuda.current_device()) # type: ignore - else: - prof.export_memory_timeline(memory_trace_file_name, torch.cuda.current_device()) # type: ignore + from composer.profiler.utils import export_memory_timeline_html + export_memory_timeline_html(prof, memory_trace_file_name, + torch.cuda.current_device()) # type: ignore log.debug(f'Uploaded memory trace to {self.memory_remote_file_name}') if self.memory_remote_file_name is not None: memory_trace_remote_file_name = format_name_with_dist_and_time(self.memory_remote_file_name, diff --git a/docs/source/trainer/performance_tutorials/profiling.md b/docs/source/trainer/performance_tutorials/profiling.md index 0bc930cd47..6c87e2fea8 100644 --- a/docs/source/trainer/performance_tutorials/profiling.md +++ b/docs/source/trainer/performance_tutorials/profiling.md @@ -83,6 +83,7 @@ Note, we support both local and object store paths for the composer profiler, e. profiler = Profiler( trace_handlers=[JSONTraceHandler(remote_file_name='oci://your-bucket/composer_profiler/')], torch_remote_filename='s3://your-bucket/torch_profiler/', + torch_prof_memory_filename=None, ... ) ``` @@ -119,30 +120,30 @@ For example, let’s assume the profiling options are set as follows: Given the configuration above, profiling will be performed as follows: -| Epoch | Batch | Profiler State | Profiler Action | -| --- | --- | --- | --- | -| 0 | 0 | skip_first | Do not record | -| | 1 | wait | Do not record | -| | 2 | warmup | Record, Torch Profiler does not record | -| | 3 | active | Record | -| | 4 | active | Record | -| | 5 | wait | Do not record | -| | 6 | warmup | Record, Torch Profiler does not record | -| | 7 | active | Record | -| | 8 | active | Record | -| | 9 | disabled | Do not record | -| | ... | | | -| 1 | 0 | skip_first | Do not record | -| | 1 | wait | Do not record | -| | 2 | warmup | Record, Torch Profiler does not record | -| | 3 | active | Record | -| | 4 | active | Record | -| | 5 | wait | Do not record | -| | 6 | warmup | Record, Torch Profiler does not record | -| | 7 | active | Record | -| | 8 | active | Record | -| | 9 | disabled | Do not record | -| | ... | | | +| Epoch | Batch | Profiler State | Profiler Action | +| ----- | ----- | -------------- | -------------------------------------- | +| 0 | 0 | skip_first | Do not record | +| | 1 | wait | Do not record | +| | 2 | warmup | Record, Torch Profiler does not record | +| | 3 | active | Record | +| | 4 | active | Record | +| | 5 | wait | Do not record | +| | 6 | warmup | Record, Torch Profiler does not record | +| | 7 | active | Record | +| | 8 | active | Record | +| | 9 | disabled | Do not record | +| | ... | | | +| 1 | 0 | skip_first | Do not record | +| | 1 | wait | Do not record | +| | 2 | warmup | Record, Torch Profiler does not record | +| | 3 | active | Record | +| | 4 | active | Record | +| | 5 | wait | Do not record | +| | 6 | warmup | Record, Torch Profiler does not record | +| | 7 | active | Record | +| | 8 | active | Record | +| | 9 | disabled | Do not record | +| | ... | | | As we can see above, the profiler skips the first batch of each epoch and is in the wait state during the following batch, after which the profiler performs warms up in the next batch and actively records trace data for the diff --git a/examples/profiler_demo.py b/examples/profiler_demo.py index c166efa315..f06fa17f06 100644 --- a/examples/profiler_demo.py +++ b/examples/profiler_demo.py @@ -63,6 +63,7 @@ ), torch_prof_folder=torch_trace_dir, torch_prof_overwrite=True, + torch_prof_memory_filename=None, )) # [trainer-end] diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 0e6b137369..695be08c55 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -53,7 +53,9 @@ def test_multiple_fit_start_and_end(self, cb_cls: Type[Callback], dummy_state: S """Test that callbacks do not crash when Event.FIT_START and Event.FIT_END is called multiple times.""" cb_kwargs = get_cb_kwargs(cb_cls) dummy_state.callbacks.append(cb_cls(**cb_kwargs)) - dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[]) + dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, + trace_handlers=[], + torch_prof_memory_filename=None) dummy_state.profiler.bind_to_state(dummy_state) logger = Logger(dummy_state) @@ -71,7 +73,9 @@ def test_idempotent_close(self, cb_cls: Type[Callback], dummy_state: State): """Test that callbacks do not crash when .close() and .post_close() are called multiple times.""" cb_kwargs = get_cb_kwargs(cb_cls) dummy_state.callbacks.append(cb_cls(**cb_kwargs)) - dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[]) + dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, + trace_handlers=[], + torch_prof_memory_filename=None) dummy_state.profiler.bind_to_state(dummy_state) logger = Logger(dummy_state) @@ -85,7 +89,9 @@ def test_multiple_init_and_close(self, cb_cls: Type[Callback], dummy_state: Stat """Test that callbacks do not crash when INIT/.close()/.post_close() are called multiple times in that order.""" cb_kwargs = get_cb_kwargs(cb_cls) dummy_state.callbacks.append(cb_cls(**cb_kwargs)) - dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[]) + dummy_state.profiler = Profiler(schedule=lambda _: ProfilerAction.SKIP, + trace_handlers=[], + torch_prof_memory_filename=None) dummy_state.profiler.bind_to_state(dummy_state) logger = Logger(dummy_state) @@ -125,7 +131,9 @@ def _get_trainer(self, cb: Callback, device_train_microbatch_size: int): device_train_microbatch_size=device_train_microbatch_size, callbacks=callbacks, loggers=loggers, - profiler=Profiler(schedule=lambda _: ProfilerAction.SKIP, trace_handlers=[]), + profiler=Profiler(schedule=lambda _: ProfilerAction.SKIP, + trace_handlers=[], + torch_prof_memory_filename=None), ) def test_trains(self, cb_cls: Type[Callback], device_train_microbatch_size: int, _remote: bool): diff --git a/tests/profiler/test_json_trace_handler.py b/tests/profiler/test_json_trace_handler.py index c09ae00fe6..1d13aed18e 100644 --- a/tests/profiler/test_json_trace_handler.py +++ b/tests/profiler/test_json_trace_handler.py @@ -34,6 +34,7 @@ def test_json_trace_profiler_handler(tmp_path: pathlib.Path): torch_prof_profile_memory=False, torch_prof_with_stack=False, torch_prof_with_flops=False, + torch_prof_memory_filename=None, ) trainer = Trainer( model=SimpleModel(), diff --git a/tests/profiler/test_memory_timeline.py b/tests/profiler/test_memory_timeline.py deleted file mode 100644 index c8e685df0b..0000000000 --- a/tests/profiler/test_memory_timeline.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2022 MosaicML Composer authors -# SPDX-License-Identifier: Apache-2.0 - -import os -import pathlib - -import pytest -import torch -from packaging import version - -from composer.profiler.utils import export_memory_timeline_html - - -@pytest.mark.gpu -def test_memory_timeline(tmp_path: pathlib.Path) -> None: - if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): - # memory timeline is supported after PyTorch 2.1.0. - return - import torch.profiler._memory_profiler as _memory_profiler - - model = torch.nn.Sequential( - torch.nn.Linear(1024, 1024, bias=True), - torch.nn.ReLU(), - torch.nn.Linear(1024, 1024, bias=False), - torch.nn.Softmax(dim=1), - ).to('cuda') - optimizer = torch.optim.Adam(model.parameters(), lr=0.1) - - x = torch.ones((1024, 1024), device='cuda') - targets = torch.ones((1024, 1024), device='cuda') - with torch.profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof: - y = model(x) - loss = torch.nn.functional.mse_loss(y, targets) - loss.backward() - optimizer.step() - optimizer.zero_grad() - - memory_profile = prof._memory_profile() - timeline = memory_profile.timeline - - # this checks the default memory timeline event value (t == -1) for preexisting tensors - assert all((t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0) for t, action, _, _ in timeline) - - fig = export_memory_timeline_html( - prof, - os.path.join(tmp_path, 'test_memory_timeline.html'), - yxis_step_size=0.01, - return_fig=True, - ) - assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True' - _, end = fig.gca().get_ylim() - assert round(end, 2) == 0.06 diff --git a/tests/profiler/test_profiler.py b/tests/profiler/test_profiler.py index 5d960f0dd1..2ae9383d79 100644 --- a/tests/profiler/test_profiler.py +++ b/tests/profiler/test_profiler.py @@ -1,12 +1,18 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import os +import pathlib +from typing import Union from unittest.mock import MagicMock import pytest +import torch +from packaging import version from composer.core import State from composer.profiler import Profiler, ProfilerAction, SystemProfiler, TorchProfiler, cyclic_schedule +from composer.profiler.utils import export_memory_timeline_html @pytest.mark.parametrize('repeat', [1, 0]) @@ -50,6 +56,7 @@ def test_profiler_init(minimal_state: State): trace_handlers=[mock_trace_handler], schedule=cyclic_schedule(), torch_prof_profile_memory=True, + torch_prof_memory_filename=None, sys_prof_cpu=True, ) profiler.bind_to_state(minimal_state) @@ -59,10 +66,9 @@ def test_profiler_init(minimal_state: State): def test_marker(dummy_state: State): mock_trace_handler = MagicMock() - profiler = Profiler( - trace_handlers=[mock_trace_handler], - schedule=cyclic_schedule(), - ) + profiler = Profiler(trace_handlers=[mock_trace_handler], + schedule=cyclic_schedule(), + torch_prof_memory_filename=None) profiler.bind_to_state(dummy_state) dummy_state.profiler = profiler marker = profiler.marker('name', @@ -94,3 +100,73 @@ def func_to_profile2(bar: int): assert mock_trace_handler.process_duration_event.call_count == 8 assert mock_trace_handler.process_instant_event.call_count == 1 + + +@pytest.mark.parametrize('torch_prof_with_stack', [True, False]) +@pytest.mark.parametrize('torch_prof_record_shapes', [True, False]) +@pytest.mark.parametrize('torch_prof_profile_memory', [True, False]) +@pytest.mark.parametrize('torch_prof_memory_filename', [None, 'test.html']) +def test_profiler_error_message(torch_prof_with_stack: bool, torch_prof_record_shapes: bool, + torch_prof_profile_memory: bool, torch_prof_memory_filename: Union[None, str]) -> None: + # Construct a profiler and assert that it triggers the ValueError if the arguments are invalid + if (torch_prof_memory_filename is not None and + not (torch_prof_with_stack and torch_prof_record_shapes and torch_prof_profile_memory)): + with pytest.raises(ValueError): + _ = Profiler( + trace_handlers=[MagicMock()], + schedule=cyclic_schedule(), + torch_prof_with_stack=torch_prof_with_stack, + torch_prof_record_shapes=torch_prof_record_shapes, + torch_prof_profile_memory=torch_prof_profile_memory, + torch_prof_memory_filename=torch_prof_memory_filename, + ) + else: + _ = Profiler( + trace_handlers=[MagicMock()], + schedule=cyclic_schedule(), + torch_prof_with_stack=torch_prof_with_stack, + torch_prof_record_shapes=torch_prof_record_shapes, + torch_prof_profile_memory=torch_prof_profile_memory, + torch_prof_memory_filename=torch_prof_memory_filename, + ) + + +@pytest.mark.gpu +def test_memory_timeline(tmp_path: pathlib.Path) -> None: + if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): + # memory timeline is supported after PyTorch 2.1.0. + return + import torch.profiler._memory_profiler as _memory_profiler + + model = torch.nn.Sequential( + torch.nn.Linear(1024, 1024, bias=True), + torch.nn.ReLU(), + torch.nn.Linear(1024, 1024, bias=False), + torch.nn.Softmax(dim=1), + ).to('cuda') + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + + x = torch.ones((1024, 1024), device='cuda') + targets = torch.ones((1024, 1024), device='cuda') + with torch.profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof: + y = model(x) + loss = torch.nn.functional.mse_loss(y, targets) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + memory_profile = prof._memory_profile() + timeline = memory_profile.timeline + + # this checks the default memory timeline event value (t == -1) for preexisting tensors + assert all((t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0) for t, action, _, _ in timeline) + + fig = export_memory_timeline_html( + prof, + os.path.join(tmp_path, 'test_memory_timeline.html'), + yxis_step_size=0.01, + return_fig=True, + ) + assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True' + _, end = fig.gca().get_ylim() + assert round(end, 2) == 0.06