Skip to content

Commit

Permalink
improve docs
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Nov 22, 2024
1 parent d68d47a commit 37e0e88
Show file tree
Hide file tree
Showing 13 changed files with 194 additions and 98 deletions.
22 changes: 16 additions & 6 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,22 @@ def filter(self, record: logging.LogRecord) -> bool:


def autodoc_skip_member(app, what, name, obj, skip, options):
"""
Skip documenting these Pydantic-specific attributes.
"""
del app, what, obj, skip, options
exclude = name in {"model_config", "model_fields", "model_computed_fields"}
return True if exclude else None
import inspect

del app, name, options

module = inspect.getmodule(obj)
module_name = None if module is None else module.__name__
if (
what == "class"
and module_name is not None
and module_name.startswith("olmo_core.train.callbacks")
and module_name != "olmo_core.train.callbacks.callback"
):
if inspect.isfunction(obj) or inspect.ismethod(obj):
return True

return skip


def setup(app):
Expand Down
26 changes: 18 additions & 8 deletions src/olmo_core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,33 @@
class AttentionType(StrEnum):
"""
An enumeration of the different attention implementations.
- "default" ➡️ :class:`Attention`
- "fused" ➡️ :class:`FusedAttention`
- "normalized" ➡️ :class:`NormalizedAttention`
"""

default = "default"
"""
➡️ :class:`Attention`
"""
fused = "fused"
"""
➡️ :class:`FusedAttention`
"""
normalized = "normalized"
"""
➡️ :class:`NormalizedAttention`
"""


@dataclass
class AttentionConfig(Config):
"""
A configuration class for easily building any of the different attention modules.
See :class:`Attention` for a description of the parameters.
See the individual :class:`Attention` subclasses for a description of the configuration options.
"""

name: AttentionType = AttentionType.default
"""
- "default" ➡️ :class:`Attention`
- "fused" ➡️ :class:`FusedAttention`
The name of the implementation.
"""
n_heads: int = 16
n_kv_heads: Optional[int] = None
Expand All @@ -60,6 +64,11 @@ class AttentionConfig(Config):
dtype: DType = DType.float32

def num_params(self, d_model: int) -> int:
"""
The number of params that the attention implementation will have once built.
:param d_model: The model dimensionality.
"""
n_heads = self.n_heads
n_kv_heads = self.n_kv_heads or n_heads
head_dim = d_model // n_heads
Expand Down Expand Up @@ -104,7 +113,8 @@ def build(
"""
Build the corresponding attention module.
See :class:`Attention` for a description of the parameters.
:param d_model: The model dimensionality.
:param init_device: The device initialize the parameters on, e.g. "cpu", "meta".
"""
kwargs = self.as_dict(exclude_none=True, recurse=False)
kwargs.pop("name")
Expand Down
24 changes: 18 additions & 6 deletions src/olmo_core/nn/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..exceptions import OLMoConfigurationError
from .functional import l2_normalize

__all__ = ["FeedForwardConfig", "FeedForwardType", "FeedForward", "NormalizedFeedForward"]
__all__ = ["FeedForwardType", "FeedForwardConfig", "FeedForward", "NormalizedFeedForward"]


class FeedForwardType(StrEnum):
Expand All @@ -21,29 +21,35 @@ class FeedForwardType(StrEnum):

default = "default"
"""
:class:`FeedForward`.
➡️ :class:`FeedForward`
"""

normalized = "normalized"
"""
:class:`NormalizedFeedForward`.
➡️ :class:`NormalizedFeedForward`
"""


@dataclass
class FeedForwardConfig(Config):
"""
A config for building :class:`FeedForward` modules.
See :class:`FeedForward` for parameter descriptions.
"""

hidden_size: int
name: FeedForwardType = FeedForwardType.default
"""
The name of the implementation.
"""
bias: Optional[bool] = None
dtype: DType = DType.float32

def num_params(self, d_model: int) -> int:
"""
The number of params that the module will have once built.
:param d_model: The model dimensionality.
"""
bias = self.bias if self.bias is not None else self.name != FeedForwardType.normalized

params = 0
Expand All @@ -58,7 +64,13 @@ def num_params(self, d_model: int) -> int:

return params

def build(self, d_model: int, init_device: str = "cpu") -> "FeedForward":
def build(self, d_model: int, *, init_device: str = "cpu") -> "FeedForward":
"""
Build the corresponding feed-forward module.
:param d_model: The model dimensionality.
:param init_device: The device initialize the parameters on, e.g. "cpu", "meta".
"""
kwargs = self.as_dict(exclude_none=True)
kwargs.pop("name")
kwargs.update(d_model=d_model, init_device=init_device, dtype=kwargs.pop("dtype").as_pt())
Expand Down
46 changes: 26 additions & 20 deletions src/olmo_core/nn/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,37 @@
class LayerNormType(StrEnum):
"""
An enumeration of the different layer norm implementations.
- "default" ➡️ :class:`LayerNorm`
- "rms" ➡️ :class:`RMSNorm`
- "fused_rms" ➡️ :class:`FusedRMSNorm`
"""

default = "default"
"""
➡️ :class:`LayerNorm`
"""
rms = "rms"
"""
➡️ :class:`RMSNorm`
"""
fused_rms = "fused_rms"
"""
➡️ :class:`FusedRMSNorm`
"""
l2_norm = "l2_norm"
"""
➡️ :class:`L2Norm`
"""


@dataclass
class LayerNormConfig(Config):
"""
A config for conveniently building any one of the different layer norm classes.
See :class:`LayerNorm` for a description of the parameters.
See the :class:`LayerNorm` subclasses to learn which fields are valid for each implementation.
"""

name: LayerNormType = LayerNormType.default
"""
- "default" ➡️ :class:`LayerNorm`
- "rms" ➡️ :class:`RMSNorm`
- "fused_rms" ➡️ :class:`FusedRMSNorm`
The name of the implementation.
"""
eps: Optional[float] = None
elementwise_affine: Optional[bool] = None
Expand All @@ -48,6 +54,11 @@ class LayerNormConfig(Config):
dtype: Optional[DType] = None

def num_params(self, size: int) -> int:
"""
The number of parameters in the module once built.
:param size: The size of the input along the dimension to be normalized.
"""
elementwise_affine = (
self.elementwise_affine
if self.elementwise_affine is not None
Expand All @@ -65,7 +76,8 @@ def build(self, size: int, init_device: str = "cpu") -> "LayerNorm":
"""
Construct the corresponding LayerNorm class.
See :class:`LayerNorm` for a description of the parameters.
:param size: The size of the input along the dimension to be normalized.
:param init_device: The device initialize the parameters on, e.g. "cpu", "meta".
"""
kwargs = self.as_dict(exclude_none=True)
kwargs.pop("name")
Expand Down Expand Up @@ -93,11 +105,7 @@ class LayerNorm(nn.Module):
"""
Layer normalization.
.. seealso::
- :class:`RMSNorm`
- :class:`FusedRMSNorm`
:param size: The hidden size / dimensionality of the input.
:param size: The size of the input along the dimension to be normalized.
:param eps: The epsilon used for numerical stability.
:param elementwise_affine: Whether to include an element-wise affine transform.
:param bias: Whether the element-wise affine should include an element-wise bias.
Expand Down Expand Up @@ -178,16 +186,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

class RMSNorm(LayerNorm):
"""
RMS norm, a simplified layer norm implementation.
.. seealso::
- :class:`LayerNorm`
- :class:`FusedRMSNorm`
RMSNorm, a simplified layer norm implementation.
"""

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply RMS norm.
Apply RMSNorm.
:param x: The input.
"""
Expand Down Expand Up @@ -270,6 +274,8 @@ class L2Norm(LayerNorm):
"""
A variant of layer norm that just normalizes the last dimension of the input by its L2 norm,
as done in nGPT.
:param size: The size of the input along the dimension to be normalized.
"""

def __init__(
Expand Down
32 changes: 26 additions & 6 deletions src/olmo_core/nn/lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,45 @@
from .functional import l2_normalize
from .layer_norm import LayerNormConfig

__all__ = ["LMHeadConfig", "LMHeadType", "LMHead", "NormalizedLMHead"]
__all__ = ["LMHeadType", "LMHeadConfig", "LMHead", "NormalizedLMHead"]


class LMHeadType(StrEnum):
"""
An enumeration of LM head types.
An enumeration of the different LM head types.
"""

default = "default"
"""
:class:`LMHead`
➡️ :class:`LMHead`
"""

normalized = "normalized"
"""
:class:`NormalizedLMHead`
➡️ :class:`NormalizedLMHead`
"""


@dataclass
class LMHeadConfig(Config):
"""
A configuration class for building an :class:`LMHead`.
A configuration class for building any of the :class:`LMHead` implementations.
See the :class:`LMHead` subclasses to learn which fields are valid for each implementation.
"""

name: LMHeadType = LMHeadType.default
"""
The name of the implementation.
"""
layer_norm: Optional[LayerNormConfig] = None
bias: Optional[bool] = None
dtype: DType = DType.float32

def num_params(self, d_model: int, vocab_size: int) -> int:
"""
The number of parameters in the module once built.
"""
bias = self.bias if self.bias is not None else self.name != LMHeadType.normalized

params = 0
Expand All @@ -59,6 +67,12 @@ def num_params(self, d_model: int, vocab_size: int) -> int:
return params

def build(self, *, d_model: int, vocab_size: int, init_device: str = "cpu") -> "LMHead":
"""
Construct the corresponding LM head implementation.
:param d_model: The model dimensionality.
:param init_device: The device initialize the parameters on, e.g. "cpu", "meta".
"""
kwargs = self.as_dict(exclude_none=True, recurse=False)
kwargs.pop("name")
kwargs.update(
Expand All @@ -83,7 +97,7 @@ def build(self, *, d_model: int, vocab_size: int, init_device: str = "cpu") -> "

class LMHead(nn.Module):
"""
The default LM head implementation.
The default language modeling head implementation.
"""

def __init__(
Expand All @@ -103,6 +117,9 @@ def __init__(
self.w_out = nn.Linear(d_model, vocab_size, bias=bias, dtype=dtype, device=init_device)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply the LM head to the hidden state ``x``, returning the logits.
"""
h = self.norm(x) if self.norm is not None else x
return self.w_out(h)

Expand Down Expand Up @@ -136,6 +153,9 @@ def __init__(
)

def reset_parameters(self):
"""
Reset the scaling parameter.
"""
nn.init.ones_(self.sz)
self.sz.mul_(self.sz_init_scaling)

Expand Down
Loading

0 comments on commit 37e0e88

Please sign in to comment.