diff --git a/src/olmo_core/nn/attention.py b/src/olmo_core/nn/attention.py index 6476333f..c424deb2 100644 --- a/src/olmo_core/nn/attention.py +++ b/src/olmo_core/nn/attention.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from ..config import Config, DType, StrEnum +from ..doc_utils import beta_feature from ..exceptions import OLMoConfigurationError from .buffer_cache import BufferCache from .functional import l2_normalize @@ -50,14 +51,49 @@ class AttentionConfig(Config): """ n_heads: int = 16 n_kv_heads: Optional[int] = None - bias: bool = True + bias: Optional[bool] = None rope: Optional[RoPEConfig] = None clip_qkv: Optional[float] = None qk_norm: Optional[LayerNormConfig] = None - dropout: float = 0.0 + dropout: Optional[float] = None use_flash: Optional[bool] = None dtype: DType = DType.float32 + def num_params(self, d_model: int) -> int: + n_heads = self.n_heads + n_kv_heads = self.n_kv_heads or n_heads + head_dim = d_model // n_heads + bias = self.bias if self.bias is not None else self.name != AttentionType.normalized + + params = 0 + + # Block attention Q projection. + params += d_model * d_model + if bias: + params += d_model + + # Block attention KV projections. + params += 2 * d_model * n_kv_heads * head_dim + if bias: + params += 2 * n_kv_heads * head_dim + + # Block attention QK norm. + if self.qk_norm is not None: + params += 2 * self.qk_norm.num_params(d_model) + + # Block attention out. + params += d_model * d_model + if bias: + params += d_model + + # Block QK scaling factors. + if self.name == AttentionType.normalized: + head_dim = d_model // n_heads + params += n_heads * head_dim + params += n_kv_heads * head_dim + + return params + def build( self, d_model: int, @@ -72,29 +108,27 @@ def build( """ kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") - kwargs["dtype"] = kwargs["dtype"].as_pt() kwargs.update( - dict( - d_model=d_model, - init_device=init_device, - cache=cache, - ) + dtype=kwargs.pop("dtype").as_pt(), + d_model=d_model, + init_device=init_device, + cache=cache, ) - if self.name == "default": - return Attention(**kwargs) - elif self.name == "fused": - kwargs.pop("use_flash", None) - return FusedAttention(**kwargs) - elif self.name == "normalized": - bias = kwargs.pop("bias") - if bias: - raise OLMoConfigurationError(f"'bias' is invalid for '{self.name}' attention") - if kwargs.pop("dropout") > 0.0: - raise OLMoConfigurationError(f"'dropout' is invalid for '{self.name}' attention") - return NormalizedAttention(**kwargs) - else: - raise NotImplementedError(self.name) + try: + if self.name == "default": + return Attention(**kwargs) + elif self.name == "fused": + kwargs.pop("use_flash", None) + return FusedAttention(**kwargs) + elif self.name == "normalized": + return NormalizedAttention(**kwargs) + else: + raise NotImplementedError(self.name) + except TypeError as e: + raise OLMoConfigurationError( + f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + ) from e class Attention(nn.Module): @@ -300,6 +334,7 @@ def forward( return self.w_out(att) +@beta_feature class NormalizedAttention(Attention): """ An nGPT attention implementation. @@ -312,6 +347,7 @@ def __init__( n_heads: int, n_kv_heads: Optional[int] = None, rope: Optional[RoPEConfig] = None, + qk_norm: Optional[LayerNormConfig] = None, use_flash: bool = False, dtype: torch.dtype = torch.float32, init_device: str = "cpu", @@ -322,6 +358,7 @@ def __init__( n_heads=n_heads, n_kv_heads=n_kv_heads, rope=rope, + qk_norm=qk_norm, use_flash=use_flash, bias=False, dtype=dtype, @@ -364,11 +401,15 @@ def forward( # (batch_size, seq_len, n_kv_heads * head_dim) q, k, v = self.w_q(x), self.w_k(x), self.w_v(x) + if self.q_norm is not None and self.k_norm is not None: + q = self.q_norm(q) + k = self.k_norm(k) + sq = (self.sq * (self.sq_init_value / self.sq_init_scaling)).view(1, 1, -1) - q = sq * l2_normalize(q) + q = sq * q sk = (self.sk * (self.sk_init_value / self.sk_init_scaling)).view(1, 1, -1) - k = sk * l2_normalize(k) + k = sk * k # shape: (batch_size, seq_len, n_heads, head_dim) q = q.view(B, T, self.n_heads, self.head_dim) diff --git a/src/olmo_core/nn/feed_forward.py b/src/olmo_core/nn/feed_forward.py index bbc03503..38d29273 100644 --- a/src/olmo_core/nn/feed_forward.py +++ b/src/olmo_core/nn/feed_forward.py @@ -1,11 +1,13 @@ import math from dataclasses import dataclass +from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from ..config import Config, DType, StrEnum +from ..doc_utils import beta_feature from ..exceptions import OLMoConfigurationError from .functional import l2_normalize @@ -38,27 +40,40 @@ class FeedForwardConfig(Config): hidden_size: int name: FeedForwardType = FeedForwardType.default - bias: bool = True + bias: Optional[bool] = None dtype: DType = DType.float32 + def num_params(self, d_model: int) -> int: + bias = self.bias if self.bias is not None else self.name != FeedForwardType.normalized + + params = 0 + + params += 3 * d_model * self.hidden_size + if bias: + params += 2 * self.hidden_size + d_model + + # w1 + w3 scaling factors + if self.name == FeedForwardType.normalized: + params += 2 * self.hidden_size + + return params + def build(self, d_model: int, init_device: str = "cpu") -> "FeedForward": - if self.name == FeedForwardType.default: - return FeedForward( - d_model=d_model, - hidden_size=self.hidden_size, - bias=self.bias, - dtype=self.dtype.as_pt(), - init_device=init_device, - ) - else: - if self.bias: - raise OLMoConfigurationError(f"'bias' is invalid for '{self.name}' feed-forward") - return NormalizedFeedForward( - d_model=d_model, - hidden_size=self.hidden_size, - dtype=self.dtype.as_pt(), - init_device=init_device, - ) + kwargs = self.as_dict(exclude_none=True) + kwargs.pop("name") + kwargs.update(d_model=d_model, init_device=init_device, dtype=kwargs.pop("dtype").as_pt()) + + try: + if self.name == FeedForwardType.default: + return FeedForward(**kwargs) + elif self.name == FeedForwardType.normalized: + return NormalizedFeedForward(**kwargs) + else: + raise NotImplementedError(self.name) + except TypeError as e: + raise OLMoConfigurationError( + f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + ) from e class FeedForward(nn.Module): @@ -91,6 +106,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) +@beta_feature class NormalizedFeedForward(FeedForward): """ An nGPT feed-forward implementation. diff --git a/src/olmo_core/nn/layer_norm.py b/src/olmo_core/nn/layer_norm.py index 92efd578..e4a3c234 100644 --- a/src/olmo_core/nn/layer_norm.py +++ b/src/olmo_core/nn/layer_norm.py @@ -1,12 +1,15 @@ from dataclasses import dataclass +from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from ..config import Config, DType, StrEnum +from ..exceptions import OLMoConfigurationError +from .functional import l2_normalize -__all__ = ["LayerNormType", "LayerNormConfig", "LayerNorm", "RMSNorm", "FusedRMSNorm"] +__all__ = ["LayerNormType", "LayerNormConfig", "LayerNorm", "RMSNorm", "FusedRMSNorm", "L2Norm"] class LayerNormType(StrEnum): @@ -21,6 +24,7 @@ class LayerNormType(StrEnum): default = "default" rms = "rms" fused_rms = "fused_rms" + l2_norm = "l2_norm" @dataclass @@ -37,11 +41,25 @@ class LayerNormConfig(Config): - "rms" ➡️ :class:`RMSNorm` - "fused_rms" ➡️ :class:`FusedRMSNorm` """ - eps: float = 1e-5 - elementwise_affine: bool = True - bias: bool = True - full_precision: bool = True - dtype: DType = DType.float32 + eps: Optional[float] = None + elementwise_affine: Optional[bool] = None + bias: Optional[bool] = None + full_precision: Optional[bool] = None + dtype: Optional[DType] = None + + def num_params(self, size: int) -> int: + elementwise_affine = ( + self.elementwise_affine + if self.elementwise_affine is not None + else self.name != LayerNormType.l2_norm + ) + bias = self.bias if self.bias is not None else self.name != LayerNormType.l2_norm + ln_params = 0 + if elementwise_affine: + ln_params += size + if bias: + ln_params += size + return ln_params def build(self, size: int, init_device: str = "cpu") -> "LayerNorm": """ @@ -51,23 +69,24 @@ def build(self, size: int, init_device: str = "cpu") -> "LayerNorm": """ kwargs = self.as_dict(exclude_none=True) kwargs.pop("name") - dtype = kwargs["dtype"].as_pt() - kwargs.update( - dict( - size=size, - init_device=init_device, - dtype=dtype, - ) - ) - - if self.name == LayerNormType.default: - return LayerNorm(**kwargs) - elif self.name == LayerNormType.rms: - return RMSNorm(**kwargs) - elif self.name == LayerNormType.fused_rms: - return FusedRMSNorm(**kwargs) - else: - raise NotImplementedError(self.name) + if (dtype := kwargs.pop("dtype", None)) is not None: + kwargs.update(dtype=dtype.as_pt()) + + try: + if self.name == LayerNormType.default: + return LayerNorm(size=size, init_device=init_device, **kwargs) + elif self.name == LayerNormType.rms: + return RMSNorm(size=size, init_device=init_device, **kwargs) + elif self.name == LayerNormType.fused_rms: + return FusedRMSNorm(size=size, init_device=init_device, **kwargs) + elif self.name == LayerNormType.l2_norm: + return L2Norm(size=size, **kwargs) + else: + raise NotImplementedError(self.name) + except TypeError as e: + raise OLMoConfigurationError( + f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + ) from e class LayerNorm(nn.Module): @@ -245,3 +264,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: None if self.bias is None else self.bias.type_as(x), eps=self.eps, ).to(og_dtype) + + +class L2Norm(LayerNorm): + """ + A variant of layer norm that just normalizes the last dimension of the input by its L2 norm, + as done in nGPT. + """ + + def __init__( + self, + *, + size: int, + ): + super().__init__(size=size, elementwise_affine=False, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return l2_normalize(x) diff --git a/src/olmo_core/nn/lm_head.py b/src/olmo_core/nn/lm_head.py index bc233a5a..b96bc999 100644 --- a/src/olmo_core/nn/lm_head.py +++ b/src/olmo_core/nn/lm_head.py @@ -6,6 +6,7 @@ import torch.nn as nn from ..config import Config, DType, StrEnum +from ..doc_utils import beta_feature from ..exceptions import OLMoConfigurationError from .functional import l2_normalize from .layer_norm import LayerNormConfig @@ -37,32 +38,47 @@ class LMHeadConfig(Config): name: LMHeadType = LMHeadType.default layer_norm: Optional[LayerNormConfig] = None - bias: bool = True + bias: Optional[bool] = None dtype: DType = DType.float32 + def num_params(self, d_model: int, vocab_size: int) -> int: + bias = self.bias if self.bias is not None else self.name != LMHeadType.normalized + + params = 0 + if self.layer_norm is not None: + params += self.layer_norm.num_params(d_model) + + params += d_model * vocab_size + if bias: + params += vocab_size + + # Final scaling factor. + if self.name == LMHeadType.normalized: + params += vocab_size + + return params + def build(self, *, d_model: int, vocab_size: int, init_device: str = "cpu") -> "LMHead": - if self.name == LMHeadType.default: - return LMHead( - d_model=d_model, - vocab_size=vocab_size, - layer_norm=self.layer_norm, - dtype=self.dtype.as_pt(), - bias=self.bias, - init_device=init_device, - ) - elif self.name == LMHeadType.normalized: - if self.bias: - raise OLMoConfigurationError(f"'bias' is invalid for '{self.name}' LM head") - if self.layer_norm is not None: - raise OLMoConfigurationError(f"'layer_norm' is invalid for '{self.name}' LM head") - return NormalizedLMHead( - d_model=d_model, - vocab_size=vocab_size, - dtype=self.dtype.as_pt(), - init_device=init_device, - ) - else: - raise NotImplementedError(self.name) + kwargs = self.as_dict(exclude_none=True, recurse=False) + kwargs.pop("name") + kwargs.update( + d_model=d_model, + vocab_size=vocab_size, + init_device=init_device, + dtype=kwargs.pop("dtype").as_pt(), + ) + + try: + if self.name == LMHeadType.default: + return LMHead(**kwargs) + elif self.name == LMHeadType.normalized: + return NormalizedLMHead(**kwargs) + else: + raise NotImplementedError(self.name) + except TypeError as e: + raise OLMoConfigurationError( + f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + ) from e class LMHead(nn.Module): @@ -75,7 +91,7 @@ def __init__( *, d_model: int, vocab_size: int, - layer_norm: Optional[LayerNormConfig], + layer_norm: Optional[LayerNormConfig] = None, dtype: torch.dtype = torch.float32, bias: bool = True, init_device: str = "cpu", @@ -91,6 +107,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(h) +@beta_feature class NormalizedLMHead(LMHead): """ An nGPT LM head implementation. diff --git a/src/olmo_core/nn/rope.py b/src/olmo_core/nn/rope.py index 137ec889..26a184a4 100644 --- a/src/olmo_core/nn/rope.py +++ b/src/olmo_core/nn/rope.py @@ -7,6 +7,7 @@ import torch.nn as nn from ..config import Config, StrEnum +from ..exceptions import OLMoConfigurationError from .buffer_cache import BufferCache __all__ = [ @@ -95,17 +96,21 @@ def build( """ kwargs = self.as_dict(exclude_none=True, recurse=False) kwargs.pop("name") - kwargs["head_shape"] = head_shape - kwargs["cache"] = cache - - if self.name == "default": - return RotaryEmbedding(**kwargs) - elif self.name == "fused": - return FusedRotaryEmbedding(**kwargs) - elif self.name == "complex": - return ComplexRotaryEmbedding(**kwargs) - else: - raise NotImplementedError(self.name) + kwargs.update(head_shape=head_shape, cache=cache) + + try: + if self.name == "default": + return RotaryEmbedding(**kwargs) + elif self.name == "fused": + return FusedRotaryEmbedding(**kwargs) + elif self.name == "complex": + return ComplexRotaryEmbedding(**kwargs) + else: + raise NotImplementedError(self.name) + except TypeError as e: + raise OLMoConfigurationError( + f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + ) from e class RotaryEmbeddingBase(nn.Module): diff --git a/src/olmo_core/nn/transformer/block.py b/src/olmo_core/nn/transformer/block.py index 36e62143..81e3da3d 100644 --- a/src/olmo_core/nn/transformer/block.py +++ b/src/olmo_core/nn/transformer/block.py @@ -7,6 +7,7 @@ import torch.nn as nn from olmo_core.config import Config, StrEnum +from olmo_core.doc_utils import beta_feature from olmo_core.exceptions import OLMoConfigurationError from ..attention import AttentionConfig @@ -74,7 +75,7 @@ class TransformerBlockConfig(Config): """ The block type. """ - dropout: float = 0.0 + dropout: Optional[float] = None """ Dropout probability. """ @@ -87,79 +88,32 @@ def build( init_device: str = "cpu", cache: Optional[BufferCache] = None, ) -> "TransformerBlockBase": - if self.name == TransformerBlockType.default: - if self.feed_forward is None: - raise OLMoConfigurationError("'feed_forward' config is required") - if self.layer_norm is None: - raise OLMoConfigurationError("'layer_norm' config is required") - return TransformerBlock( - d_model=d_model, - block_idx=block_idx, - attention=self.attention, - feed_forward=self.feed_forward, - layer_norm=self.layer_norm, - dropout=self.dropout, - init_device=init_device, - cache=cache, - ) - elif self.name == TransformerBlockType.reordered_norm: - if self.feed_forward is None: - raise OLMoConfigurationError("'feed_forward' config is required") - if self.layer_norm is None: - raise OLMoConfigurationError("'layer_norm' config is required") - return ReorderedNormTransformerBlock( - d_model=d_model, - block_idx=block_idx, - attention=self.attention, - feed_forward=self.feed_forward, - layer_norm=self.layer_norm, - dropout=self.dropout, - init_device=init_device, - cache=cache, - ) - elif self.name == TransformerBlockType.normalized: - if self.feed_forward is None: - raise OLMoConfigurationError("'feed_forward' config is required") - return NormalizedTransformerBlock( - d_model=d_model, - block_idx=block_idx, - attention=self.attention, - feed_forward=self.feed_forward, - init_device=init_device, - cache=cache, - ) - elif self.name == TransformerBlockType.moe: - if self.feed_forward_moe is None: - raise OLMoConfigurationError("'feed_forward_moe' config is required for MoE blocks") - if self.layer_norm is None: - raise OLMoConfigurationError("'layer_norm' config is required") - return MoETransformerBlock( - d_model=d_model, - block_idx=block_idx, - attention=self.attention, - feed_forward_moe=self.feed_forward_moe, - layer_norm=self.layer_norm, - dropout=self.dropout, - init_device=init_device, - cache=cache, - ) - elif self.name == TransformerBlockType.moe_reordered_norm: - if self.feed_forward_moe is None: - raise OLMoConfigurationError("'feed_forward_moe' config is required for MoE blocks") - if self.layer_norm is None: - raise OLMoConfigurationError("'layer_norm' config is required") - return MoEReorderedNormTransformerBlock( - d_model=d_model, - block_idx=block_idx, - attention=self.attention, - feed_forward_moe=self.feed_forward_moe, - layer_norm=self.layer_norm, - dropout=self.dropout, - init_device=init_device, - cache=cache, - ) - else: - raise NotImplementedError(self.name) + kwargs = self.as_dict(exclude_none=True, recurse=False) + kwargs.pop("name") + kwargs.update( + d_model=d_model, + block_idx=block_idx, + init_device=init_device, + cache=cache, + ) + + try: + if self.name == TransformerBlockType.default: + return TransformerBlock(**kwargs) + elif self.name == TransformerBlockType.reordered_norm: + return ReorderedNormTransformerBlock(**kwargs) + elif self.name == TransformerBlockType.normalized: + return NormalizedTransformerBlock(**kwargs) + elif self.name == TransformerBlockType.moe: + return MoETransformerBlock(**kwargs) + elif self.name == TransformerBlockType.moe_reordered_norm: + return MoEReorderedNormTransformerBlock(**kwargs) + else: + raise NotImplementedError(self.name) + except TypeError as e: + raise OLMoConfigurationError( + f"invalid options for '{self.name}' {self.__class__.__name__}, {e}" + ) from e class TransformerBlockBase(nn.Module): @@ -253,6 +207,7 @@ def forward( return h + self.dropout(self.feed_forward_norm(self.feed_forward(h))) +@beta_feature class NormalizedTransformerBlock(TransformerBlockBase): """ An nGPT block implementation to be used with the :class:`~olmo_core.nn.attention.NormalizedAttention` diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index 5d8e1eef..4e344859 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -228,14 +228,6 @@ def num_params(self) -> int: The total number of parameters that a model from this config would have. """ - def layer_norm_params(layer_norm: LayerNormConfig) -> int: - ln_params = 0 - if layer_norm.elementwise_affine: - ln_params += self.d_model - if layer_norm.bias: - ln_params += self.d_model - return ln_params - num_params = 0 # Embedding params. @@ -243,77 +235,32 @@ def layer_norm_params(layer_norm: LayerNormConfig) -> int: block_params = 0 - n_heads = self.block.attention.n_heads - n_kv_heads = self.block.attention.n_kv_heads or n_heads - head_dim = self.d_model // n_heads - # Block attn and MLP scaling factors. if self.block.name == TransformerBlockType.normalized: block_params += 2 * self.d_model - # Block attention Q projection. - block_params += self.d_model * self.d_model - if self.block.attention.bias: - block_params += self.d_model - - # Block attention KV projections. - block_params += 2 * self.d_model * n_kv_heads * head_dim - if self.block.attention.bias: - block_params += 2 * n_kv_heads * head_dim - - # Block attention QK norm. - if self.block.attention.qk_norm is not None: - block_params += 2 * layer_norm_params(self.block.attention.qk_norm) - - # Block attention out. - block_params += self.d_model * self.d_model - if self.block.attention.bias: - block_params += self.d_model + # Block attention params. + block_params += self.block.attention.num_params(self.d_model) # Block attention norm. if self.block.layer_norm is not None: - block_params += layer_norm_params(self.block.layer_norm) - - # Block QK scaling factors. - if self.block.attention.name == AttentionType.normalized: - head_dim = self.d_model // self.block.attention.n_heads - block_params += self.block.attention.n_heads * head_dim - block_params += ( - self.block.attention.n_kv_heads or self.block.attention.n_heads - ) * head_dim + block_params += self.block.layer_norm.num_params(self.d_model) # Block feed forward. - if "moe" not in self.block.name: - assert self.block.feed_forward is not None - block_params += 3 * self.d_model * self.block.feed_forward.hidden_size - if self.block.feed_forward.bias: - block_params += 2 * self.block.feed_forward.hidden_size + self.d_model - # w1 + w3 scaling factors - if self.block.feed_forward.name == FeedForwardType.normalized: - block_params += 2 * self.block.feed_forward.hidden_size - else: - assert self.block.feed_forward_moe is not None + if self.block.feed_forward is not None: + block_params += self.block.feed_forward.num_params(self.d_model) + elif self.block.feed_forward_moe is not None: block_params += self.block.feed_forward_moe.num_params(self.d_model) # Block feed forward norm. if self.block.layer_norm is not None: - block_params += layer_norm_params(self.block.layer_norm) + block_params += self.block.layer_norm.num_params(self.d_model) # All block params. num_params += self.n_layers * block_params - # Final layer norm. - if self.lm_head.layer_norm is not None: - num_params += layer_norm_params(self.lm_head.layer_norm) - - # Final FF out. - num_params += self.d_model * self.vocab_size - if self.lm_head.bias: - num_params += self.vocab_size - - # Final scaling factor. - if self.name == TransformerType.normalized: - num_params += self.vocab_size + # LM head. + num_params += self.lm_head.num_params(self.d_model, self.vocab_size) return num_params @@ -670,6 +617,7 @@ def ngpt_like( n_layers: int, n_heads: int, n_kv_heads: Optional[int] = None, + qk_norm: bool = True, rope_theta: int = 500_000, hidden_size_multiple_of: int = 256, hidden_size_multiplier: Optional[float] = None, @@ -699,15 +647,14 @@ def ngpt_like( name=AttentionType.normalized, n_heads=n_heads, n_kv_heads=n_kv_heads, - bias=False, + qk_norm=None if not qk_norm else LayerNormConfig(name=LayerNormType.l2_norm), rope=RoPEConfig(name=RoPEType.default, theta=rope_theta), use_flash=use_flash, dtype=dtype, ), feed_forward=FeedForwardConfig( - name=FeedForwardType.normalized, hidden_size=hidden_size, bias=False, dtype=dtype + name=FeedForwardType.normalized, hidden_size=hidden_size, dtype=dtype ), - layer_norm=None, ) return cls( @@ -716,9 +663,7 @@ def ngpt_like( vocab_size=vocab_size, n_layers=n_layers, block=block, - lm_head=LMHeadConfig( - name=LMHeadType.normalized, layer_norm=None, bias=False, dtype=dtype - ), + lm_head=LMHeadConfig(name=LMHeadType.normalized, dtype=dtype), dtype=dtype, compile=compile, init_method=InitMethod.normalized, diff --git a/src/olmo_core/nn/transformer/model.py b/src/olmo_core/nn/transformer/model.py index 9b328c59..0a0caa98 100644 --- a/src/olmo_core/nn/transformer/model.py +++ b/src/olmo_core/nn/transformer/model.py @@ -437,6 +437,7 @@ def num_flops_per_token(self, seq_len: int) -> int: return flop_per_token +@beta_feature class NormalizedTransformer(Transformer): """ A nGPT transformer implementation, to be used with the :class:`NormalizedTransformerBlock` block diff --git a/src/test/nn/layer_norm_test.py b/src/test/nn/layer_norm_test.py index 718d34cd..dd058e89 100644 --- a/src/test/nn/layer_norm_test.py +++ b/src/test/nn/layer_norm_test.py @@ -1,7 +1,14 @@ import pytest import torch -from olmo_core.nn.layer_norm import FusedRMSNorm, RMSNorm +from olmo_core.exceptions import OLMoConfigurationError +from olmo_core.nn.layer_norm import ( + FusedRMSNorm, + L2Norm, + LayerNormConfig, + LayerNormType, + RMSNorm, +) from ..utils import requires_flash_attn, requires_gpu @@ -21,3 +28,11 @@ def test_fused_rms_norm(bias, dtype): y1 = norm(x) y2 = norm_fused(x) torch.testing.assert_close(y1, y2) + + +def test_layer_norm_builder_config(): + norm = LayerNormConfig(name=LayerNormType.l2_norm).build(size=1024) + assert isinstance(norm, L2Norm) + + with pytest.raises(OLMoConfigurationError): + LayerNormConfig(name=LayerNormType.l2_norm, elementwise_affine=True).build(size=1024) diff --git a/src/test/nn/lm_head_test.py b/src/test/nn/lm_head_test.py new file mode 100644 index 00000000..efe6b301 --- /dev/null +++ b/src/test/nn/lm_head_test.py @@ -0,0 +1,15 @@ +import pytest + +from olmo_core.exceptions import OLMoConfigurationError +from olmo_core.nn.lm_head import LMHeadConfig, LMHeadType + + +def test_lm_head_builder_config(): + lm_head = LMHeadConfig(name=LMHeadType.default).build(d_model=64, vocab_size=128) + assert lm_head.w_out.bias is not None + + lm_head = LMHeadConfig(name=LMHeadType.default, bias=False).build(d_model=64, vocab_size=128) + assert lm_head.w_out.bias is None + + with pytest.raises(OLMoConfigurationError): + LMHeadConfig(name=LMHeadType.normalized, bias=True).build(d_model=64, vocab_size=128) diff --git a/src/test/nn/transformer/model_test.py b/src/test/nn/transformer/model_test.py index ddcea420..6aca21e7 100644 --- a/src/test/nn/transformer/model_test.py +++ b/src/test/nn/transformer/model_test.py @@ -1,3 +1,5 @@ +import logging + import pytest import torch import torch.nn as nn @@ -7,6 +9,8 @@ from ...utils import GPU_MARKS +log = logging.getLogger(__name__) + @pytest.mark.parametrize( "init_device, device", @@ -17,6 +21,7 @@ ) def test_small_llama2_config_builder(init_device, device): config = TransformerConfig.llama2_271M(vocab_size=50257) + log.info(config) model = config.build(init_device=init_device, device=torch.device(device)) # Make sure num params estimate is correct. @@ -60,3 +65,19 @@ def test_small_ngpt_config_builder(init_device, device): # Make sure block_idx is set correctly. assert model.blocks[0].block_idx == 0 assert model.blocks[-1].block_idx == len(model.blocks) - 1 + + # Make sure all weights are normalized in the embedding dimension. + for name, module in model.named_modules(): + if isinstance(module, nn.Linear): + assert module.bias is None + w = module.weight + if w.shape[1] == config.d_model and "attention.w_out" not in name: + pass + elif w.shape[0] == config.d_model: + w = w.transpose(0, 1) + else: + continue + + log.info(f"Checking norm for '{name}'") + norm = torch.linalg.vector_norm(w, dim=1) + torch.testing.assert_close(norm, torch.ones_like(norm))