Skip to content

Commit

Permalink
Add init_method option to Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Aug 27, 2024
1 parent 65e21ac commit 3d53c65
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 47 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `data.TokenizerConfig` config class and `data.TokenizerName` enumeration.
- Added data mixes with `data.DataMix` API.
- Added `block_idx` attribute to the `TransformerBlock` class.
- Added `init_func` parameter to `Transformer.init_weights()` and `TransformerConfig.build()`.
- Added `init_method` option to `Transformer` for controlling how the weights are initialized.

## [v1.0.1](https://github.com/allenai/OLMo-core/releases/tag/v1.0.1) - 2024-08-26

Expand Down
12 changes: 0 additions & 12 deletions src/olmo_core/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,6 @@ def __init__(
self._flash_attn_func = flash_attn_func
self._flash_attn_varlen_func = flash_attn_varlen_func

def reset_parameters(self):
for w in (self.w_q, self.w_k, self.w_v, self.w_out):
nn.init.trunc_normal_(w.weight, mean=0.0, std=0.02)
if w.bias is not None:
nn.init.zeros_(w.bias)

def forward(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -334,12 +328,6 @@ def __init__(
self._flash_attn_qkvpacked_func = flash_attn_qkvpacked_func
self._flash_attn_varlen_qkvpacked_func = flash_attn_varlen_qkvpacked_func

def reset_parameters(self):
for w in (self.w_qkv, self.w_out):
nn.init.trunc_normal_(w.weight, mean=0.0, std=0.02)
if w.bias is not None:
nn.init.zeros_(w.bias)

def forward(
self,
x: torch.Tensor,
Expand Down
6 changes: 0 additions & 6 deletions src/olmo_core/nn/feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ def __init__(
self.w2 = nn.Linear(hidden_size, d_model, bias=bias, dtype=dtype, device=init_device)
self.w3 = nn.Linear(d_model, hidden_size, bias=bias, dtype=dtype, device=init_device)

def reset_parameters(self):
for w in (self.w1, self.w2, self.w3):
nn.init.trunc_normal_(w.weight, mean=0.0, std=0.02)
if w.bias is not None:
nn.init.zeros_(w.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Run the feed-forward on the input ``x``.
Expand Down
6 changes: 4 additions & 2 deletions src/olmo_core/nn/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@
"""

from .block import TransformerBlock, TransformerBlockConfig, TransformerBlockType
from .init import InitMethod
from .model import (
Transformer,
TransformerActivationCheckpointingConfig,
TransformerConfig,
)

__all__ = [
"TransformerConfig",
"Transformer",
"TransformerBlockType",
"TransformerBlockConfig",
"TransformerBlock",
"TransformerConfig",
"TransformerActivationCheckpointingConfig",
"Transformer",
"InitMethod",
]
75 changes: 75 additions & 0 deletions src/olmo_core/nn/transformer/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Union

import torch.nn as nn

from olmo_core.config import StrEnum

from ..attention import Attention, FusedAttention
from ..feed_forward import FeedForward


class InitMethod(StrEnum):
normal = "normal"
"""
Every linear and embedding layer and initialized from a truncated normal distributed
with standard deviation 0.02.
"""

llama = "llama"
"""
Like :data:`normal`, but "output" layers are initialized with a standard deviation that's
dependent on either ``d_model`` or the number of layers.
"""

llama_depth = "llama_depth"
"""
Like :data:`normal`, but "output" layers are initialized with a standard deviation that's
dependent on either ``d_model`` or the layer index.
"""

def _init_linear(self, m: nn.Linear, *, std: float = 0.02):
nn.init.trunc_normal_(m.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
if m.bias is not None:
nn.init.zeros_(m.bias)

def init_embeddings(self, m: nn.Embedding):
if self in (InitMethod.llama, InitMethod.llama_depth):
nn.init.normal_(m.weight)
else:
nn.init.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-3 * 0.02, b=3 * 0.02)

def init_final_w_out(self, m: nn.Linear, *, d_model: int):
std = 0.02
if self in (InitMethod.llama, InitMethod.llama_depth):
std = d_model**-0.05
self._init_linear(m, std=std)

def init_attention(
self, m: Union[Attention, FusedAttention], *, block_idx: int, num_blocks: int
):
std = 0.02
if self == InitMethod.llama:
std = 0.02 / (2 * num_blocks) ** 0.5
elif self == InitMethod.llama_depth:
std = 0.02 / (2 * (block_idx + 1)) ** 0.5

if isinstance(m, Attention):
for w in (m.w_q, m.w_k, m.w_v):
self._init_linear(w, std=0.02)
elif isinstance(m, FusedAttention):
self._init_linear(m.w_qkv, std=0.02)
else:
raise NotImplementedError(m)

self._init_linear(m.w_out, std=std)

def init_feed_forward(self, m: FeedForward, *, block_idx: int, num_blocks: int):
std = 0.02
if self == InitMethod.llama:
std = 0.02 / (2 * num_blocks) ** 0.5
elif self == InitMethod.llama_depth:
std = 0.02 / (2 * (block_idx + 1)) ** 0.5

self._init_linear(m.w1, std=0.02)
self._init_linear(m.w2, std=std)
self._init_linear(m.w3, std=std)
95 changes: 70 additions & 25 deletions src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Callable, Literal, Optional, Sequence, Union
from typing import Literal, Optional, Sequence, Union

import torch
import torch.nn as nn
Expand All @@ -14,12 +14,13 @@
has_flash_attn,
)

from ..attention import AttentionConfig, AttentionType
from ..attention import Attention, AttentionConfig, AttentionType
from ..buffer_cache import BufferCache
from ..feed_forward import FeedForwardConfig
from ..layer_norm import LayerNormConfig, LayerNormType
from ..rope import RoPEConfig, RoPEType, RotaryEmbeddingBase
from .block import TransformerBlockConfig, TransformerBlockType
from ..rope import RoPEConfig, RoPEType
from .block import TransformerBlock, TransformerBlockConfig, TransformerBlockType
from .init import InitMethod
from .utils import apply_activation_checkpointing_to_transformer_block

__all__ = ["TransformerConfig", "Transformer"]
Expand Down Expand Up @@ -60,6 +61,7 @@ class TransformerConfig(Config):
layer_norm: LayerNormConfig
bias: bool = True
dtype: DType = DType.float32
init_method: InitMethod = InitMethod.normal
compile: bool = False
dp_config: Optional[DataParallelConfig] = None
ac_config: Optional[TransformerActivationCheckpointingConfig] = None
Expand All @@ -71,10 +73,18 @@ def build(
device: Optional[torch.device] = None,
dp_mesh: Optional[DeviceMesh] = None,
max_seq_len: Optional[int] = None,
init_func: Optional[Callable[[nn.Module], None]] = None,
) -> "Transformer":
"""
Build the model corresponding to this config.
Build the model corresponding to this config, potentially applying activation checkpointing,
compilation, FSDP or DDP, etc, and eventually calling :meth:`Transformer.init_weights()`.
:param init_device: The device to put the parameters on during initialization. In a
distributed setting it usually makes sense to set this to "meta".
:param device: The device to put the model on after initialization.
:param dp_mesh: Data parallel device mesh. This can be used to configure hybrid sharding
with FSDP. See :func:`~olmo_core.distributed.utils.init_hybrid_shard_mesh()` for
easily creating such a mesh.
:param max_seq_len: The maximum sequence length expected.
"""
device = device or get_default_device()

Expand All @@ -90,6 +100,7 @@ def build(
layer_norm=self.layer_norm,
bias=self.bias,
dtype=self.dtype.as_pt(),
init_method=self.init_method,
init_device=init_device,
)
log.info("%s", model)
Expand Down Expand Up @@ -122,7 +133,7 @@ def build(
# Materialize and init parameters.
if device != torch.device(init_device):
model.to_empty(device=device)
model.init_weights(max_seq_len=max_seq_len, device=device, init_func=init_func)
model.init_weights(max_seq_len=max_seq_len, device=device)

return model

Expand Down Expand Up @@ -470,6 +481,7 @@ def __init__(
layer_norm: LayerNormConfig,
bias: bool = True,
dtype: torch.dtype = torch.float32,
init_method: InitMethod = InitMethod.normal,
init_device: str = "cpu",
):
super().__init__()
Expand All @@ -490,6 +502,7 @@ def __init__(
)
self.norm = layer_norm.build(d_model, init_device=init_device)
self.w_out = nn.Linear(d_model, vocab_size, bias=bias, dtype=dtype, device=init_device)
self.init_method = InitMethod(init_method)
self._cache = cache

@property
Expand All @@ -499,40 +512,57 @@ def device(self) -> torch.device:
return p.device
return get_default_device()

@torch.no_grad()
def init_weights(
self,
*,
max_seq_len: Optional[int] = None,
device: Optional[torch.device] = None,
init_func: Optional[Callable[[nn.Module], None]] = None,
):
"""
Initialize the model weights.
:param max_seq_len: The maximum sequence length expected during training. This is used
to warm up the RoPE cache.
:param device: The device the local copy of the model will be trained on.
:param init_func: The function used to initialize the weights of each module.
By default this just calls ``m.reset_parameters()`` if defined.
"""
device = device or self.device

def reset_params(m: nn.Module):
if init_func is not None:
init_func(m)
elif hasattr(m, "reset_parameters"):
m.reset_parameters()
if self.embeddings is not None:
self.init_method.init_embeddings(self.embeddings)

if max_seq_len is not None and isinstance(m, RotaryEmbeddingBase):
m.warmup_cache(max_seq_len, device)
for block in self.blocks:
assert isinstance(block, TransformerBlock)

# Norms.
block_norms = [block.attention_norm, block.feed_forward_norm]
if isinstance(block.attention, Attention):
if block.attention.q_norm is not None:
block_norms.append(block.attention.q_norm)
if block.attention.k_norm is not None:
block_norms.append(block.attention.k_norm)
for norm in block_norms:
norm.reset_parameters()

# Attention weights.
self.init_method.init_attention(
block.attention, block_idx=block.block_idx, num_blocks=len(self.blocks)
)

self.apply(reset_params)
# Feed-forward weights.
self.init_method.init_feed_forward(
block.feed_forward, block_idx=block.block_idx, num_blocks=len(self.blocks)
)

def reset_parameters(self):
nn.init.trunc_normal_(self.embeddings.weight, mean=0.0, std=0.02)
nn.init.trunc_normal_(self.w_out.weight, mean=0.0, std=0.02)
if self.w_out.bias is not None:
nn.init.zeros_(self.bias)
# Warm up RoPE cache.
if max_seq_len is not None and block.attention.rope is not None:
block.attention.rope.warmup_cache(max_seq_len, device)

if self.norm is not None:
self.norm.reset_parameters()

if self.w_out is not None:
self.init_method.init_final_w_out(self.w_out, d_model=self.d_model)

def forward(
self,
Expand Down Expand Up @@ -574,6 +604,10 @@ def apply_activation_checkpointing(
"""
Apply activation checkpointing to the model.
.. warning::
Usually this does not need to be called directly, as :meth:`TransformerConfig.build()`
will call it for you.
:param mode: Either "full" for apply AC to each block, or "selective" which depends on
the value of ``selective_option``.
:param selective_option: If "op", AC is applied individual operations. If an int, it's
Expand All @@ -593,8 +627,11 @@ def apply_compile(self):
due to repeated structure.
.. warning::
This should be called after :meth:`apply_activation_checkpointing()` but before
:meth:`apply_fsdp2()` or :meth:`apply_ddp2()`.
Usually this does not need to be called directly, as :meth:`TransformerConfig.build()`
will call it for you.
If you do use this directly note that it must be called after
:meth:`apply_activation_checkpointing()` but before :meth:`apply_fsdp()` or :meth:`apply_ddp()`.
"""
for block_id, block in self.blocks.named_children():
block = torch.compile(block, fullgraph=False)
Expand All @@ -612,6 +649,10 @@ def apply_fsdp(
"""
Apply FSDP(2) to the model.
.. warning::
Usually this does not need to be called directly, as :meth:`TransformerConfig.build()`
will call it for you.
:param dp_mesh: The data parallel device mesh.
:param param_dtype: The data type to materialize params in. Defaults to the current param dtype.
:param reduce_dtype: The data type for gradient reduction.
Expand Down Expand Up @@ -653,6 +694,10 @@ def apply_ddp(
):
"""
Apply DDP to the model.
.. warning::
Usually this does not need to be called directly, as :meth:`TransformerConfig.build()`
will call it for you.
"""
from torch.distributed._composable.replicate import replicate

Expand Down
5 changes: 4 additions & 1 deletion src/test/nn/transformer/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ def test_small_llama2_config_builder():
num_actual_params += p.numel()
assert config.num_params == num_actual_params

# Make sure there are no biases anywhere.
for module in model.modules():
# Make sure there are no biases anywhere and layer norm weights are all 1.
if isinstance(module, (nn.Linear, LayerNorm)):
assert module.bias is None
if isinstance(module, LayerNorm):
assert module.weight is not None
assert (module.weight == 1).all()

# Make sure block_idx is set correctly.
assert model.blocks[0].block_idx == 0
Expand Down

0 comments on commit 3d53c65

Please sign in to comment.