Skip to content

Commit

Permalink
Add function to get checkpoint metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 12, 2024
1 parent c0e47cc commit 2e2b35b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `olmo_core.distributed.checkpoint.get_checkpoint_metadata()` function.

### Fixed

- Old ephemeral checkpoints won't be removed until after the latest ephemeral checkpoint is saved successfully.
Expand Down
13 changes: 13 additions & 0 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.state_dict as dist_cp_sd
import torch.nn as nn
from torch.distributed.checkpoint.metadata import Metadata

from olmo_core.aliases import PathOrStr
from olmo_core.io import clear_directory, dir_is_empty, is_url, normalize_path
Expand All @@ -49,6 +50,7 @@
"async_save_model_and_optim_state",
"load_model_and_optim_state",
"unshard_checkpoint",
"get_checkpoint_metadata",
]

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -310,6 +312,17 @@ def save(state_dict: Dict[str, Any], path: Path):
return model_path, optim_path


def get_checkpoint_metadata(dir: PathOrStr) -> Metadata:
"""
Load the metadata from a checkpoint.
:param dir: The path/URL to the checkpoint.
"""
dir = normalize_path(dir)
storage_reader = RemoteFileSystemReader(dir)
return storage_reader.read_metadata()


def _prepare_env_for_save(
dir: PathOrStr,
process_group: Optional[dist.ProcessGroup] = None,
Expand Down

0 comments on commit 2e2b35b

Please sign in to comment.