Skip to content

Commit

Permalink
Make nn configs more flexible (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Nov 22, 2024
1 parent 0bcc840 commit d68d47a
Show file tree
Hide file tree
Showing 11 changed files with 310 additions and 243 deletions.
89 changes: 65 additions & 24 deletions src/olmo_core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn.functional as F

from ..config import Config, DType, StrEnum
from ..doc_utils import beta_feature
from ..exceptions import OLMoConfigurationError
from .buffer_cache import BufferCache
from .functional import l2_normalize
Expand Down Expand Up @@ -50,14 +51,49 @@ class AttentionConfig(Config):
"""
n_heads: int = 16
n_kv_heads: Optional[int] = None
bias: bool = True
bias: Optional[bool] = None
rope: Optional[RoPEConfig] = None
clip_qkv: Optional[float] = None
qk_norm: Optional[LayerNormConfig] = None
dropout: float = 0.0
dropout: Optional[float] = None
use_flash: Optional[bool] = None
dtype: DType = DType.float32

def num_params(self, d_model: int) -> int:
n_heads = self.n_heads
n_kv_heads = self.n_kv_heads or n_heads
head_dim = d_model // n_heads
bias = self.bias if self.bias is not None else self.name != AttentionType.normalized

params = 0

# Block attention Q projection.
params += d_model * d_model
if bias:
params += d_model

# Block attention KV projections.
params += 2 * d_model * n_kv_heads * head_dim
if bias:
params += 2 * n_kv_heads * head_dim

# Block attention QK norm.
if self.qk_norm is not None:
params += 2 * self.qk_norm.num_params(d_model)

# Block attention out.
params += d_model * d_model
if bias:
params += d_model

# Block QK scaling factors.
if self.name == AttentionType.normalized:
head_dim = d_model // n_heads
params += n_heads * head_dim
params += n_kv_heads * head_dim

return params

def build(
self,
d_model: int,
Expand All @@ -72,29 +108,27 @@ def build(
"""
kwargs = self.as_dict(exclude_none=True, recurse=False)
kwargs.pop("name")
kwargs["dtype"] = kwargs["dtype"].as_pt()
kwargs.update(
dict(
d_model=d_model,
init_device=init_device,
cache=cache,
)
dtype=kwargs.pop("dtype").as_pt(),
d_model=d_model,
init_device=init_device,
cache=cache,
)

if self.name == "default":
return Attention(**kwargs)
elif self.name == "fused":
kwargs.pop("use_flash", None)
return FusedAttention(**kwargs)
elif self.name == "normalized":
bias = kwargs.pop("bias")
if bias:
raise OLMoConfigurationError(f"'bias' is invalid for '{self.name}' attention")
if kwargs.pop("dropout") > 0.0:
raise OLMoConfigurationError(f"'dropout' is invalid for '{self.name}' attention")
return NormalizedAttention(**kwargs)
else:
raise NotImplementedError(self.name)
try:
if self.name == "default":
return Attention(**kwargs)
elif self.name == "fused":
kwargs.pop("use_flash", None)
return FusedAttention(**kwargs)
elif self.name == "normalized":
return NormalizedAttention(**kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e


class Attention(nn.Module):
Expand Down Expand Up @@ -300,6 +334,7 @@ def forward(
return self.w_out(att)


@beta_feature
class NormalizedAttention(Attention):
"""
An nGPT attention implementation.
Expand All @@ -312,6 +347,7 @@ def __init__(
n_heads: int,
n_kv_heads: Optional[int] = None,
rope: Optional[RoPEConfig] = None,
qk_norm: Optional[LayerNormConfig] = None,
use_flash: bool = False,
dtype: torch.dtype = torch.float32,
init_device: str = "cpu",
Expand All @@ -322,6 +358,7 @@ def __init__(
n_heads=n_heads,
n_kv_heads=n_kv_heads,
rope=rope,
qk_norm=qk_norm,
use_flash=use_flash,
bias=False,
dtype=dtype,
Expand Down Expand Up @@ -364,11 +401,15 @@ def forward(
# (batch_size, seq_len, n_kv_heads * head_dim)
q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)

if self.q_norm is not None and self.k_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)

sq = (self.sq * (self.sq_init_value / self.sq_init_scaling)).view(1, 1, -1)
q = sq * l2_normalize(q)
q = sq * q

sk = (self.sk * (self.sk_init_value / self.sk_init_scaling)).view(1, 1, -1)
k = sk * l2_normalize(k)
k = sk * k

# shape: (batch_size, seq_len, n_heads, head_dim)
q = q.view(B, T, self.n_heads, self.head_dim)
Expand Down
52 changes: 34 additions & 18 deletions src/olmo_core/nn/feed_forward.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import math
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..config import Config, DType, StrEnum
from ..doc_utils import beta_feature
from ..exceptions import OLMoConfigurationError
from .functional import l2_normalize

Expand Down Expand Up @@ -38,27 +40,40 @@ class FeedForwardConfig(Config):

hidden_size: int
name: FeedForwardType = FeedForwardType.default
bias: bool = True
bias: Optional[bool] = None
dtype: DType = DType.float32

def num_params(self, d_model: int) -> int:
bias = self.bias if self.bias is not None else self.name != FeedForwardType.normalized

params = 0

params += 3 * d_model * self.hidden_size
if bias:
params += 2 * self.hidden_size + d_model

# w1 + w3 scaling factors
if self.name == FeedForwardType.normalized:
params += 2 * self.hidden_size

return params

def build(self, d_model: int, init_device: str = "cpu") -> "FeedForward":
if self.name == FeedForwardType.default:
return FeedForward(
d_model=d_model,
hidden_size=self.hidden_size,
bias=self.bias,
dtype=self.dtype.as_pt(),
init_device=init_device,
)
else:
if self.bias:
raise OLMoConfigurationError(f"'bias' is invalid for '{self.name}' feed-forward")
return NormalizedFeedForward(
d_model=d_model,
hidden_size=self.hidden_size,
dtype=self.dtype.as_pt(),
init_device=init_device,
)
kwargs = self.as_dict(exclude_none=True)
kwargs.pop("name")
kwargs.update(d_model=d_model, init_device=init_device, dtype=kwargs.pop("dtype").as_pt())

try:
if self.name == FeedForwardType.default:
return FeedForward(**kwargs)
elif self.name == FeedForwardType.normalized:
return NormalizedFeedForward(**kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e


class FeedForward(nn.Module):
Expand Down Expand Up @@ -91,6 +106,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))


@beta_feature
class NormalizedFeedForward(FeedForward):
"""
An nGPT feed-forward implementation.
Expand Down
82 changes: 59 additions & 23 deletions src/olmo_core/nn/layer_norm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..config import Config, DType, StrEnum
from ..exceptions import OLMoConfigurationError
from .functional import l2_normalize

__all__ = ["LayerNormType", "LayerNormConfig", "LayerNorm", "RMSNorm", "FusedRMSNorm"]
__all__ = ["LayerNormType", "LayerNormConfig", "LayerNorm", "RMSNorm", "FusedRMSNorm", "L2Norm"]


class LayerNormType(StrEnum):
Expand All @@ -21,6 +24,7 @@ class LayerNormType(StrEnum):
default = "default"
rms = "rms"
fused_rms = "fused_rms"
l2_norm = "l2_norm"


@dataclass
Expand All @@ -37,11 +41,25 @@ class LayerNormConfig(Config):
- "rms" ➡️ :class:`RMSNorm`
- "fused_rms" ➡️ :class:`FusedRMSNorm`
"""
eps: float = 1e-5
elementwise_affine: bool = True
bias: bool = True
full_precision: bool = True
dtype: DType = DType.float32
eps: Optional[float] = None
elementwise_affine: Optional[bool] = None
bias: Optional[bool] = None
full_precision: Optional[bool] = None
dtype: Optional[DType] = None

def num_params(self, size: int) -> int:
elementwise_affine = (
self.elementwise_affine
if self.elementwise_affine is not None
else self.name != LayerNormType.l2_norm
)
bias = self.bias if self.bias is not None else self.name != LayerNormType.l2_norm
ln_params = 0
if elementwise_affine:
ln_params += size
if bias:
ln_params += size
return ln_params

def build(self, size: int, init_device: str = "cpu") -> "LayerNorm":
"""
Expand All @@ -51,23 +69,24 @@ def build(self, size: int, init_device: str = "cpu") -> "LayerNorm":
"""
kwargs = self.as_dict(exclude_none=True)
kwargs.pop("name")
dtype = kwargs["dtype"].as_pt()
kwargs.update(
dict(
size=size,
init_device=init_device,
dtype=dtype,
)
)

if self.name == LayerNormType.default:
return LayerNorm(**kwargs)
elif self.name == LayerNormType.rms:
return RMSNorm(**kwargs)
elif self.name == LayerNormType.fused_rms:
return FusedRMSNorm(**kwargs)
else:
raise NotImplementedError(self.name)
if (dtype := kwargs.pop("dtype", None)) is not None:
kwargs.update(dtype=dtype.as_pt())

try:
if self.name == LayerNormType.default:
return LayerNorm(size=size, init_device=init_device, **kwargs)
elif self.name == LayerNormType.rms:
return RMSNorm(size=size, init_device=init_device, **kwargs)
elif self.name == LayerNormType.fused_rms:
return FusedRMSNorm(size=size, init_device=init_device, **kwargs)
elif self.name == LayerNormType.l2_norm:
return L2Norm(size=size, **kwargs)
else:
raise NotImplementedError(self.name)
except TypeError as e:
raise OLMoConfigurationError(
f"invalid options for '{self.name}' {self.__class__.__name__}, {e}"
) from e


class LayerNorm(nn.Module):
Expand Down Expand Up @@ -245,3 +264,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
None if self.bias is None else self.bias.type_as(x),
eps=self.eps,
).to(og_dtype)


class L2Norm(LayerNorm):
"""
A variant of layer norm that just normalizes the last dimension of the input by its L2 norm,
as done in nGPT.
"""

def __init__(
self,
*,
size: int,
):
super().__init__(size=size, elementwise_affine=False, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return l2_normalize(x)
Loading

0 comments on commit d68d47a

Please sign in to comment.