Skip to content

Commit

Permalink
Minor feature additions (#33)
Browse files Browse the repository at this point in the history
* ensure we stop right at the `stop_at` spec

* Add `CheckpointerCallback.remove` field

* Add "reordered norm" transformer block implementation

* Add `WandBCallback.notes` field
  • Loading branch information
epwalsh authored Aug 30, 2024
1 parent 535722c commit c009f89
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 17 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Add `Trainer.hard_stop` field.
- The trainer now catches `SIGTERM` and marks the run as canceled.
- Added `CheckpointerCallback.remove` strategy for configuring which old checkpoints found in the save folder are removed.
- Added `ReorderedNormTransformerBlock` implementation.
- Added `WandBCallback.notes` field.

### Fixed

Expand Down
8 changes: 7 additions & 1 deletion src/olmo_core/nn/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
Transformer building blocks.
"""

from .block import TransformerBlock, TransformerBlockConfig, TransformerBlockType
from .block import (
ReorderedNormTransformerBlock,
TransformerBlock,
TransformerBlockConfig,
TransformerBlockType,
)
from .init import InitMethod
from .model import (
Transformer,
Expand All @@ -16,6 +21,7 @@
"TransformerBlockType",
"TransformerBlockConfig",
"TransformerBlock",
"ReorderedNormTransformerBlock",
"TransformerActivationCheckpointingConfig",
"InitMethod",
]
33 changes: 30 additions & 3 deletions src/olmo_core/nn/transformer/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ class TransformerBlockType(StrEnum):
"""

default = "default"
"""
:class:`TransformerBlock`
"""

reordered_norm = "reordered_norm"
"""
:class:`ReorderedNormTransformerBlock`
"""


@dataclass
Expand Down Expand Up @@ -52,9 +60,9 @@ def build(
)

if self.name == TransformerBlockType.default:
return TransformerBlock(
**kwargs,
)
return TransformerBlock(**kwargs)
elif self.name == TransformerBlockType.reordered_norm:
return ReorderedNormTransformerBlock(**kwargs)
else:
raise NotImplementedError(self.name)

Expand Down Expand Up @@ -114,3 +122,22 @@ def forward(
self.attention(self.attention_norm(x), max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens)
)
return h + self.dropout(self.feed_forward(self.feed_forward_norm(h)))


class ReorderedNormTransformerBlock(TransformerBlock):
"""
Like :class:`TransformerBlock` except that the attention norm is applied on the output
of attention instead of the input, and likewise the feed-forward norm is applied on the output
of the feed-forward instead of the input.
"""

def forward(
self,
x: torch.Tensor,
max_doc_len: Optional[int] = None,
cu_doc_lens: Optional[torch.Tensor] = None,
) -> torch.Tensor:
h = x + self.dropout(
self.attention_norm(self.attention(x, max_doc_len=max_doc_len, cu_doc_lens=cu_doc_lens))
)
return h + self.dropout(self.feed_forward_norm(self.feed_forward(h)))
3 changes: 2 additions & 1 deletion src/olmo_core/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .callback import Callback
from .checkpointer import CheckpointerCallback
from .checkpointer import CheckpointerCallback, CheckpointRemovalStrategy
from .console_logger import ConsoleLoggerCallback
from .garbage_collector import GarbageCollectorCallback
from .gpu_memory_monitor import GPUMemoryMonitorCallback
Expand All @@ -11,6 +11,7 @@
__all__ = [
"Callback",
"CheckpointerCallback",
"CheckpointRemovalStrategy",
"ConsoleLoggerCallback",
"GarbageCollectorCallback",
"GPUMemoryMonitorCallback",
Expand Down
49 changes: 43 additions & 6 deletions src/olmo_core/train/callbacks/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch.distributed as dist

from olmo_core.config import StrEnum
from olmo_core.distributed.utils import (
backend_supports_cpu,
get_fs_local_rank,
Expand All @@ -21,10 +22,34 @@
log = logging.getLogger(__name__)


class CheckpointRemovalStrategy(StrEnum):
"""
An enumeration of the different strategies for removing old checkpoints found in the save folder.
"""

ephemeral_only = "ephemeral_only"
"""
Only remove checkpoints that were saved at the :data:`CheckpointerCallback.ephemeral_save_interval`.
"""

all_non_permanent = "all_non_permanent"
"""
Remove all non-permanent checkpoints found, including ephemeral checkpoints and also
any other checkpoints that were not saved at the :data:`CheckpointerCallback.save_interval`.
"""

never = "never"
"""
Never remove any old checkpoints found in the save folder.
"""


@dataclass
class CheckpointerCallback(Callback):
"""
Used to configure checkpointing during training.
Manages checkpointing during training, including writing checkpoints at set intervals
determined by :data:`save_interval` and :data:`ephemeral_save_interval`, as well as removing
old checkpoints found in the save folder as determined by the :data:`remove` setting.
.. important::
This callback gets added automatically if you don't explicitly configure it.
Expand All @@ -38,8 +63,10 @@ class CheckpointerCallback(Callback):

ephemeral_save_interval: Optional[int] = None
"""
The interval, in steps, with which to save temporary checkpoints. It's useful to set this to
a frequent interval for preemptible jobs.
The interval, in steps, with which to save temporary checkpoints. These checkpoints are removed
each time a new checkpoint is saved.
It can be useful to set this to a relatively frequent interval for preemptible jobs.
"""

pre_train_checkpoint: Optional[bool] = None
Expand All @@ -52,6 +79,11 @@ class CheckpointerCallback(Callback):
Save checkpoints asynchronously. Requires a backend that supports CPU.
"""

remove: CheckpointRemovalStrategy = CheckpointRemovalStrategy.ephemeral_only
"""
The strategy for removing old checkpoints found in the save folder.
"""

# Bookkeeping

# NOTE: can't use type annotation here, omegaconf doesn't like it
Expand Down Expand Up @@ -138,16 +170,21 @@ def pre_train(self):
self._checkpoints.append(self._save_checkpoint())

# Collect existing ephemeral checkpoints from previous runs.
if self.ephemeral_save_interval is not None:
if self.remove != CheckpointRemovalStrategy.never:
ephemeral_checkpoints: List[Tuple[int, str]] = []

# Only search from rank 0 to avoid hammering remote file stores with requests.
if get_rank() == 0:
for step_num, path in self.checkpointer.find_checkpoints(self.save_folder):
if step_num == 0 and step_num % self.save_interval == 0:
if step_num == 0 or step_num % self.save_interval == 0:
continue
elif step_num % self.ephemeral_save_interval == 0:
elif (
self.remove == CheckpointRemovalStrategy.ephemeral_only
and self.ephemeral_save_interval is not None
and step_num % self.ephemeral_save_interval == 0
) or (self.remove == CheckpointRemovalStrategy.all_non_permanent):
ephemeral_checkpoints.append((step_num, path))

ephemeral_checkpoints = scatter_object(ephemeral_checkpoints)

# TODO: handle this if we ever restore callback state.
Expand Down
6 changes: 6 additions & 0 deletions src/olmo_core/train/callbacks/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class WandBCallback(Callback):
Tags to assign the run.
"""

notes: Optional[str] = None
"""
A note/description of the run.
"""

config: Optional[Dict[str, Any]] = None
"""
The config to load to W&B.
Expand Down Expand Up @@ -111,6 +116,7 @@ def pre_train(self):
group=self.group,
name=self.name,
tags=self.tags,
notes=self.notes,
config=self.config,
)
self._run_path = self.run.path # type: ignore
Expand Down
15 changes: 9 additions & 6 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,9 +497,6 @@ def fit(self):
self._cancel_reason = None
self._canceling_rank = None

# Install SIGTERM handler.
signal.signal(signal.SIGTERM, self._handle_sigterm)

# Maybe load a checkpoint.
if not self.checkpoint_loaded:
load_path = self.load_path if self.load_path is not None else self.save_folder
Expand All @@ -517,13 +514,19 @@ def fit(self):

barrier()

# Install SIGTERM handler.
og_handler = signal.signal(signal.SIGTERM, self._handle_sigterm)

try:
while not self.training_complete:
self._fit_epoch()
except BaseException as exc:
for callback in self.callbacks.values():
callback.on_error(exc)
raise
finally:
# Restore original SIGTERM handler.
signal.signal(signal.SIGTERM, og_handler)

for callback in self.callbacks.values():
callback.post_train()
Expand Down Expand Up @@ -676,11 +679,11 @@ def write_file(self, name: str, contents: Union[str, bytes]) -> PathOrStr:

def _duration_due(self, duration: Duration) -> bool:
if duration.unit == DurationUnit.steps:
return self.global_step > duration.value
elif duration.unit == DurationUnit.epochs:
return self.epoch > duration.value
return self.global_step >= duration.value
elif duration.unit == DurationUnit.tokens:
return self.global_train_tokens_seen >= duration.value
elif duration.unit == DurationUnit.epochs:
return self.epoch > duration.value
else:
raise NotImplementedError

Expand Down

0 comments on commit c009f89

Please sign in to comment.