Skip to content

Commit

Permalink
Made benchmark runnable -- not yet integrated with github actions in …
Browse files Browse the repository at this point in the history
…any meaningful way.
  • Loading branch information
mmcdermott committed Sep 12, 2024
1 parent 4238c71 commit b1a48b4
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 25 deletions.
62 changes: 38 additions & 24 deletions benchmark/benchmarkable_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import logging
import os
import sys
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime, timedelta
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any

import torch
from memray import Tracker
from mixins import TimeableMixin
from torch.utils.data import DataLoader, Dataset
Expand All @@ -19,11 +22,11 @@
import subprocess


def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path):
def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict:
memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f"
subprocess.run(memray_stats_cmd, shell=True, check=True)
try:
json.loads(memray_stats_fp.read_text())
return json.loads(memray_stats_fp.read_text())
except Exception as e:
raise ValueError(f"Failed to parse memray stats file at {memray_stats_fp}") from e

Expand All @@ -32,47 +35,27 @@ class BenchmarkableDataset(Dataset, TimeableMixin, ABC):
def __init__(
self,
data_dir: Path,
memray_stats_fp: Path,
max_seq_len: int | None = None,
min_seq_len: int | None = None,
task_bounds: list[tuple[int, int, int]] | None = None,
):
super().__init__()
# TODO(mmd): Need to handle min seq length too.
self.max_seq_len = max_seq_len
self.min_seq_len = min_seq_len
self.task_bounds = task_bounds
self.read(data_dir)
if not hasattr(self, "N"):
raise AttributeError("Dataset must have attribute 'N' after reading data.")

@property
def total_memory_stats(self) -> dict:
if not self.memray_stats_fp.exists():
raise FileNotFoundError(f"Memray stats file not found at {self.memray_stats_fp}")
return json.loads(self.memray_stats_fp.read_text())

@classmethod
@contextmanager
def TemporaryDataset(cls, data: SAMPLE_DATASET_T, root_dir: Path):
memray_stats_fp = root_dir / "memray_stats.json"
with TemporaryDirectory(prefix=str(root_dir.resolve())) as tmpdir:
tmpdir = Path(tmpdir)
memray_fp = tmpdir / ".memray"

cnstr_kwargs, prep_times = cls._prep(data, tmpdir)
cnstr_kwargs["memray_stats_fp"] = memray_stats_fp

disk_size = sum((Path(d) / f).stat().st_size for d, _, files in os.walk(tmpdir) for f in files)

try:
if memray_fp.exists():
logger.warning(f"Memray tracker file already exists at {memray_fp}. Overwriting.")
memray_fp.unlink()

with Tracker(memray_fp, follow_fork=True):
yield cnstr_kwargs, prep_times, disk_size
finally:
_ = get_memray_stats(memray_fp, memray_stats_fp)
yield cnstr_kwargs, prep_times, disk_size

@classmethod
@abstractmethod
Expand Down Expand Up @@ -118,3 +101,34 @@ def collate(self, batch: list[dict]) -> dict:

def dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(self, *args, collate_fn=self.collate, **kwargs)

@staticmethod
def tensor_size(a: torch.Tensor) -> int:
return sys.getsizeof(a) + torch.numel(a) * a.element_size()

@TimeableMixin.TimeAs
def benchmark(
self, batch_size: int, num_epochs: int = 1
) -> tuple[dict[str, list[int]], list[timedelta], dict]:
torch.manual_seed(1)

dataloader = self.dataloader(batch_size=batch_size, shuffle=True)

sizes = defaultdict(list)
epoch_durations = []

with TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
memray_fp = tmpdir / ".memray"
memray_stats_fp = tmpdir / "memray_stats.json"

with Tracker(memray_fp, follow_fork=True):
for epoch in range(num_epochs):
epoch_start = datetime.now()
for B in dataloader:
for k, v in B.items():
sizes[k].append(BenchmarkableDataset.tensor_size(v))
epoch_durations.append(datetime.now() - epoch_start)
memray_stats = get_memray_stats(memray_fp, memray_stats_fp)

return sizes, epoch_durations, memray_stats
5 changes: 4 additions & 1 deletion benchmark/nrt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from pathlib import Path

import numpy as np
from benchmarkable_dataset import BenchmarkableDataset
import torch
from mixins import TimeableMixin
from torch.utils.data import default_collate

from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict
from sample_dataset_builder import SAMPLE_DATASET_T

from .benchmarkable_dataset import BenchmarkableDataset


class NRTDataset(BenchmarkableDataset):
@classmethod
Expand Down Expand Up @@ -71,5 +73,6 @@ def __getitem__(self, i):
def collate(self, batch: list[tuple[dict, JointNestedRaggedTensorDict]]) -> dict:
dynamics = [d for _, d in batch]
collated_dynamics = JointNestedRaggedTensorDict.vstack(dynamics).to_dense()
collated_dynamics = {k: torch.from_numpy(v) for k, v in collated_dynamics.items()}
collated_static_data = default_collate([s for s, _ in batch])
return {**collated_static_data, **collated_dynamics}
45 changes: 45 additions & 0 deletions benchmark/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import rootutils

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=False)

SAMPLE_DATASET_PATH = root / "sample_dataset" / "dataset.pkl"

import pickle
from pathlib import Path

import humanize
import numpy as np
import pytest

from benchmark.nrt_dataset import NRTDataset


@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("max_seq_len", [128, 256, 512])
@pytest.mark.parametrize("num_epochs", [5])
def test_profile(tmp_path: Path, batch_size: int, max_seq_len: int, num_epochs: int):
with open(SAMPLE_DATASET_PATH, mode="rb") as f:
raw_D = pickle.load(f)

out = {}
with NRTDataset.TemporaryDataset(raw_D, tmp_path) as (kwargs, prep_times, disk_size):
out["prep_times"] = prep_times
out["disk_size"] = disk_size
print(f"Dataset takes up {humanize.naturalsize(disk_size)}")

D = NRTDataset(**kwargs, max_seq_len=max_seq_len)
batch_sizes, epoch_durations, memray_stats = D.benchmark(
batch_size=batch_size,
num_epochs=num_epochs,
)

out["batch_sizes"] = batch_sizes
out["epoch_durations"] = epoch_durations
out["memray_stats"] = memray_stats
out["peak_memory"] = memray_stats["metadata"]["peak_memory"]

print(D._profile_durations())
print(f"Peak memory: {humanize.naturalsize(out['peak_memory'])}")

average_epoch_duration = np.mean(epoch_durations)
print(f"Average epoch duration: {average_epoch_duration} seconds")

0 comments on commit b1a48b4

Please sign in to comment.