Skip to content

Commit

Permalink
Fixes for remote checkpointing (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Aug 29, 2024
1 parent 3d53c65 commit 35145b0
Show file tree
Hide file tree
Showing 15 changed files with 396 additions and 127 deletions.
29 changes: 29 additions & 0 deletions .github/workflows/pr_checks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
name: PR Checks

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

on:
pull_request:
branches:
- main
paths:
- 'src/**'

jobs:
changelog:
name: CHANGELOG
runs-on: ubuntu-latest
if: github.event_name == 'pull_request'

steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Check that CHANGELOG has been updated
run: |
# If this step fails, this means you haven't updated the CHANGELOG.md
# file with notes on your contribution.
git diff --name-only $(git merge-base origin/main HEAD) | grep '^CHANGELOG.md$' && echo "Thanks for helping keep our CHANGELOG up-to-date!"
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `block_idx` attribute to the `TransformerBlock` class.
- Added `init_method` option to `Transformer` for controlling how the weights are initialized.

### Fixed

- Fixed `list_directory` for remote folders.

## [v1.0.1](https://github.com/allenai/OLMo-core/releases/tag/v1.0.1) - 2024-08-26

### Fixed
Expand Down
10 changes: 4 additions & 6 deletions src/olmo_core/data/memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class MemMapDatasetConfig(Config):
tokenizer: TokenizerConfig
paths: Optional[List[str]] = None
mix: Optional[DataMix] = None
mix_base_dir: Optional[str] = None
memmap_dtype: Optional[MemMapDType] = None
metadata: Optional[List[Dict[str, Any]]] = None
include_instance_metadata: bool = True
Expand Down Expand Up @@ -105,12 +106,9 @@ def get_memmap_dtype(

raise ValueError("vocab size too big!")

def build(self, mix_base_dir: Optional[str] = None) -> MemMapDataset:
def build(self) -> MemMapDataset:
"""
Construct the corresponding :class:`MemMapDataset`.
:param mix_base_dir: The base directory for the :data:`mix`, e.g. "s3://ai2-llm".
Required if initializing from a data mix.
"""
if (self.paths is None) == (self.mix is None):
raise OLMoConfigurationError("Exactly one of 'paths' or 'mix' is required")
Expand All @@ -131,15 +129,15 @@ def build(self, mix_base_dir: Optional[str] = None) -> MemMapDataset:
paths = self.paths
else:
assert self.mix is not None
if mix_base_dir is None:
if self.mix_base_dir is None:
raise OLMoConfigurationError(
"'mix_base_dir' is required to build a dataset from a mix"
)
if self.tokenizer.identifier is None:
raise OLMoConfigurationError(
"Missing tokenizer identifier required to construct data mix"
)
paths = self.mix.build(mix_base_dir, self.tokenizer.identifier)
paths = self.mix.build(self.mix_base_dir, self.tokenizer.identifier)

dataset = MemMapDataset(
*paths,
Expand Down
14 changes: 9 additions & 5 deletions src/olmo_core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,11 +613,15 @@ def _s3_clear_directory(scheme: str, bucket_name: str, prefix: str, max_attempts


def _s3_list_directory(scheme: str, bucket_name: str, prefix: str) -> Generator[str, None, None]:
response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/")
assert not response["IsTruncated"] # need to handle this if it happens
for item in response.get("CommonPrefixes", []):
prefix = item["Prefix"].strip("/")
yield f"{scheme}://{bucket_name}/{prefix}"
client = _get_s3_client(scheme)
paginator = client.get_paginator("list_objects_v2")
if not prefix.endswith("/"):
prefix = prefix + "/"
for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix, MaxKeys=50, Delimiter="/"):
for file_item in page.get("Contents", []):
yield f"{scheme}://{bucket_name}/{file_item['Key']}"
for dir_item in page.get("CommonPrefixes", []):
yield f"{scheme}://{bucket_name}/{dir_item['Prefix'].strip('/')}"


#############################################
Expand Down
7 changes: 5 additions & 2 deletions src/olmo_core/launch/beaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def default_env_vars(self) -> List[Tuple[str, str]]:
("S3_PROFILE", "S3"),
("WEKA_PROFILE", "WEKA"),
("NUM_NODES", str(self.num_nodes)),
("OLMO_CORE_VERSION", VERSION),
]
if self.shared_filesystem:
env_vars.append((OLMO_SHARED_FS_ENV_VAR, "1"))
Expand Down Expand Up @@ -306,8 +307,8 @@ def build_experiment_spec(self) -> ExperimentSpec:
"set -exuo pipefail",
"mkdir -p /olmo-core-runtime",
"cd /olmo-core-runtime",
f"git clone https://github.com/{github_account}/{github_repo} .",
f"git checkout {git_ref}",
'git clone "${REPO_URL}" .',
'git checkout "${GIT_REF}"',
"git submodule update --init --recursive",
*self.setup_steps,
" ".join(self._get_torchrun_cmd()) + " $@",
Expand All @@ -333,6 +334,8 @@ def build_experiment_spec(self) -> ExperimentSpec:
)
.with_dataset("/olmo-core", beaker=entrypoint_dataset.id)
.with_constraint(cluster=self.clusters)
.with_env_var("REPO_URL", f"https://github.com/{github_account}/{github_repo}")
.with_env_var("GIT_REF", git_ref)
)

for name, val in self._get_env_vars():
Expand Down
28 changes: 15 additions & 13 deletions src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Literal, Optional, Sequence, Union
from typing import List, Literal, Optional, Sequence, Union, cast

import torch
import torch.nn as nn
Expand All @@ -14,10 +14,10 @@
has_flash_attn,
)

from ..attention import Attention, AttentionConfig, AttentionType
from ..attention import AttentionConfig, AttentionType
from ..buffer_cache import BufferCache
from ..feed_forward import FeedForwardConfig
from ..layer_norm import LayerNormConfig, LayerNormType
from ..layer_norm import LayerNorm, LayerNormConfig, LayerNormType
from ..rope import RoPEConfig, RoPEType
from .block import TransformerBlock, TransformerBlockConfig, TransformerBlockType
from .init import InitMethod
Expand Down Expand Up @@ -532,21 +532,23 @@ def init_weights(
self.init_method.init_embeddings(self.embeddings)

for block in self.blocks:
assert isinstance(block, TransformerBlock)
# This might fail if it's wrapped.
# assert isinstance(block, TransformerBlock)
block = cast(TransformerBlock, block)
att = block.attention

# Norms.
block_norms = [block.attention_norm, block.feed_forward_norm]
if isinstance(block.attention, Attention):
if block.attention.q_norm is not None:
block_norms.append(block.attention.q_norm)
if block.attention.k_norm is not None:
block_norms.append(block.attention.k_norm)
block_norms: List[LayerNorm] = [block.attention_norm, block.feed_forward_norm]
if hasattr(att, "q_norm") and att.q_norm is not None:
block_norms.append(att.q_norm)
if hasattr(att, "k_norm") and att.k_norm is not None:
block_norms.append(att.k_norm)
for norm in block_norms:
norm.reset_parameters()

# Attention weights.
self.init_method.init_attention(
block.attention, block_idx=block.block_idx, num_blocks=len(self.blocks)
att, block_idx=block.block_idx, num_blocks=len(self.blocks)
)

# Feed-forward weights.
Expand All @@ -555,8 +557,8 @@ def init_weights(
)

# Warm up RoPE cache.
if max_seq_len is not None and block.attention.rope is not None:
block.attention.rope.warmup_cache(max_seq_len, device)
if max_seq_len is not None and att.rope is not None:
att.rope.warmup_cache(max_seq_len, device)

if self.norm is not None:
self.norm.reset_parameters()
Expand Down
3 changes: 2 additions & 1 deletion src/olmo_core/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@
from ..io import add_cached_path_clients
from ..utils import LogFilterType, prepare_cli_environment, seed_all
from .config import TrainerConfig
from .trainer import Trainer
from .trainer import LoadStrategy, Trainer

__all__ = [
"prepare_training_environment",
"teardown_training_environment",
"TrainerConfig",
"Trainer",
"LoadStrategy",
]


Expand Down
18 changes: 13 additions & 5 deletions src/olmo_core/train/callbacks/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ class CheckpointerCallback(Callback):

save_interval: int = 250
ephemeral_save_interval: Optional[int] = None
pre_train_checkpoint: bool = True
pre_train_checkpoint: Optional[bool] = None
save_async: bool = False

# Bookkeeping

_future: Optional[Future] = None
# NOTE: can't use type annotation here, omegaconf doesn't like it
# _future: Optional[Future] = None
_future = None
_latest_checkpoint: int = -1
_checkpoints: List[str] = field(default_factory=list)
_ephemeral_checkpoints: List[str] = field(default_factory=list)
Expand All @@ -52,16 +54,18 @@ def _await_last_checkpoint(self, blocking: bool = True) -> Optional[Future]:
def _save_checkpoint(self) -> str:
self._await_last_checkpoint()
self._latest_checkpoint = self.step
path = f"{self.trainer.save_folder}/step{self.step}"
log.info(f"Saving checkpoint for step {self.step} to {path}...")
dirname = self.trainer.checkpointer.checkpoint_dirname(self.step)
path = f"{self.trainer.save_folder}/{dirname}"
if self.save_async:
log.info(f"Saving checkpoint for step {self.step} to '{path}' asynchronously...")
self._future = self.trainer.checkpointer.save_async(
path,
self.trainer.model,
self.trainer.optim,
self.trainer.state_dict(),
)
else:
log.info(f"Saving checkpoint for step {self.step} to '{path}'...")
self.trainer.checkpointer.save(
path,
self.trainer.model,
Expand All @@ -81,7 +85,11 @@ def pre_train(self):
)
self.trainer.checkpointer.process_group = dist.new_group()

if self.step == 0 and self.pre_train_checkpoint:
if (
self.step == 0
and self.pre_train_checkpoint is not False
and not self.trainer.checkpoint_loaded
):
self._checkpoints.append(self._save_checkpoint())

def post_train_batch(self):
Expand Down
Loading

0 comments on commit 35145b0

Please sign in to comment.