Skip to content

Commit

Permalink
Allow customizing init func of transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Aug 27, 2024
1 parent 6a5c5cf commit 65e21ac
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 19 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for unsharding model state into `safetensors` format with `olmo_core.distributed.checkpoint.unshard_checkpoint(..., use_safetensors=True)`.
- Added `data.TokenizerConfig` config class and `data.TokenizerName` enumeration.
- Added data mixes with `data.DataMix` API.
- Added `block_idx` attribute to the `TransformerBlock` class.
- Added `init_func` parameter to `Transformer.init_weights()` and `TransformerConfig.build()`.

## [v1.0.1](https://github.com/allenai/OLMo-core/releases/tag/v1.0.1) - 2024-08-26

Expand Down
7 changes: 6 additions & 1 deletion src/olmo_core/nn/transformer/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ class TransformerBlockConfig(Config):

def build(
self,
d_model: int,
*,
d_model: int,
block_idx: int,
init_device: str = "cpu",
cache: Optional[BufferCache] = None,
) -> "TransformerBlock":
Expand All @@ -44,6 +45,7 @@ def build(
kwargs.update(
dict(
d_model=d_model,
block_idx=block_idx,
init_device=init_device,
cache=cache,
)
Expand All @@ -62,6 +64,7 @@ class TransformerBlock(nn.Module):
A typical "Llama-style" transformer block implementation.
:param d_model: The model dimensionality.
:param block_idx: The index/position of the block within the model. Ranges from 0 to ``n_layers - 1``.
:param attention: The attention module config.
:param feed_forward: The feed forward module config.
:param layer_norm: The layer norm config for both the attention LN and the feed forward LN.
Expand All @@ -73,6 +76,7 @@ def __init__(
self,
*,
d_model: int,
block_idx: int,
attention: AttentionConfig,
feed_forward: FeedForwardConfig,
layer_norm: LayerNormConfig,
Expand All @@ -82,6 +86,7 @@ def __init__(
):
super().__init__()
self.d_model = d_model
self.block_idx = block_idx
self.attention = attention.build(d_model, init_device=init_device, cache=cache)
self.attention_norm = layer_norm.build(d_model, init_device=init_device)
self.feed_forward = feed_forward.build(d_model=d_model, init_device=init_device)
Expand Down
50 changes: 32 additions & 18 deletions src/olmo_core/nn/transformer/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from dataclasses import dataclass
from typing import Literal, Optional, Sequence, Union
from typing import Callable, Literal, Optional, Sequence, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -71,6 +71,7 @@ def build(
device: Optional[torch.device] = None,
dp_mesh: Optional[DeviceMesh] = None,
max_seq_len: Optional[int] = None,
init_func: Optional[Callable[[nn.Module], None]] = None,
) -> "Transformer":
"""
Build the model corresponding to this config.
Expand Down Expand Up @@ -121,7 +122,7 @@ def build(
# Materialize and init parameters.
if device != torch.device(init_device):
model.to_empty(device=device)
model.init_weights(max_seq_len=max_seq_len)
model.init_weights(max_seq_len=max_seq_len, device=device, init_func=init_func)

return model

Expand Down Expand Up @@ -479,40 +480,53 @@ def __init__(
self.blocks = nn.ModuleList(
[
block.build(
d_model,
d_model=d_model,
block_idx=block_idx,
init_device=init_device,
cache=cache,
)
for _ in range(n_layers)
for block_idx in range(n_layers)
]
)
self.norm = layer_norm.build(d_model, init_device=init_device)
self.w_out = nn.Linear(d_model, vocab_size, bias=bias, dtype=dtype, device=init_device)
self._cache = cache

def init_weights(self, max_seq_len: Optional[int] = None):
@property
def device(self) -> torch.device:
for p in self.parameters():
if p.numel() > 0:
return p.device
return get_default_device()

def init_weights(
self,
*,
max_seq_len: Optional[int] = None,
device: Optional[torch.device] = None,
init_func: Optional[Callable[[nn.Module], None]] = None,
):
"""
Initialize the model weights.
:param max_seq_len: The maximum sequence length expected during training. This is used
to warm up the RoPE cache.
:param device: The device the local copy of the model will be trained on.
:param init_func: The function used to initialize the weights of each module.
By default this just calls ``m.reset_parameters()`` if defined.
"""
device = device or self.device

def reset_params(m: nn.Module):
if hasattr(m, "reset_parameters"):
if init_func is not None:
init_func(m)
elif hasattr(m, "reset_parameters"):
m.reset_parameters()

self.apply(reset_params)

if max_seq_len is None:
return

# Warmup RoPE embedding caches.
device = self.w_out.weight.device

def warmup_cache(m: nn.Module):
if isinstance(m, RotaryEmbeddingBase):
assert max_seq_len is not None
if max_seq_len is not None and isinstance(m, RotaryEmbeddingBase):
m.warmup_cache(max_seq_len, device)

self.apply(warmup_cache)
self.apply(reset_params)

def reset_parameters(self):
nn.init.trunc_normal_(self.embeddings.weight, mean=0.0, std=0.02)
Expand Down
4 changes: 4 additions & 0 deletions src/test/nn/transformer/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ def test_small_llama2_config_builder():
for module in model.modules():
if isinstance(module, (nn.Linear, LayerNorm)):
assert module.bias is None

# Make sure block_idx is set correctly.
assert model.blocks[0].block_idx == 0
assert model.blocks[-1].block_idx == len(model.blocks) - 1

0 comments on commit 65e21ac

Please sign in to comment.