From b1a48b4feb573e4c1ba8fe7b0c4fc97835ef7675 Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Wed, 11 Sep 2024 22:23:03 -0400 Subject: [PATCH] Made benchmark runnable -- not yet integrated with github actions in any meaningful way. --- benchmark/benchmarkable_dataset.py | 62 ++++++++++++++++++------------ benchmark/nrt_dataset.py | 5 ++- benchmark/run.py | 45 ++++++++++++++++++++++ 3 files changed, 87 insertions(+), 25 deletions(-) create mode 100644 benchmark/run.py diff --git a/benchmark/benchmarkable_dataset.py b/benchmark/benchmarkable_dataset.py index 49c5267..4806e56 100644 --- a/benchmark/benchmarkable_dataset.py +++ b/benchmark/benchmarkable_dataset.py @@ -1,12 +1,15 @@ import logging import os +import sys from abc import ABC, abstractmethod +from collections import defaultdict from contextlib import contextmanager from datetime import datetime, timedelta from pathlib import Path from tempfile import TemporaryDirectory from typing import Any +import torch from memray import Tracker from mixins import TimeableMixin from torch.utils.data import DataLoader, Dataset @@ -19,11 +22,11 @@ import subprocess -def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path): +def get_memray_stats(memray_tracker_fp: Path, memray_stats_fp: Path) -> dict: memray_stats_cmd = f"memray stats {memray_tracker_fp} --json -o {memray_stats_fp} -f" subprocess.run(memray_stats_cmd, shell=True, check=True) try: - json.loads(memray_stats_fp.read_text()) + return json.loads(memray_stats_fp.read_text()) except Exception as e: raise ValueError(f"Failed to parse memray stats file at {memray_stats_fp}") from e @@ -32,47 +35,27 @@ class BenchmarkableDataset(Dataset, TimeableMixin, ABC): def __init__( self, data_dir: Path, - memray_stats_fp: Path, max_seq_len: int | None = None, - min_seq_len: int | None = None, task_bounds: list[tuple[int, int, int]] | None = None, ): super().__init__() + # TODO(mmd): Need to handle min seq length too. self.max_seq_len = max_seq_len - self.min_seq_len = min_seq_len self.task_bounds = task_bounds self.read(data_dir) if not hasattr(self, "N"): raise AttributeError("Dataset must have attribute 'N' after reading data.") - @property - def total_memory_stats(self) -> dict: - if not self.memray_stats_fp.exists(): - raise FileNotFoundError(f"Memray stats file not found at {self.memray_stats_fp}") - return json.loads(self.memray_stats_fp.read_text()) - @classmethod @contextmanager def TemporaryDataset(cls, data: SAMPLE_DATASET_T, root_dir: Path): - memray_stats_fp = root_dir / "memray_stats.json" with TemporaryDirectory(prefix=str(root_dir.resolve())) as tmpdir: tmpdir = Path(tmpdir) - memray_fp = tmpdir / ".memray" cnstr_kwargs, prep_times = cls._prep(data, tmpdir) - cnstr_kwargs["memray_stats_fp"] = memray_stats_fp disk_size = sum((Path(d) / f).stat().st_size for d, _, files in os.walk(tmpdir) for f in files) - - try: - if memray_fp.exists(): - logger.warning(f"Memray tracker file already exists at {memray_fp}. Overwriting.") - memray_fp.unlink() - - with Tracker(memray_fp, follow_fork=True): - yield cnstr_kwargs, prep_times, disk_size - finally: - _ = get_memray_stats(memray_fp, memray_stats_fp) + yield cnstr_kwargs, prep_times, disk_size @classmethod @abstractmethod @@ -118,3 +101,34 @@ def collate(self, batch: list[dict]) -> dict: def dataloader(self, *args, **kwargs) -> DataLoader: return DataLoader(self, *args, collate_fn=self.collate, **kwargs) + + @staticmethod + def tensor_size(a: torch.Tensor) -> int: + return sys.getsizeof(a) + torch.numel(a) * a.element_size() + + @TimeableMixin.TimeAs + def benchmark( + self, batch_size: int, num_epochs: int = 1 + ) -> tuple[dict[str, list[int]], list[timedelta], dict]: + torch.manual_seed(1) + + dataloader = self.dataloader(batch_size=batch_size, shuffle=True) + + sizes = defaultdict(list) + epoch_durations = [] + + with TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + memray_fp = tmpdir / ".memray" + memray_stats_fp = tmpdir / "memray_stats.json" + + with Tracker(memray_fp, follow_fork=True): + for epoch in range(num_epochs): + epoch_start = datetime.now() + for B in dataloader: + for k, v in B.items(): + sizes[k].append(BenchmarkableDataset.tensor_size(v)) + epoch_durations.append(datetime.now() - epoch_start) + memray_stats = get_memray_stats(memray_fp, memray_stats_fp) + + return sizes, epoch_durations, memray_stats diff --git a/benchmark/nrt_dataset.py b/benchmark/nrt_dataset.py index 9c36ece..e94ccb3 100644 --- a/benchmark/nrt_dataset.py +++ b/benchmark/nrt_dataset.py @@ -3,13 +3,15 @@ from pathlib import Path import numpy as np -from benchmarkable_dataset import BenchmarkableDataset +import torch from mixins import TimeableMixin from torch.utils.data import default_collate from nested_ragged_tensors.ragged_numpy import JointNestedRaggedTensorDict from sample_dataset_builder import SAMPLE_DATASET_T +from .benchmarkable_dataset import BenchmarkableDataset + class NRTDataset(BenchmarkableDataset): @classmethod @@ -71,5 +73,6 @@ def __getitem__(self, i): def collate(self, batch: list[tuple[dict, JointNestedRaggedTensorDict]]) -> dict: dynamics = [d for _, d in batch] collated_dynamics = JointNestedRaggedTensorDict.vstack(dynamics).to_dense() + collated_dynamics = {k: torch.from_numpy(v) for k, v in collated_dynamics.items()} collated_static_data = default_collate([s for s, _ in batch]) return {**collated_static_data, **collated_dynamics} diff --git a/benchmark/run.py b/benchmark/run.py new file mode 100644 index 0000000..5ccd758 --- /dev/null +++ b/benchmark/run.py @@ -0,0 +1,45 @@ +import rootutils + +root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=False) + +SAMPLE_DATASET_PATH = root / "sample_dataset" / "dataset.pkl" + +import pickle +from pathlib import Path + +import humanize +import numpy as np +import pytest + +from benchmark.nrt_dataset import NRTDataset + + +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("max_seq_len", [128, 256, 512]) +@pytest.mark.parametrize("num_epochs", [5]) +def test_profile(tmp_path: Path, batch_size: int, max_seq_len: int, num_epochs: int): + with open(SAMPLE_DATASET_PATH, mode="rb") as f: + raw_D = pickle.load(f) + + out = {} + with NRTDataset.TemporaryDataset(raw_D, tmp_path) as (kwargs, prep_times, disk_size): + out["prep_times"] = prep_times + out["disk_size"] = disk_size + print(f"Dataset takes up {humanize.naturalsize(disk_size)}") + + D = NRTDataset(**kwargs, max_seq_len=max_seq_len) + batch_sizes, epoch_durations, memray_stats = D.benchmark( + batch_size=batch_size, + num_epochs=num_epochs, + ) + + out["batch_sizes"] = batch_sizes + out["epoch_durations"] = epoch_durations + out["memray_stats"] = memray_stats + out["peak_memory"] = memray_stats["metadata"]["peak_memory"] + + print(D._profile_durations()) + print(f"Peak memory: {humanize.naturalsize(out['peak_memory'])}") + + average_epoch_duration = np.mean(epoch_durations) + print(f"Average epoch duration: {average_epoch_duration} seconds")