Skip to content

Commit

Permalink
Add callback
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 19, 2024
1 parent 801cddf commit 5471345
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/examples/ngpt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
GPUMemoryMonitorCallback,
GradClipperCallback,
LMEvaluatorCallbackConfig,
MatrixNormalizerCallback,
ProfilerCallback,
SchedulerCallback,
SequenceLengthSchedulerCallback,
Expand Down Expand Up @@ -93,6 +94,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
metrics_collect_interval=5,
cancel_check_interval=5,
)
.with_callback("matrix_normalizer", MatrixNormalizerCallback())
.with_callback("lr_scheduler", SchedulerCallback(scheduler=CosWithWarmup(warmup_steps=0)))
.with_callback(
"seq_len_scheduler",
Expand Down
38 changes: 37 additions & 1 deletion src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

from olmo_core.config import StrEnum
from olmo_core.data.utils import get_cumulative_document_lengths
from olmo_core.distributed.utils import get_local_tensor
from olmo_core.doc_utils import beta_feature
from olmo_core.utils import get_default_device

from ..buffer_cache import BufferCache
from ..layer_norm import LayerNormConfig
from ..utils import selective_checkpointing_context_fn
from ..utils import l2_normalize, selective_checkpointing_context_fn
from .block import TransformerBlock, TransformerBlockConfig
from .init import InitMethod

Expand Down Expand Up @@ -434,6 +435,14 @@ def num_flops_per_token(self, seq_len: int) -> int:


class NormalizedTransformer(Transformer):
"""
A nGPT transformer implementation.
.. warning::
When training this model you should use the :class:`~olmo_core.train.callbacks.MatrixNormalizerCallback`
to re-normalize the weight matrices after each optimizer step.
"""

def __init__(
self,
*,
Expand Down Expand Up @@ -474,6 +483,33 @@ def init_weights(
super().init_weights(max_seq_len=max_seq_len, device=device)
nn.init.ones_(self.sz)
self.sz.mul_(self.sz_init_scaling)
self.normalize_matrices()

@torch.no_grad()
def normalize_matrices(self):
"""
Normalize the weights in all matrices. This should be called after each optimizer step, which
the :class:`~olmo_core.train.callbacks.MatrixNormalizerCallback` will handle for you.
"""
if self.embeddings is not None:
self._normalize_matrix(self.embeddings.weight)

for block in self.blocks:
self._normalize_matrix(block.attention.w_q.weight)
self._normalize_matrix(block.attention.w_k.weight)
self._normalize_matrix(block.attention.w_v.weight)
self._normalize_matrix(block.attention.w_out.weight, dim=0)

self._normalize_matrix(block.feed_forward.w1.weight)
self._normalize_matrix(block.feed_forward.w2.weight, dim=0)
self._normalize_matrix(block.feed_forward.w3.weight)

if self.w_out is not None:
self._normalize_matrix(self.w_out.weight)

def _normalize_matrix(self, p: torch.Tensor, dim: int = -1):
w = get_local_tensor(p.data)
w.copy_(l2_normalize(w, dim=dim))

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def selective_checkpointing_context_fn():


def l2_normalize(x: torch.Tensor, dim: int = -1) -> torch.Tensor:
return x / x.norm(p=2, dim=dim, keepdim=True)
return x / x.float().norm(p=2, dim=dim, keepdim=True).type_as(x)
2 changes: 2 additions & 0 deletions src/olmo_core/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .garbage_collector import GarbageCollectorCallback
from .gpu_memory_monitor import GPUMemoryMonitorCallback
from .grad_clipper import GradClipperCallback
from .matrix_normalizer import MatrixNormalizerCallback
from .moe_handler import MoEHandlerCallback
from .profiler import ProfilerCallback
from .scheduler import SchedulerCallback
Expand All @@ -36,6 +37,7 @@
"GarbageCollectorCallback",
"GPUMemoryMonitorCallback",
"GradClipperCallback",
"MatrixNormalizerCallback",
"ProfilerCallback",
"SchedulerCallback",
"SequenceLengthSchedulerCallback",
Expand Down
3 changes: 3 additions & 0 deletions src/olmo_core/train/callbacks/float8_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import ClassVar

from olmo_core.float8 import Float8Config, Float8Handler

Expand All @@ -18,6 +19,8 @@ class Float8HandlerCallback(Callback):
on your model prior to training to replace the linear layers with ``Float8Linear`` layers.
"""

priority: ClassVar[int] = -1

config: Float8Config = field(default_factory=Float8Config)

_handler = None
Expand Down
14 changes: 14 additions & 0 deletions src/olmo_core/train/callbacks/matrix_normalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass

from .callback import Callback


@dataclass
class MatrixNormalizerCallback(Callback):
"""
A callback to be used in conjunction with :class:`~olmo_core.nn.transformer.NormalizedTransformer`
(nGPT) models to re-normalize the weight matrices after each optimizer step.
"""

def post_train_batch(self):
self.trainer.model.normalize_matrices()

0 comments on commit 5471345

Please sign in to comment.