-
Notifications
You must be signed in to change notification settings - Fork 421
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the memory timeline profiling support through the PyTorch profile…
…r. (#2771) * v1 * fix issues * add logs * change names * comment * add device * uncomment original trace * add custome plot * fix pyright * Update composer/profiler/torch_profiler.py Co-authored-by: Charles Tang <[email protected]> * address comments * fix code check * fix formatting * address comments * add unit test * fix check * fix check * fix check * fix check * fix print * add test comment * add test comment --------- Co-authored-by: Mihir Patel <[email protected]> Co-authored-by: Charles Tang <[email protected]>
- Loading branch information
1 parent
f497e60
commit a7cad7c
Showing
4 changed files
with
279 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Utility functions for torch profiler.""" | ||
|
||
import importlib.util | ||
import logging | ||
from base64 import b64encode | ||
from os import remove | ||
from tempfile import NamedTemporaryFile | ||
from typing import Any, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
import torch.cuda | ||
from packaging import version | ||
from torch.profiler.profiler import profile as TorchProfile | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def export_memory_timeline_html(prof: TorchProfile, | ||
path: str, | ||
device: Optional[str] = None, | ||
figsize=(20, 12), | ||
title=None, | ||
yxis_step_size: float = 1.0, | ||
return_fig: bool = False) -> Optional[Union[None, Any]]: | ||
"""Exports a memory timeline to an HTML file. Similar to the PyTorch plotting function, but with adjusted axis tickers and grids.""" | ||
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): | ||
log.warning('export_memory_timeline_html failed because memory timeline is supported after PyTorch 2.1.0.') | ||
return | ||
|
||
from torch.profiler._memory_profiler import _CATEGORY_TO_COLORS, _CATEGORY_TO_INDEX, MemoryProfileTimeline | ||
|
||
# Default to device 0, if unset. Fallback on cpu. | ||
if device is None and prof.use_device and prof.use_device != 'cuda': | ||
device = prof.use_device + ':0' | ||
|
||
if device is None: | ||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | ||
|
||
# Construct the memory timeline plot data | ||
mem_tl = MemoryProfileTimeline(prof._memory_profile()) | ||
|
||
# Check if user has matplotlib installed, return gracefully if not. | ||
matplotlib_spec = importlib.util.find_spec('matplotlib') | ||
if matplotlib_spec is None: | ||
log.warning('export_memory_timeline_html failed because matplotlib was not found.') | ||
return | ||
import matplotlib.pyplot as plt | ||
|
||
mt = mem_tl._coalesce_timeline(device) | ||
times, sizes = np.array(mt[0]), np.array(mt[1]) | ||
stacked = np.cumsum(sizes, axis=1) / 1024**3 | ||
max_memory_allocated = torch.cuda.max_memory_allocated() | ||
max_memory_reserved = torch.cuda.max_memory_reserved() | ||
|
||
# Plot memory timeline as stacked data | ||
fig = plt.figure(figsize=figsize, dpi=80) | ||
axes = fig.gca() | ||
for category, color in _CATEGORY_TO_COLORS.items(): | ||
i = _CATEGORY_TO_INDEX[category] | ||
axes.fill_between(times / 1e3, stacked[:, i], stacked[:, i + 1], color=color, alpha=0.7) | ||
fig.legend(['Unknown' if i is None else i.name for i in _CATEGORY_TO_COLORS]) | ||
axes.set_xlabel('Time (us)') | ||
axes.set_ylabel('Memory (GB)') | ||
_, end = axes.get_ylim() | ||
axes.grid(True) | ||
axes.set_yticks(np.arange(0, end, yxis_step_size)) | ||
title = '\n\n'.join(([title] if title else []) + [ | ||
f'Max memory allocated: {max_memory_allocated/(10**9):.2f} GB \n' | ||
f'Max memory reserved: {max_memory_reserved/(10**9):.2f} GB' | ||
]) | ||
axes.set_title(title) | ||
|
||
if return_fig: | ||
return fig | ||
|
||
# Embed the memory timeline image into the HTML file | ||
tmpfile = NamedTemporaryFile('wb', suffix='.png', delete=False) | ||
tmpfile.close() | ||
fig.savefig(tmpfile.name, format='png') | ||
|
||
with open(tmpfile.name, 'rb') as tmp: | ||
encoded = b64encode(tmp.read()).decode('utf-8') | ||
html = f"""<html> | ||
<head><meta charset="utf-8" /><title>GPU Memory Timeline HTML</title></head> | ||
<body> | ||
<img src='data:image/png;base64,{encoded}'> | ||
</body> | ||
</html>""" | ||
|
||
with open(path, 'w') as f: | ||
f.write(html) | ||
log.debug('Memory timeline exported to', path, '.') | ||
remove(tmpfile.name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import pathlib | ||
|
||
import pytest | ||
import torch | ||
from packaging import version | ||
|
||
from composer.profiler.utils import export_memory_timeline_html | ||
|
||
|
||
@pytest.mark.gpu | ||
def test_memory_timeline(tmp_path: pathlib.Path) -> None: | ||
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'): | ||
# memory timeline is supported after PyTorch 2.1.0. | ||
return | ||
import torch.profiler._memory_profiler as _memory_profiler | ||
|
||
model = torch.nn.Sequential( | ||
torch.nn.Linear(1024, 1024, bias=True), | ||
torch.nn.ReLU(), | ||
torch.nn.Linear(1024, 1024, bias=False), | ||
torch.nn.Softmax(dim=1), | ||
).to('cuda') | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) | ||
|
||
x = torch.ones((1024, 1024), device='cuda') | ||
targets = torch.ones((1024, 1024), device='cuda') | ||
with torch.profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof: | ||
y = model(x) | ||
loss = torch.nn.functional.mse_loss(y, targets) | ||
loss.backward() | ||
optimizer.step() | ||
optimizer.zero_grad() | ||
|
||
memory_profile = prof._memory_profile() | ||
timeline = memory_profile.timeline | ||
|
||
# this checks the default memory timeline event value (t == -1) for preexisting tensors | ||
assert all((t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0) for t, action, _, _ in timeline) | ||
|
||
fig = export_memory_timeline_html( | ||
prof, | ||
os.path.join(tmp_path, 'test_memory_timeline.html'), | ||
yxis_step_size=0.01, | ||
return_fig=True, | ||
) | ||
assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True' | ||
_, end = fig.gca().get_ylim() | ||
assert round(end, 2) == 0.06 |