diff --git a/src/examples/ngpt/train.py b/src/examples/ngpt/train.py index e07791d7..131c4962 100644 --- a/src/examples/ngpt/train.py +++ b/src/examples/ngpt/train.py @@ -35,6 +35,7 @@ GPUMemoryMonitorCallback, GradClipperCallback, LMEvaluatorCallbackConfig, + MatrixNormalizerCallback, ProfilerCallback, SchedulerCallback, SequenceLengthSchedulerCallback, @@ -93,6 +94,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: metrics_collect_interval=5, cancel_check_interval=5, ) + .with_callback("matrix_normalizer", MatrixNormalizerCallback()) .with_callback("lr_scheduler", SchedulerCallback(scheduler=CosWithWarmup(warmup_steps=0))) .with_callback( "seq_len_scheduler", diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 5d8c6eb9..52c5ed16 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -9,12 +9,13 @@ from olmo_core.config import StrEnum from olmo_core.data.utils import get_cumulative_document_lengths +from olmo_core.distributed.utils import get_local_tensor from olmo_core.doc_utils import beta_feature from olmo_core.utils import get_default_device from ..buffer_cache import BufferCache from ..layer_norm import LayerNormConfig -from ..utils import selective_checkpointing_context_fn +from ..utils import l2_normalize, selective_checkpointing_context_fn from .block import TransformerBlock, TransformerBlockConfig from .init import InitMethod @@ -434,6 +435,14 @@ def num_flops_per_token(self, seq_len: int) -> int: class NormalizedTransformer(Transformer): + """ + A nGPT transformer implementation. + + .. warning:: + When training this model you should use the :class:`~olmo_core.train.callbacks.MatrixNormalizerCallback` + to re-normalize the weight matrices after each optimizer step. + """ + def __init__( self, *, @@ -474,6 +483,33 @@ def init_weights( super().init_weights(max_seq_len=max_seq_len, device=device) nn.init.ones_(self.sz) self.sz.mul_(self.sz_init_scaling) + self.normalize_matrices() + + @torch.no_grad() + def normalize_matrices(self): + """ + Normalize the weights in all matrices. This should be called after each optimizer step, which + the :class:`~olmo_core.train.callbacks.MatrixNormalizerCallback` will handle for you. + """ + if self.embeddings is not None: + self._normalize_matrix(self.embeddings.weight) + + for block in self.blocks: + self._normalize_matrix(block.attention.w_q.weight) + self._normalize_matrix(block.attention.w_k.weight) + self._normalize_matrix(block.attention.w_v.weight) + self._normalize_matrix(block.attention.w_out.weight, dim=0) + + self._normalize_matrix(block.feed_forward.w1.weight) + self._normalize_matrix(block.feed_forward.w2.weight, dim=0) + self._normalize_matrix(block.feed_forward.w3.weight) + + if self.w_out is not None: + self._normalize_matrix(self.w_out.weight) + + def _normalize_matrix(self, p: torch.Tensor, dim: int = -1): + w = get_local_tensor(p.data) + w.copy_(l2_normalize(w, dim=dim)) def forward( self, diff --git a/src/olmo_core/nn/utils.py b/src/olmo_core/nn/utils.py index 7fdc15ab..72ee5384 100644 --- a/src/olmo_core/nn/utils.py +++ b/src/olmo_core/nn/utils.py @@ -39,4 +39,4 @@ def selective_checkpointing_context_fn(): def l2_normalize(x: torch.Tensor, dim: int = -1) -> torch.Tensor: - return x / x.norm(p=2, dim=dim, keepdim=True) + return x / x.float().norm(p=2, dim=dim, keepdim=True).type_as(x) diff --git a/src/olmo_core/train/callbacks/__init__.py b/src/olmo_core/train/callbacks/__init__.py index 05505104..c393422e 100644 --- a/src/olmo_core/train/callbacks/__init__.py +++ b/src/olmo_core/train/callbacks/__init__.py @@ -12,6 +12,7 @@ from .garbage_collector import GarbageCollectorCallback from .gpu_memory_monitor import GPUMemoryMonitorCallback from .grad_clipper import GradClipperCallback +from .matrix_normalizer import MatrixNormalizerCallback from .moe_handler import MoEHandlerCallback from .profiler import ProfilerCallback from .scheduler import SchedulerCallback @@ -36,6 +37,7 @@ "GarbageCollectorCallback", "GPUMemoryMonitorCallback", "GradClipperCallback", + "MatrixNormalizerCallback", "ProfilerCallback", "SchedulerCallback", "SequenceLengthSchedulerCallback", diff --git a/src/olmo_core/train/callbacks/float8_handler.py b/src/olmo_core/train/callbacks/float8_handler.py index 4e47829b..679d5299 100644 --- a/src/olmo_core/train/callbacks/float8_handler.py +++ b/src/olmo_core/train/callbacks/float8_handler.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import ClassVar from olmo_core.float8 import Float8Config, Float8Handler @@ -18,6 +19,8 @@ class Float8HandlerCallback(Callback): on your model prior to training to replace the linear layers with ``Float8Linear`` layers. """ + priority: ClassVar[int] = -1 + config: Float8Config = field(default_factory=Float8Config) _handler = None diff --git a/src/olmo_core/train/callbacks/matrix_normalizer.py b/src/olmo_core/train/callbacks/matrix_normalizer.py new file mode 100644 index 00000000..e158e215 --- /dev/null +++ b/src/olmo_core/train/callbacks/matrix_normalizer.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +from .callback import Callback + + +@dataclass +class MatrixNormalizerCallback(Callback): + """ + A callback to be used in conjunction with :class:`~olmo_core.nn.transformer.NormalizedTransformer` + (nGPT) models to re-normalize the weight matrices after each optimizer step. + """ + + def post_train_batch(self): + self.trainer.model.normalize_matrices()