diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml new file mode 100644 index 00000000..865ff0cc --- /dev/null +++ b/.github/workflows/pr_checks.yml @@ -0,0 +1,29 @@ +name: PR Checks + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + pull_request: + branches: + - main + paths: + - 'src/**' + +jobs: + changelog: + name: CHANGELOG + runs-on: ubuntu-latest + if: github.event_name == 'pull_request' + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Check that CHANGELOG has been updated + run: | + # If this step fails, this means you haven't updated the CHANGELOG.md + # file with notes on your contribution. + git diff --name-only $(git merge-base origin/main HEAD) | grep '^CHANGELOG.md$' && echo "Thanks for helping keep our CHANGELOG up-to-date!" diff --git a/CHANGELOG.md b/CHANGELOG.md index bcb4e689..624478f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `block_idx` attribute to the `TransformerBlock` class. - Added `init_method` option to `Transformer` for controlling how the weights are initialized. +### Fixed + +- Fixed `list_directory` for remote folders. + ## [v1.0.1](https://github.com/allenai/OLMo-core/releases/tag/v1.0.1) - 2024-08-26 ### Fixed diff --git a/src/olmo_core/data/memmap_dataset.py b/src/olmo_core/data/memmap_dataset.py index 584cd0e4..7941c8ba 100644 --- a/src/olmo_core/data/memmap_dataset.py +++ b/src/olmo_core/data/memmap_dataset.py @@ -46,6 +46,7 @@ class MemMapDatasetConfig(Config): tokenizer: TokenizerConfig paths: Optional[List[str]] = None mix: Optional[DataMix] = None + mix_base_dir: Optional[str] = None memmap_dtype: Optional[MemMapDType] = None metadata: Optional[List[Dict[str, Any]]] = None include_instance_metadata: bool = True @@ -105,12 +106,9 @@ def get_memmap_dtype( raise ValueError("vocab size too big!") - def build(self, mix_base_dir: Optional[str] = None) -> MemMapDataset: + def build(self) -> MemMapDataset: """ Construct the corresponding :class:`MemMapDataset`. - - :param mix_base_dir: The base directory for the :data:`mix`, e.g. "s3://ai2-llm". - Required if initializing from a data mix. """ if (self.paths is None) == (self.mix is None): raise OLMoConfigurationError("Exactly one of 'paths' or 'mix' is required") @@ -131,7 +129,7 @@ def build(self, mix_base_dir: Optional[str] = None) -> MemMapDataset: paths = self.paths else: assert self.mix is not None - if mix_base_dir is None: + if self.mix_base_dir is None: raise OLMoConfigurationError( "'mix_base_dir' is required to build a dataset from a mix" ) @@ -139,7 +137,7 @@ def build(self, mix_base_dir: Optional[str] = None) -> MemMapDataset: raise OLMoConfigurationError( "Missing tokenizer identifier required to construct data mix" ) - paths = self.mix.build(mix_base_dir, self.tokenizer.identifier) + paths = self.mix.build(self.mix_base_dir, self.tokenizer.identifier) dataset = MemMapDataset( *paths, diff --git a/src/olmo_core/io.py b/src/olmo_core/io.py index f21aa07f..a4736740 100644 --- a/src/olmo_core/io.py +++ b/src/olmo_core/io.py @@ -613,11 +613,15 @@ def _s3_clear_directory(scheme: str, bucket_name: str, prefix: str, max_attempts def _s3_list_directory(scheme: str, bucket_name: str, prefix: str) -> Generator[str, None, None]: - response = _get_s3_client(scheme).list_objects(Bucket=bucket_name, Prefix=prefix, Delimiter="/") - assert not response["IsTruncated"] # need to handle this if it happens - for item in response.get("CommonPrefixes", []): - prefix = item["Prefix"].strip("/") - yield f"{scheme}://{bucket_name}/{prefix}" + client = _get_s3_client(scheme) + paginator = client.get_paginator("list_objects_v2") + if not prefix.endswith("/"): + prefix = prefix + "/" + for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix, MaxKeys=50, Delimiter="/"): + for file_item in page.get("Contents", []): + yield f"{scheme}://{bucket_name}/{file_item['Key']}" + for dir_item in page.get("CommonPrefixes", []): + yield f"{scheme}://{bucket_name}/{dir_item['Prefix'].strip('/')}" ############################################# diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 2370af61..7138acca 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -212,6 +212,7 @@ def default_env_vars(self) -> List[Tuple[str, str]]: ("S3_PROFILE", "S3"), ("WEKA_PROFILE", "WEKA"), ("NUM_NODES", str(self.num_nodes)), + ("OLMO_CORE_VERSION", VERSION), ] if self.shared_filesystem: env_vars.append((OLMO_SHARED_FS_ENV_VAR, "1")) @@ -306,8 +307,8 @@ def build_experiment_spec(self) -> ExperimentSpec: "set -exuo pipefail", "mkdir -p /olmo-core-runtime", "cd /olmo-core-runtime", - f"git clone https://github.com/{github_account}/{github_repo} .", - f"git checkout {git_ref}", + 'git clone "${REPO_URL}" .', + 'git checkout "${GIT_REF}"', "git submodule update --init --recursive", *self.setup_steps, " ".join(self._get_torchrun_cmd()) + " $@", @@ -333,6 +334,8 @@ def build_experiment_spec(self) -> ExperimentSpec: ) .with_dataset("/olmo-core", beaker=entrypoint_dataset.id) .with_constraint(cluster=self.clusters) + .with_env_var("REPO_URL", f"https://github.com/{github_account}/{github_repo}") + .with_env_var("GIT_REF", git_ref) ) for name, val in self._get_env_vars(): diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 891d9a81..185c967f 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -1,6 +1,6 @@ import logging from dataclasses import dataclass -from typing import Literal, Optional, Sequence, Union +from typing import List, Literal, Optional, Sequence, Union, cast import torch import torch.nn as nn @@ -14,10 +14,10 @@ has_flash_attn, ) -from ..attention import Attention, AttentionConfig, AttentionType +from ..attention import AttentionConfig, AttentionType from ..buffer_cache import BufferCache from ..feed_forward import FeedForwardConfig -from ..layer_norm import LayerNormConfig, LayerNormType +from ..layer_norm import LayerNorm, LayerNormConfig, LayerNormType from ..rope import RoPEConfig, RoPEType from .block import TransformerBlock, TransformerBlockConfig, TransformerBlockType from .init import InitMethod @@ -532,21 +532,23 @@ def init_weights( self.init_method.init_embeddings(self.embeddings) for block in self.blocks: - assert isinstance(block, TransformerBlock) + # This might fail if it's wrapped. + # assert isinstance(block, TransformerBlock) + block = cast(TransformerBlock, block) + att = block.attention # Norms. - block_norms = [block.attention_norm, block.feed_forward_norm] - if isinstance(block.attention, Attention): - if block.attention.q_norm is not None: - block_norms.append(block.attention.q_norm) - if block.attention.k_norm is not None: - block_norms.append(block.attention.k_norm) + block_norms: List[LayerNorm] = [block.attention_norm, block.feed_forward_norm] + if hasattr(att, "q_norm") and att.q_norm is not None: + block_norms.append(att.q_norm) + if hasattr(att, "k_norm") and att.k_norm is not None: + block_norms.append(att.k_norm) for norm in block_norms: norm.reset_parameters() # Attention weights. self.init_method.init_attention( - block.attention, block_idx=block.block_idx, num_blocks=len(self.blocks) + att, block_idx=block.block_idx, num_blocks=len(self.blocks) ) # Feed-forward weights. @@ -555,8 +557,8 @@ def init_weights( ) # Warm up RoPE cache. - if max_seq_len is not None and block.attention.rope is not None: - block.attention.rope.warmup_cache(max_seq_len, device) + if max_seq_len is not None and att.rope is not None: + att.rope.warmup_cache(max_seq_len, device) if self.norm is not None: self.norm.reset_parameters() diff --git a/src/olmo_core/train/__init__.py b/src/olmo_core/train/__init__.py index 514df23f..3f5ba12b 100644 --- a/src/olmo_core/train/__init__.py +++ b/src/olmo_core/train/__init__.py @@ -49,13 +49,14 @@ from ..io import add_cached_path_clients from ..utils import LogFilterType, prepare_cli_environment, seed_all from .config import TrainerConfig -from .trainer import Trainer +from .trainer import LoadStrategy, Trainer __all__ = [ "prepare_training_environment", "teardown_training_environment", "TrainerConfig", "Trainer", + "LoadStrategy", ] diff --git a/src/olmo_core/train/callbacks/checkpointer.py b/src/olmo_core/train/callbacks/checkpointer.py index 4c776012..436eb1f0 100644 --- a/src/olmo_core/train/callbacks/checkpointer.py +++ b/src/olmo_core/train/callbacks/checkpointer.py @@ -29,12 +29,14 @@ class CheckpointerCallback(Callback): save_interval: int = 250 ephemeral_save_interval: Optional[int] = None - pre_train_checkpoint: bool = True + pre_train_checkpoint: Optional[bool] = None save_async: bool = False # Bookkeeping - _future: Optional[Future] = None + # NOTE: can't use type annotation here, omegaconf doesn't like it + # _future: Optional[Future] = None + _future = None _latest_checkpoint: int = -1 _checkpoints: List[str] = field(default_factory=list) _ephemeral_checkpoints: List[str] = field(default_factory=list) @@ -52,9 +54,10 @@ def _await_last_checkpoint(self, blocking: bool = True) -> Optional[Future]: def _save_checkpoint(self) -> str: self._await_last_checkpoint() self._latest_checkpoint = self.step - path = f"{self.trainer.save_folder}/step{self.step}" - log.info(f"Saving checkpoint for step {self.step} to {path}...") + dirname = self.trainer.checkpointer.checkpoint_dirname(self.step) + path = f"{self.trainer.save_folder}/{dirname}" if self.save_async: + log.info(f"Saving checkpoint for step {self.step} to '{path}' asynchronously...") self._future = self.trainer.checkpointer.save_async( path, self.trainer.model, @@ -62,6 +65,7 @@ def _save_checkpoint(self) -> str: self.trainer.state_dict(), ) else: + log.info(f"Saving checkpoint for step {self.step} to '{path}'...") self.trainer.checkpointer.save( path, self.trainer.model, @@ -81,7 +85,11 @@ def pre_train(self): ) self.trainer.checkpointer.process_group = dist.new_group() - if self.step == 0 and self.pre_train_checkpoint: + if ( + self.step == 0 + and self.pre_train_checkpoint is not False + and not self.trainer.checkpoint_loaded + ): self._checkpoints.append(self._save_checkpoint()) def post_train_batch(self): diff --git a/src/olmo_core/train/checkpoint.py b/src/olmo_core/train/checkpoint.py index 3da9d2ee..b7d80f50 100644 --- a/src/olmo_core/train/checkpoint.py +++ b/src/olmo_core/train/checkpoint.py @@ -1,10 +1,12 @@ +import json import os +import re import tempfile from concurrent.futures import Future from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Generator, Optional, Union +from typing import Any, ClassVar, Dict, Generator, Optional, Tuple, Union import torch import torch.distributed as dist @@ -13,6 +15,7 @@ from torch.optim import Optimizer from ..aliases import PathOrStr +from ..config import Config from ..distributed.checkpoint import ( async_save_model_and_optim_state, load_model_and_optim_state, @@ -30,6 +33,12 @@ upload, ) from ..utils import wait_for +from ..version import VERSION + + +@dataclass +class CheckpointMetadata(Config): + version: str = VERSION @dataclass @@ -38,6 +47,9 @@ class Checkpointer: Trainer checkpointer. """ + METADATA_FNAME: ClassVar[str] = ".metadata.json" + CHECKPOINT_DIR: ClassVar[str] = "step{step}" + save_overwrite: bool = False process_group: Optional[dist.ProcessGroup] = None @@ -48,7 +60,7 @@ def save(self, dir: PathOrStr, model: nn.Module, optim: Optimizer, train_state: dir = normalize_path(dir) with self._temporary_wd(dir) as wd: # Save trainer state. - self._save_train_state(wd, train_state) + self._save_train_state(dir, wd, train_state) # Save model and optim state. model_and_optim_dir = ( @@ -62,6 +74,8 @@ def save(self, dir: PathOrStr, model: nn.Module, optim: Optimizer, train_state: save_overwrite=self.save_overwrite, ) + self._save_metadata(dir, CheckpointMetadata()) + def save_async( self, dir: PathOrStr, model: nn.Module, optim: Optimizer, train_state: Dict[str, Any] ) -> Future[None]: @@ -77,11 +91,11 @@ def save_async( with self._temporary_wd(dir) as wd: # Save trainer state. - self._save_train_state(wd, train_state) + self._save_train_state(dir, wd, train_state) # Save model and optim state. model_and_optim_dir = f"{dir}/model_and_optim" - return async_save_model_and_optim_state( + future = async_save_model_and_optim_state( model_and_optim_dir, model, optim, @@ -89,6 +103,15 @@ def save_async( save_overwrite=self.save_overwrite, ) + def done_callback(fut: Future): + fut.result() + self._save_metadata(dir, CheckpointMetadata()) + + # Upload metadata when everything else is done. + future.add_done_callback(done_callback) + + return future + def load( self, dir: PathOrStr, @@ -108,11 +131,13 @@ def load( trainer_state: Optional[Dict[str, Any]] = None if load_trainer_state: try: - trainer_state = torch.load(cached_path(f"{dir}/train/rank{get_rank()}.pt")) + trainer_state = torch.load( + cached_path(f"{dir}/train/rank{get_rank()}.pt", quiet=True) + ) except FileNotFoundError: # Fall back to rank 0 train state. # This can happen when we're restoring a checkpoint with a different world size. - trainer_state = torch.load(cached_path(f"{dir}/train/rank0.pt")) + trainer_state = torch.load(cached_path(f"{dir}/train/rank0.pt", quiet=True)) # Load model and optimizer state. load_model_and_optim_state( @@ -140,7 +165,9 @@ def write_file(self, dir: PathOrStr, fname: str, contents: Union[str, bytes]) -> Path(dir).mkdir(exist_ok=True, parents=True) mode = "wb" if isinstance(contents, bytes) else "wt" - tmp_file = tempfile.NamedTemporaryFile(mode=mode, delete=False, dir=dir) + tmp_file = tempfile.NamedTemporaryFile( + mode=mode, delete=False, dir=None if is_url(dir) else dir + ) tmp_path = Path(tmp_file.name) try: tmp_file.write(contents) @@ -161,40 +188,68 @@ def write_file(self, dir: PathOrStr, fname: str, contents: Union[str, bytes]) -> finally: tmp_path.unlink(missing_ok=True) + @classmethod + def checkpoint_dirname(cls, step: int) -> str: + return cls.CHECKPOINT_DIR.format(step=step) + @classmethod def dir_is_checkpoint(cls, dir: PathOrStr) -> bool: """ - Check if a directory contains a checkpoint. + Check if a directory is a checkpoint directory. """ dir = normalize_path(dir) - paths_to_check = [f"{dir}/train/rank0.pt", f"{dir}/model_and_optim/.metadata"] + paths_to_check = [ + f"{dir}/train/rank0.pt", + f"{dir}/model_and_optim/.metadata", + f"{dir}/{cls.METADATA_FNAME}", + ] for path in paths_to_check: if not file_exists(path): return False return True @classmethod - def latest_checkpoint(cls, dir: PathOrStr) -> str: + def find_checkpoints(cls, dir: PathOrStr) -> Generator[Tuple[int, str], None, None]: """ - Find the latest checkpoint in a directory of checkpoints. + Find checkpoints within a directory. """ dir = normalize_path(dir) - latest_step: Optional[int] = None - latest_checkpoint: Optional[str] = None for path in list_directory(dir): name = os.path.basename(path) - if not name.startswith("step"): - continue + if (m := re.match("^" + cls.CHECKPOINT_DIR.format(step=r"(\d+)$"), name)) is not None: + step = int(m.group(1)) - try: - step = int(name.replace("step", "")) - except ValueError: - continue + # Make sure the directory is a valid checkpoint dir. + if not cls.dir_is_checkpoint(path): + continue - # Make sure the directory is a valid checkpoint dir. - if not cls.dir_is_checkpoint(path): - continue + yield step, path + @classmethod + def contains_checkpoint(cls, dir: PathOrStr) -> bool: + """ + Check if a directory is a checkpoint directory or contains a child checkpoint directory. + """ + if cls.dir_is_checkpoint(dir): + return True + + try: + next(cls.find_checkpoints(dir)) + return True + except (StopIteration, FileNotFoundError): + return False + + @classmethod + def latest_checkpoint(cls, dir: PathOrStr) -> str: + """ + Find the latest checkpoint in a directory of checkpoints. + + :raises FileNotFoundError: If no checkpoints are found. + """ + dir = normalize_path(dir) + latest_step: Optional[int] = None + latest_checkpoint: Optional[str] = None + for step, path in cls.find_checkpoints(dir): if latest_step is None or step > latest_step: latest_step = step latest_checkpoint = path @@ -204,13 +259,18 @@ def latest_checkpoint(cls, dir: PathOrStr) -> str: else: return latest_checkpoint - def _save_train_state(self, wd: Path, train_state: Dict[str, Any]): + def _save_train_state(self, dir: PathOrStr, wd: Path, train_state: Dict[str, Any]): train_dir = wd / "train" - if get_fs_local_rank() == 0: + # NOTE: if 'dir' is a URL, the 'wd' will be a different temp dir for each rank. + if is_url(dir) or get_fs_local_rank() == 0: train_dir.mkdir(exist_ok=True, parents=True) wait_for(train_dir.exists, description=f"Waiting on '{train_dir}' to be created...") torch.save(train_state, train_dir / f"rank{get_rank()}.pt") + def _save_metadata(self, dir: PathOrStr, metadata: CheckpointMetadata): + if get_rank() == 0: + self.write_file(dir, self.METADATA_FNAME, json.dumps(metadata.as_dict(json_safe=True))) + def _prepare_dir(self, dir: PathOrStr, ensure_exists: bool = True) -> str: dir = normalize_path(dir) @@ -271,19 +331,19 @@ def _teardown_tmp_dir(self, dir: PathOrStr, tmp_dir: Path): # So we wait here across all ranks until that final checkpoint directory is visible. wait_for(lambda: Path(dir).exists(), "Waiting for checkpoint directory", timeout=10.0) else: - if get_fs_local_rank() == 0: - # Upload files to final location. - for path in tmp_dir.glob("**/*"): - if not path.is_file(): - continue - upload( - path, - f"{dir}/{path.relative_to(tmp_dir)}", - save_overwrite=self.save_overwrite, - ) - - # Then remove the temp dir. - tmp_dir.unlink(missing_ok=True) + # NOTE: each rank will have its own tmp dir + # Upload files to final location. + for path in tmp_dir.glob("**/*"): + if not path.is_file(): + continue + upload( + path, + f"{dir}/{path.relative_to(tmp_dir)}", + save_overwrite=self.save_overwrite, + ) + + # Then remove the temp dir. + clear_directory(tmp_dir) barrier() diff --git a/src/olmo_core/train/config.py b/src/olmo_core/train/config.py index 26ccb959..2cf01d97 100644 --- a/src/olmo_core/train/config.py +++ b/src/olmo_core/train/config.py @@ -11,7 +11,7 @@ from ..utils import get_default_device from .callbacks import Callback from .checkpoint import Checkpointer -from .trainer import Trainer +from .trainer import LoadStrategy, Trainer from .utils import Duration, DurationUnit @@ -26,6 +26,9 @@ class TrainerConfig(Config): global_batch_size: int microbatch_size: int + load_path: Optional[str] = None + load_strategy: LoadStrategy = LoadStrategy.if_available + device: Optional[str] = None save_overwrite: bool = False max_duration: Duration = field( diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index 2c664e0b..de1c815b 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -15,6 +15,7 @@ from torch.utils.data import DataLoader from ..aliases import PathOrStr +from ..config import StrEnum from ..data import DataCollator, IterableDataset, MemMapDataset from ..distributed.utils import ( all_reduce_value, @@ -58,6 +59,27 @@ TRAIN_Z_LOSS_METRIC = "train/Z loss" +class LoadStrategy(StrEnum): + """ + Determines the strategy for loading checkpoints prior to training. + """ + + if_available = "if_available" + """ + Only load from the load path if a checkpoint exists there. + """ + + always = "always" + """ + Always try loading from the load path. + """ + + never = "never" + """ + Never load from the load path. + """ + + @dataclass class Trainer: """ @@ -137,6 +159,17 @@ class Trainer: Microbatch size per rank, i.e. the number of instances to process at a time from each rank. """ + load_path: Optional[PathOrStr] = None + """ + Where to load a checkpoint from prior to training. + Defaults to ``save_folder``. + """ + + load_strategy: LoadStrategy = LoadStrategy.if_available + """ + The strategy for loading a checkpoint prior to training. + """ + metrics_collect_interval: int = 5 """ How often (in steps) to collect, reduce, and pass on metrics to the @@ -232,9 +265,12 @@ class Trainer: _rank_batch_size: Optional[int] = None _thread_pool: Optional[ThreadPoolExecutor] = None _bookkeeping_pg: Optional[dist.ProcessGroup] = None + _checkpoint_loaded: bool = False def __post_init__(self): self.save_folder = normalize_path(self.save_folder) + if self.load_path is not None: + self.load_path = normalize_path(self.load_path) # If save folder is a local directory, make sure we're using a shared filesystem. if not is_url(self.save_folder) and get_fs_local_rank() != get_rank(): @@ -414,6 +450,13 @@ def thread_pool(self) -> ThreadPoolExecutor: self._thread_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="trainer") return self._thread_pool + @property + def checkpoint_loaded(self) -> bool: + """ + If a checkpoint has been loaded. + """ + return self._checkpoint_loaded + def cancel_run(self, reason: str): """ Mark the run canceled. @@ -426,12 +469,21 @@ def cancel_run(self, reason: str): def fit(self): """ - Fit the model. + Fit the model, potentially loading a checkpoint before hand depending on the + :data:`load_strategy`. """ self._canceled = False self._cancel_reason = None self._canceling_rank = None + # Maybe load a checkpoint. + if not self.checkpoint_loaded: + load_path = self.load_path if self.load_path is not None else self.save_folder + if self.load_strategy == LoadStrategy.always: + self.load_checkpoint(load_path) + elif self.load_strategy == LoadStrategy.if_available: + self.maybe_load_checkpoint(load_path) + log.info(f"Training for {self.max_steps:,d} steps") self.model.train() @@ -487,6 +539,8 @@ def load_state_dict(self, state_dict: Dict[str, Any]): ] self.epoch = state_dict["epoch"] + log.info(f"Will resume training from step {self.global_step}, epoch {self.epoch}") + if state_dict["world_size"] == get_world_size(): # global world size here on purpose rng_state = EnvRngStates.from_dict(state_dict["rng"]) if not rng_state.restore(): @@ -500,23 +554,26 @@ def load_state_dict(self, state_dict: Dict[str, Any]): ) def load_checkpoint( - self, dir: PathOrStr, load_optimizer_state: bool = True, load_trainer_state: bool = True + self, dir: PathOrStr, *, load_optimizer_state: bool = True, load_trainer_state: bool = True ): """ Load a checkpoint. - :param dir: The path/URL to the checkpoint. + .. note:: + :meth:`fit()` may call this method automatically depending on the :data:`load_strategy`. + + :param dir: The path/URL to a checkpoint or a folder of checkpoints. :param load_optimizer_state: Load optimizer state. :param load_trainer_state: Load trainer state. """ - if not self.checkpointer.dir_is_checkpoint(dir): + dir = normalize_path(dir) + + # NOTE: to avoid making a ton of client requests (S3 or otherwise) we only make those + # requests from rank 0 then scatter the result to the other ranks. + if get_rank() == 0 and not self.checkpointer.dir_is_checkpoint(dir): # Try to find the latest checkpoint in the directory. - latest_checkpoint: Optional[str] = None - if get_rank() == 0: - latest_checkpoint = self.checkpointer.latest_checkpoint(dir) - latest_checkpoint = scatter_object(latest_checkpoint) - assert latest_checkpoint is not None - dir = latest_checkpoint + dir = self.checkpointer.latest_checkpoint(dir) + dir = scatter_object(dir) log.info(f"Loading checkpoint from '{dir}'...") trainer_state = self.checkpointer.load( @@ -529,8 +586,35 @@ def load_checkpoint( if load_trainer_state: assert trainer_state is not None self.load_state_dict(trainer_state) + + self._checkpoint_loaded = True log.info("Checkpoint successfully loaded") + def maybe_load_checkpoint( + self, dir: PathOrStr, *, load_optimizer_state: bool = True, load_trainer_state: bool = True + ) -> bool: + """ + Like :meth:`load_checkpoint()` but is a no-op if there is no checkpoint in the ``dir`` provided. + + .. note:: + :meth:`fit()` may call this method automatically depending on the :data:`load_strategy`. + + :returns: If a checkpoint was loaded. + """ + should_load: bool = True + if get_rank() == 0: + should_load = self.checkpointer.contains_checkpoint(dir) + should_load = scatter_object(should_load) + if should_load: + self.load_checkpoint( + dir, + load_optimizer_state=load_optimizer_state, + load_trainer_state=load_trainer_state, + ) + else: + log.warning(f"No checkpoint found in '{dir}', will train from scratch...") + return should_load + def record_metric( self, name: str, value: Union[float, torch.Tensor], reduce_type: Optional[ReduceType] = None ): diff --git a/src/scripts/train/OLMo-7B.py b/src/scripts/train/OLMo-7B.py index 817bd3f4..8572aa22 100644 --- a/src/scripts/train/OLMo-7B.py +++ b/src/scripts/train/OLMo-7B.py @@ -3,9 +3,10 @@ """ import json +import logging import sys from dataclasses import dataclass -from typing import List, Optional +from typing import List from beaker import Beaker @@ -13,6 +14,7 @@ from olmo_core.data import DataMix, MemMapDatasetConfig, TokenizerConfig from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType from olmo_core.distributed.utils import get_num_nodes, get_rank, init_hybrid_shard_mesh +from olmo_core.io import is_url from olmo_core.launch.beaker import ( BeakerEnvSecret, BeakerLaunchConfig, @@ -36,10 +38,13 @@ ) from olmo_core.utils import generate_uuid, get_default_device, prepare_cli_environment +log = logging.getLogger(__name__) + class SubCmd(StrEnum): launch = "launch" train = "train" + dry_run = "dry_run" @dataclass @@ -50,26 +55,31 @@ class ExperimentConfig(Config): optim: AdamWConfig dataset: MemMapDatasetConfig trainer: TrainerConfig - load_path: Optional[str] = None seed: int = 3423 -def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: +def build_config(run_name: str, cluster: str, overrides: List[str]) -> ExperimentConfig: + root_dir: str = "weka://oe-training-default/ai2-llm" + weka_buckets: List[BeakerWekaBucket] = [] + if "jupiter" in cluster: + root_dir = "/weka/oe-training-default/ai2-llm" + weka_buckets.append(BeakerWekaBucket("oe-training-default", "/weka/oe-training-default")) + beaker_user = (Beaker.from_env().account.whoami().name).upper() launch_config = BeakerLaunchConfig( name=f"{run_name}-{generate_uuid()[:8]}", budget="ai2/oe-training", - cmd=["src/scripts/train/OLMo-7B.py", SubCmd.train, run_name, *overrides], + cmd=["src/scripts/train/OLMo-7B.py", SubCmd.train, run_name, cluster, *overrides], task_name="train", workspace="ai2/OLMo-core", description="Testing OLMo-core launch utilities", - clusters=["ai2/jupiter-cirrascale-2"], - weka_buckets=[BeakerWekaBucket("oe-training-default", "/weka/oe-training-default")], + clusters=[cluster], + weka_buckets=weka_buckets, beaker_image=OLMoCoreBeakerImage.nightly, # some features require nightly at the moment num_nodes=1, num_gpus=8, - shared_filesystem=True, + shared_filesystem=not is_url(root_dir), allow_dirty=False, env_secrets=[ BeakerEnvSecret(name="BEAKER_TOKEN", secret=f"{beaker_user}_BEAKER_TOKEN"), @@ -116,13 +126,13 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: DataMix.OLMoE_mix_0824, tokenizer=tokenizer_config, sequence_length=4096, + mix_base_dir=root_dir, ) - save_folder = f"/weka/oe-training-default/ai2-llm/checkpoints/OLMo-medium/{beaker_user.lower()}/{run_name}" - + save_folder = f"{root_dir}/checkpoints/OLMo-medium/{beaker_user.lower()}/{run_name}" trainer_config = ( TrainerConfig( - work_dir=save_folder, + work_dir=save_folder if not is_url(save_folder) else f"/tmp/{run_name}", save_folder=save_folder, global_batch_size=1024, microbatch_size=2, @@ -140,6 +150,13 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: num_flops_per_token=model_config.num_flops_per_token(dataset_config.sequence_length) ) ) + .with_callback( + CheckpointerCallback( + save_interval=10_000, + ephemeral_save_interval=250, + save_async=True, + ) + ) ) experiment_config = ExperimentConfig( @@ -151,15 +168,6 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: trainer=trainer_config, ).merge(overrides) - experiment_config.trainer.with_callback( - CheckpointerCallback( - save_interval=10_000, - ephemeral_save_interval=1024, - save_async=True, - pre_train_checkpoint=experiment_config.load_path is None, - ) - ) - return experiment_config @@ -189,39 +197,44 @@ def train(config: ExperimentConfig): dp_mesh=None if get_num_nodes() == 1 else init_hybrid_shard_mesh(), ) optim = config.optim.build(model) - dataset = config.dataset.build(mix_base_dir="/weka/oe-training-default/ai2-llm") + dataset = config.dataset.build() trainer = config.trainer.build(model, optim, dataset) - # Save config to file. + # Save the config to file. if get_rank() == 0: trainer.write_file("config.json", json.dumps(config_dict, indent=2)) - # Maybe load a checkpoint. - if config.load_path is not None: - trainer.load_checkpoint(config.load_path) - # Train. trainer.fit() if __name__ == "__main__": - usage = f"Usage: python {sys.argv[0]} {SubCmd.launch}|{SubCmd.train} run_name [OVERRIDES...]" + usage = ( + f"Usage: python {sys.argv[0]} {SubCmd.launch}|{SubCmd.train}|{SubCmd.dry_run} run_name cluster [OVERRIDES...]\n\n" + "Example:\n" + f"$ python {sys.argv[0]} {SubCmd.launch} OLMo-core-7B ai2/pluto-cirrascale --launch.num_nodes=2" + ) - if len(sys.argv) < 3: + if len(sys.argv) < 4: print(usage) sys.exit(1) cmd = sys.argv[1] run_name = sys.argv[2] - overrides = sys.argv[3:] + cluster = sys.argv[3] + overrides = sys.argv[4:] if sys.argv[1] == SubCmd.launch: prepare_cli_environment() - config = build_config(run_name, overrides) + config = build_config(run_name, cluster, overrides) launch(config) + elif sys.argv[1] == SubCmd.dry_run: + prepare_cli_environment() + config = build_config(run_name, cluster, overrides) + log.info(config) else: prepare_training_environment() - config = build_config(run_name, overrides) + config = build_config(run_name, cluster, overrides) try: train(config) finally: diff --git a/src/test/distributed/utils.py b/src/test/distributed/utils.py index 36a2ef25..1cc82fe8 100644 --- a/src/test/distributed/utils.py +++ b/src/test/distributed/utils.py @@ -117,7 +117,7 @@ def log_record_factory(*args, **kwargs) -> logging.LogRecord: log.info("Starting test...") - if torch.cuda.is_available(): + if "nccl" in backend: torch.cuda.set_device(int(process_rank)) try: diff --git a/src/test/io_test.py b/src/test/io_test.py index 74095537..e925d542 100644 --- a/src/test/io_test.py +++ b/src/test/io_test.py @@ -1,6 +1,40 @@ -from olmo_core.io import deserialize_from_tensor, serialize_to_tensor +from olmo_core.io import ( + deserialize_from_tensor, + list_directory, + serialize_to_tensor, + upload, +) def test_serde_from_tensor(): data = {"a": (1, 2)} assert deserialize_from_tensor(serialize_to_tensor(data)) == data + + +def test_list_local_directory(tmp_path): + (tmp_path / "file1.json").touch() + (tmp_path / "dir1").mkdir() + (tmp_path / "dir1" / "file2").touch() + + # Should only list immediate children (files and dirs), but not files in subdirs. + # The paths returned should be full paths. + assert set(list_directory(tmp_path)) == {f"{tmp_path}/file1.json", f"{tmp_path}/dir1"} + + +def test_list_remote_directory(tmp_path, s3_checkpoint_dir): + (tmp_path / "file1.json").touch() + (tmp_path / "dir1").mkdir() + (tmp_path / "dir1" / "file2").touch() + + for path in tmp_path.glob("**/*"): + if not path.is_file(): + continue + rel_path = path.relative_to(tmp_path) + upload(path, f"{s3_checkpoint_dir}/{rel_path}") + + # Should only list immediate children (files and dirs), but not files in subdirs. + # The paths returned should be full paths. + assert set(list_directory(s3_checkpoint_dir)) == { + f"{s3_checkpoint_dir}/file1.json", + f"{s3_checkpoint_dir}/dir1", + } diff --git a/src/test/train/checkpoint_test.py b/src/test/train/checkpoint_test.py index 7f2f1892..1b4095ff 100644 --- a/src/test/train/checkpoint_test.py +++ b/src/test/train/checkpoint_test.py @@ -1,16 +1,21 @@ import os +import time import torch import torch.distributed as dist -from olmo_core.distributed.utils import get_rank +from olmo_core.distributed.utils import barrier, get_rank +from olmo_core.io import dir_is_empty, file_exists, is_url, normalize_path from olmo_core.train.checkpoint import Checkpointer from ..distributed.utils import run_distributed_test -def run_checkpointer_with_local_dir(dir, model_factory): - os.environ["OLMO_SHARED_FS"] = "1" +def run_checkpointer(base_dir, model_factory): + dir = f"{normalize_path(base_dir)}/{Checkpointer.checkpoint_dirname(10)}" + + if not is_url(dir): + os.environ["OLMO_SHARED_FS"] = "1" checkpointer = Checkpointer() model = model_factory() @@ -18,10 +23,14 @@ def run_checkpointer_with_local_dir(dir, model_factory): # Save checkpoint. checkpointer.save(dir, model, optim, {"rank": get_rank()}) - assert (dir / "train").is_dir() - assert (dir / "train" / "rank0.pt").is_file() - assert (dir / "train" / "rank1.pt").is_file() - assert (dir / "model_and_optim").is_dir() + barrier() + + assert file_exists((f"{dir}/train/rank0.pt")) + assert file_exists((f"{dir}/train/rank1.pt")) + assert not dir_is_empty((f"{dir}/model_and_optim")) + assert checkpointer.dir_is_checkpoint(dir) + assert list(checkpointer.find_checkpoints(base_dir)) == [(10, dir)] + assert checkpointer.latest_checkpoint(base_dir) == dir # Load checkpoint. train_state = checkpointer.load(dir, model, optim) @@ -30,11 +39,20 @@ def run_checkpointer_with_local_dir(dir, model_factory): def test_checkpointer_with_local_dir(tmp_path, tiny_model_factory): - run_distributed_test(run_checkpointer_with_local_dir, func_args=(tmp_path, tiny_model_factory)) + run_distributed_test(run_checkpointer, func_args=(tmp_path, tiny_model_factory)) + +def test_checkpointer_with_remote_dir(s3_checkpoint_dir, tiny_model_factory): + run_distributed_test( + run_checkpointer, func_args=(s3_checkpoint_dir, tiny_model_factory), start_method="spawn" + ) -def run_async_checkpointer_with_local_dir(dir, model_factory): - os.environ["OLMO_SHARED_FS"] = "1" + +def run_async_checkpointer(dir, model_factory): + dir = normalize_path(dir) + + if not is_url(dir): + os.environ["OLMO_SHARED_FS"] = "1" checkpointer = Checkpointer(process_group=dist.new_group()) model = model_factory() @@ -43,11 +61,13 @@ def run_async_checkpointer_with_local_dir(dir, model_factory): # Save checkpoint. future = checkpointer.save_async(dir, model, optim, {"rank": get_rank()}) future.result() + time.sleep(0.1) # allow done callback to run. + barrier() - assert (dir / "train").is_dir() - assert (dir / "train" / "rank0.pt").is_file() - assert (dir / "train" / "rank1.pt").is_file() - assert (dir / "model_and_optim").is_dir() + assert file_exists((f"{dir}/train/rank0.pt")) + assert file_exists((f"{dir}/train/rank1.pt")) + assert not dir_is_empty((f"{dir}/model_and_optim")) + assert checkpointer.dir_is_checkpoint(dir) # Load checkpoint. train_state = checkpointer.load(dir, model, optim) @@ -56,6 +76,12 @@ def run_async_checkpointer_with_local_dir(dir, model_factory): def test_async_checkpointer_with_local_dir(tmp_path, tiny_model_factory): + run_distributed_test(run_async_checkpointer, func_args=(tmp_path, tiny_model_factory)) + + +def test_async_checkpointer_with_remote_dir(s3_checkpoint_dir, tiny_model_factory): run_distributed_test( - run_async_checkpointer_with_local_dir, func_args=(tmp_path, tiny_model_factory) + run_async_checkpointer, + func_args=(s3_checkpoint_dir, tiny_model_factory), + start_method="spawn", )