diff --git a/composer/profiler/profiler.py b/composer/profiler/profiler.py index 294dfb8471..876282dd99 100644 --- a/composer/profiler/profiler.py +++ b/composer/profiler/profiler.py @@ -76,6 +76,9 @@ def new_profiler_init(self, dummy_ellipsis=None, **kwargs): torch_prof_filename (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. torch_prof_remote_file_name (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. Additionally supports full object store paths e.g: s3://bucket/path/to/file. + torch_prof_memory_filename (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. + torch_prof_memory_remote_file_name (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. + Additionally supports full object store paths e.g: s3://bucket/path/to/file. torch_prof_overwrite (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. torch_prof_use_gzip (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. torch_prof_record_shapes (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`. @@ -97,6 +100,9 @@ def __init__( torch_prof_folder: str = '{run_name}/torch_traces', torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json', torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json', + torch_prof_memory_filename: str = 'rank{rank}.{batch}.pt.memory_trace.html', + torch_prof_memory_remote_file_name: Optional[ + str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.memory_trace.html', torch_prof_overwrite: bool = False, torch_prof_use_gzip: bool = False, torch_prof_record_shapes: bool = False, @@ -116,6 +122,9 @@ def __init__( if torch_prof_remote_file_name: self.remote_filenames.append(torch_prof_remote_file_name) _, _, torch_prof_remote_file_name = parse_uri(torch_prof_remote_file_name) + if torch_prof_memory_remote_file_name: + self.remote_filenames.append(torch_prof_memory_remote_file_name) + _, _, torch_prof_memory_remote_file_name = parse_uri(torch_prof_memory_remote_file_name) for handler in self._trace_handlers: if isinstance(handler, JSONTraceHandler): if handler.remote_file_name: @@ -139,6 +148,8 @@ def __init__( TorchProfiler(filename=torch_prof_filename, folder=torch_prof_folder, remote_file_name=torch_prof_remote_file_name, + memory_filename=torch_prof_memory_filename, + memory_remote_file_name=torch_prof_memory_remote_file_name, num_traces_to_keep=torch_prof_num_traces_to_keep, overwrite=torch_prof_overwrite, record_shapes=torch_prof_record_shapes, diff --git a/composer/profiler/torch_profiler.py b/composer/profiler/torch_profiler.py index dc33d829aa..cfd4c0a48b 100644 --- a/composer/profiler/torch_profiler.py +++ b/composer/profiler/torch_profiler.py @@ -11,7 +11,9 @@ import textwrap from typing import TYPE_CHECKING, Optional, OrderedDict +import torch.cuda import torch.profiler +from packaging import version from torch.profiler.profiler import ProfilerAction as TorchProfilerAction from composer.core.callback import Callback @@ -92,9 +94,9 @@ class TorchProfiler(Callback): # noqa: D101 Each rank (process) will save traces to:: - awesome-training-run/torch_traces/ep1-ba42-rank0.json - awesome-training-run/torch_traces/ep1-ba42-rank1.json - awesome-training-run/torch_traces/ep1-ba42-rank2.json + awesome-training-run/torch_traces/ep1-ba42-rank0.pt.trace.json + awesome-training-run/torch_traces/ep1-ba42-rank1.pt.trace.json + awesome-training-run/torch_traces/ep1-ba42-rank2.pt.trace.json ... remote_file_name (str, optional): Format string for a Torch Profiler trace file's remote file name. @@ -107,6 +109,43 @@ class TorchProfiler(Callback): # noqa: D101 Leading slashes (``'/'``) will be stripped. + To disable uploading trace files, set this parameter to ``None``. + memory_filename (str, optional): A format string describing how to name Torch Profiler memory trace files. + Defaults to ``'rank{{rank}}.{{batch}}.pt.trace.memory.html'``. + + At the end of each batch where :meth:`~composer.profiler.Profiler.get_action` returns + :attr:`~composer.profiler._profiler_action.ProfilerAction.ACTIVE_AND_SAVE`, trace files are saved + approximately to ``{{folder.format(...)}}/{{memory_filename.format(...)}}``. + + The following format variables are available: + + {textwrap.indent(FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, prefix=' ')} + + Consider the following scenario, where: + + * The :attr:`~.State.run_name` is ``'awesome-training-run'``. + * The default ``trace_folder='{{run_name}}/torch_traces'`` is used. + * The default ``name='rank{{rank}}.{{batch}}.pt.trace.memory.html'`` is used. + * The current epoch count is ``1``. + * The current batch count is ``42``. + + Each rank (process) will save traces to:: + + awesome-training-run/torch_traces/ep1-ba42-rank0.pt.trace.memory.html + awesome-training-run/torch_traces/ep1-ba42-rank1.pt.trace.memory.html + awesome-training-run/torch_traces/ep1-ba42-rank2.pt.trace.memory.html + ... + + memory_remote_file_name (str, optional): Format string for a Torch Profiler memory trace file's remote file name. + Defaults to ``'{{run_name}}/torch_traces/rank{{rank}}.{{batch}}.pt.trace.memory.json'``. + + Whenever a trace file is saved, it is also uploaded as a file according to this format string. + The same format variables as for ``filename`` are available. + + .. seealso:: :doc:`Uploading Files` for notes for file uploading. + + Leading slashes (``'/'``) will be stripped. + To disable uploading trace files, set this parameter to ``None``. overwrite (bool, optional): Whether to override existing Torch Profiler traces. Defaults to False. @@ -146,7 +185,10 @@ def __init__( folder: str = '{run_name}/torch_traces', filename: str = 'rank{rank}.{batch}.pt.trace.json', remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json', - *, + memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.trace.memory.html', + memory_remote_file_name: Optional[ + str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory.html', + memory_custom_plot: bool = True, overwrite: bool = False, use_gzip: bool = False, record_shapes: bool = False, @@ -157,12 +199,26 @@ def __init__( ) -> None: self.overwrite = overwrite self.folder = folder - if use_gzip and not filename.endswith('.gz'): - filename += '.gz' + + if use_gzip: + if not filename.endswith('.gz'): + filename += '.gz' self.filename = filename - if use_gzip and remote_file_name is not None and not remote_file_name.endswith('.gz'): - remote_file_name += '.gz' + + if use_gzip: + if remote_file_name is not None and not remote_file_name.endswith('.gz'): + remote_file_name += '.gz' self.remote_file_name = remote_file_name + + if memory_filename is not None: + assert memory_filename.endswith('.html'), f'memory_filename must end with .html, got {memory_filename}' + self.memory_filename = memory_filename + + if memory_remote_file_name is not None: + assert memory_remote_file_name.endswith( + '.html'), f'memory_remote_file_name must end with .html, got {memory_remote_file_name}' + self.memory_remote_file_name = memory_remote_file_name + self.record_shapes = record_shapes self.profile_memory = profile_memory self.with_stack = with_stack @@ -170,6 +226,7 @@ def __init__( self.num_traces_to_keep = num_traces_to_keep self.saved_traces = OrderedDict() self.profiler: Optional[torch.profiler.profile] = None + self.memory_custom_plot = memory_custom_plot def init(self, state: State, logger: Logger) -> None: if state.profiler is None: @@ -203,27 +260,63 @@ def handler_fn(prof: torch.profiler.profiler.profile): timestamp = state.timestamp - trace_file_name = os.path.join( - folder_name, - format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=timestamp), - ) - trace_file_dirname = os.path.dirname(trace_file_name) - if trace_file_dirname: - os.makedirs(trace_file_dirname, exist_ok=True) - prof.export_chrome_trace(trace_file_name) - state.profiler.record_chrome_json_trace_file(trace_file_name) - if self.remote_file_name is not None: - trace_remote_file_name = format_name_with_dist_and_time(self.remote_file_name, - run_name=state.run_name, - timestamp=timestamp) - trace_remote_file_name = trace_remote_file_name.lstrip('/') - logger.upload_file(remote_file_name=trace_remote_file_name, - file_path=trace_file_name, - overwrite=self.overwrite) + log.info(f'PyTorch Chrome trace profiler enabled: {self.filename if self.filename else False}') + if self.filename is not None: + trace_file_name = os.path.join( + folder_name, + format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=timestamp), + ) + trace_file_dirname = os.path.dirname(trace_file_name) + if trace_file_dirname: + os.makedirs(trace_file_dirname, exist_ok=True) + prof.export_chrome_trace(trace_file_name) + state.profiler.record_chrome_json_trace_file(trace_file_name) + if self.remote_file_name is not None: + trace_remote_file_name = format_name_with_dist_and_time(self.remote_file_name, + run_name=state.run_name, + timestamp=timestamp) + trace_remote_file_name = trace_remote_file_name.lstrip('/') + logger.upload_file(remote_file_name=trace_remote_file_name, + file_path=trace_file_name, + overwrite=self.overwrite) + + log.info( + f'PyTorch memory timeline profiler enabled: {self.memory_filename if self.memory_filename else False}') + if self.memory_filename is not None: + if version.parse(torch.__version__) > version.parse('2.1.0.dev'): # type: ignore + # memory timeline profiling is only supported in torch v2.1.0-rc1 or higher + memory_trace_file_name = os.path.join( + folder_name, + format_name_with_dist_and_time(self.memory_filename, + run_name=state.run_name, + timestamp=timestamp), + ) + log.debug(f'Saving memory trace to {memory_trace_file_name}') + memory_trace_file_dirname = os.path.dirname(memory_trace_file_name) + if memory_trace_file_dirname: + os.makedirs(memory_trace_file_dirname, exist_ok=True) + if self.memory_custom_plot: + from composer.profiler.utils import export_memory_timeline_html + export_memory_timeline_html(prof, memory_trace_file_name, + torch.cuda.current_device()) # type: ignore + else: + prof.export_memory_timeline(memory_trace_file_name, torch.cuda.current_device()) # type: ignore + log.debug(f'Uploaded memory trace to {self.memory_remote_file_name}') + if self.memory_remote_file_name is not None: + memory_trace_remote_file_name = format_name_with_dist_and_time(self.memory_remote_file_name, + run_name=state.run_name, + timestamp=timestamp) + memory_trace_remote_file_name = memory_trace_remote_file_name.lstrip('/') + log.debug( + f'Uploading memory trace to {memory_trace_remote_file_name} from {memory_trace_file_name}') + logger.upload_file(remote_file_name=memory_trace_remote_file_name, + file_path=memory_trace_file_name, + overwrite=self.overwrite) + else: + log.warning('Memory timeline is supported after PyTorch 2.1.0. Skipping memory trace.') if self.num_traces_to_keep >= 0: while len(self.saved_traces) > self.num_traces_to_keep: - # self.saved_traces is an ordered dict, so the zeroth item will be the oldest checkpoint timestamp, filepaths = next(iter(self.saved_traces.items())) if dist.get_global_rank() < len(filepaths): diff --git a/composer/profiler/utils.py b/composer/profiler/utils.py new file mode 100644 index 0000000000..b4df8396a7 --- /dev/null +++ b/composer/profiler/utils.py @@ -0,0 +1,97 @@ +# Copyright 2022 MosaicML Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""Utility functions for torch profiler.""" + +import importlib.util +import logging +from base64 import b64encode +from os import remove +from tempfile import NamedTemporaryFile +from typing import Any, Optional, Union + +import numpy as np +import torch +import torch.cuda +from packaging import version +from torch.profiler.profiler import profile as TorchProfile + +log = logging.getLogger(__name__) + + +def export_memory_timeline_html(prof: TorchProfile, + path: str, + device: Optional[str] = None, + figsize=(20, 12), + title=None, + yxis_step_size: float = 1.0, + return_fig: bool = False) -> Optional[Union[None, Any]]: + """Exports a memory timeline to an HTML file. Similar to the PyTorch plotting function, but with adjusted axis tickers and grids.""" + if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): + log.warning('export_memory_timeline_html failed because memory timeline is supported after PyTorch 2.1.0.') + return + + from torch.profiler._memory_profiler import _CATEGORY_TO_COLORS, _CATEGORY_TO_INDEX, MemoryProfileTimeline + + # Default to device 0, if unset. Fallback on cpu. + if device is None and prof.use_device and prof.use_device != 'cuda': + device = prof.use_device + ':0' + + if device is None: + device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + + # Construct the memory timeline plot data + mem_tl = MemoryProfileTimeline(prof._memory_profile()) + + # Check if user has matplotlib installed, return gracefully if not. + matplotlib_spec = importlib.util.find_spec('matplotlib') + if matplotlib_spec is None: + log.warning('export_memory_timeline_html failed because matplotlib was not found.') + return + import matplotlib.pyplot as plt + + mt = mem_tl._coalesce_timeline(device) + times, sizes = np.array(mt[0]), np.array(mt[1]) + stacked = np.cumsum(sizes, axis=1) / 1024**3 + max_memory_allocated = torch.cuda.max_memory_allocated() + max_memory_reserved = torch.cuda.max_memory_reserved() + + # Plot memory timeline as stacked data + fig = plt.figure(figsize=figsize, dpi=80) + axes = fig.gca() + for category, color in _CATEGORY_TO_COLORS.items(): + i = _CATEGORY_TO_INDEX[category] + axes.fill_between(times / 1e3, stacked[:, i], stacked[:, i + 1], color=color, alpha=0.7) + fig.legend(['Unknown' if i is None else i.name for i in _CATEGORY_TO_COLORS]) + axes.set_xlabel('Time (us)') + axes.set_ylabel('Memory (GB)') + _, end = axes.get_ylim() + axes.grid(True) + axes.set_yticks(np.arange(0, end, yxis_step_size)) + title = '\n\n'.join(([title] if title else []) + [ + f'Max memory allocated: {max_memory_allocated/(10**9):.2f} GB \n' + f'Max memory reserved: {max_memory_reserved/(10**9):.2f} GB' + ]) + axes.set_title(title) + + if return_fig: + return fig + + # Embed the memory timeline image into the HTML file + tmpfile = NamedTemporaryFile('wb', suffix='.png', delete=False) + tmpfile.close() + fig.savefig(tmpfile.name, format='png') + + with open(tmpfile.name, 'rb') as tmp: + encoded = b64encode(tmp.read()).decode('utf-8') + html = f""" +