Skip to content

Commit

Permalink
Add OLMo November 2024 (#34551)
Browse files Browse the repository at this point in the history
* Add model skeletion with transformers-cli add-new-model-like

* Convert config to modular, add rms_norm_eps, delete clip_qkv

* Convert model to modular, add RMSNorm

* Add flash attention with qk norm and no qkv clipping

* Add decoder layer with RMSNorm after attention/feedforward layers

* Add base and causal model

* Add converter improvements from OLMo repo

* Update weight loading in OLMo to HF converter

* Set correct default for rms_norm_eps

* Set correct pipeline_model_mapping in test

* Run make fixup

* Fix model type

* Re-run modular conversion

* Manually set config docs to fix build errors

* Convert olmo-1124 to olmo_1124 to fix flash attention docs errors

* Start updating tests

* Update tests

* Copy upstream test_eager_matches_sdpa_inference_1_bfloat16 changes to olmo_1124

* Rename input_layernorm and post_attention_layernorm to reflect their ops better

* Use correct tokenizer

* Remove test unsupported by GPT2 tokenizer

* Create GenerationConfig outside of from_pretrained call

* Use simpler init file structure

* Add explicit __all__ to support simplified init

* Make safetensor serialization the default

* Update OLMo November 2024 docs
  • Loading branch information
2015aroras authored Nov 18, 2024
1 parent 1349321 commit 3ee24e2
Show file tree
Hide file tree
Showing 17 changed files with 2,641 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,8 @@
title: Nyströmformer
- local: model_doc/olmo
title: OLMo
- local: model_doc/olmo_1124
title: OLMo November 2024
- local: model_doc/olmoe
title: OLMoE
- local: model_doc/open-llama
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Nougat](model_doc/nougat) ||||
| [Nyströmformer](model_doc/nystromformer) ||||
| [OLMo](model_doc/olmo) ||||
| [OLMo November 2024](model_doc/olmo_1124) ||||
| [OLMoE](model_doc/olmoe) ||||
| [OmDet-Turbo](model_doc/omdet-turbo) ||||
| [OneFormer](model_doc/oneformer) ||||
Expand Down
46 changes: 46 additions & 0 deletions docs/source/en/model_doc/olmo_1124.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# OLMo November 2024

## Overview

The OLMo November 2024 model is a successor of the OLMo model, which was proposed in
[OLMo: Accelerating the Science of Language Models](https://arxiv.org/abs/2402.00838).

The architectural changes from the original OLMo model to this model are:

- RMSNorm is used instead of standard layer norm.
- Norm is applied to attention queries and keys.
- Norm is applied after attention/feedforward layers rather than before.

This model was contributed by [shanearora](https://huggingface.co/shanearora).
The original code can be found [here](https://github.com/allenai/OLMo/tree/main/olmo).


## Olmo1124Config

[[autodoc]] Olmo1124Config

## Olmo1124Model

[[autodoc]] Olmo1124Model
- forward

## Olmo1124ForCausalLM

[[autodoc]] Olmo1124ForCausalLM
- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OLMo November 2024](https://huggingface.co/docs/transformers/model_doc/olmo_1124#transformers.Olmo1124Model)
* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
Expand Down Expand Up @@ -260,6 +261,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OLMo November 2024](https://huggingface.co/docs/transformers/model_doc/olmo_1124#transformers.Olmo1124Model)
* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
* [OPT](https://huggingface.co/docs/transformers/en/model_doc/opt)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@
"models.nougat": ["NougatProcessor"],
"models.nystromformer": ["NystromformerConfig"],
"models.olmo": ["OlmoConfig"],
"models.olmo_1124": ["Olmo1124Config"],
"models.olmoe": ["OlmoeConfig"],
"models.omdet_turbo": [
"OmDetTurboConfig",
Expand Down Expand Up @@ -2919,6 +2920,13 @@
"OlmoPreTrainedModel",
]
)
_import_structure["models.olmo_1124"].extend(
[
"Olmo1124ForCausalLM",
"Olmo1124Model",
"Olmo1124PreTrainedModel",
]
)
_import_structure["models.olmoe"].extend(
[
"OlmoeForCausalLM",
Expand Down Expand Up @@ -5506,6 +5514,7 @@
NystromformerConfig,
)
from .models.olmo import OlmoConfig
from .models.olmo_1124 import Olmo1124Config
from .models.olmoe import OlmoeConfig
from .models.omdet_turbo import (
OmDetTurboConfig,
Expand Down Expand Up @@ -7523,6 +7532,11 @@
OlmoModel,
OlmoPreTrainedModel,
)
from .models.olmo_1124 import (
Olmo1124ForCausalLM,
Olmo1124Model,
Olmo1124PreTrainedModel,
)
from .models.olmoe import (
OlmoeForCausalLM,
OlmoeModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@
nougat,
nystromformer,
olmo,
olmo_1124,
olmoe,
omdet_turbo,
oneformer,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
("nougat", "VisionEncoderDecoderConfig"),
("nystromformer", "NystromformerConfig"),
("olmo", "OlmoConfig"),
("olmo_1124", "Olmo1124Config"),
("olmoe", "OlmoeConfig"),
("omdet-turbo", "OmDetTurboConfig"),
("oneformer", "OneFormerConfig"),
Expand Down Expand Up @@ -510,6 +511,7 @@
("nougat", "Nougat"),
("nystromformer", "Nyströmformer"),
("olmo", "OLMo"),
("olmo_1124", "OLMo November 2024"),
("olmoe", "OLMoE"),
("omdet-turbo", "OmDet-Turbo"),
("oneformer", "OneFormer"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
("nllb-moe", "NllbMoeModel"),
("nystromformer", "NystromformerModel"),
("olmo", "OlmoModel"),
("olmo_1124", "Olmo1124Model"),
("olmoe", "OlmoeModel"),
("omdet-turbo", "OmDetTurboForObjectDetection"),
("oneformer", "OneFormerModel"),
Expand Down Expand Up @@ -516,6 +517,7 @@
("mvp", "MvpForCausalLM"),
("nemotron", "NemotronForCausalLM"),
("olmo", "OlmoForCausalLM"),
("olmo_1124", "Olmo1124ForCausalLM"),
("olmoe", "OlmoeForCausalLM"),
("open-llama", "OpenLlamaForCausalLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@
),
),
("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("olmo_1124", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
(
"omdet-turbo",
Expand Down
27 changes: 27 additions & 0 deletions src/transformers/models/olmo_1124/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_olmo_1124 import *
from .modeling_olmo_1124 import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
166 changes: 166 additions & 0 deletions src/transformers/models/olmo_1124/configuration_olmo_1124.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/olmo_1124/modular_olmo_1124.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_olmo_1124.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨

from ...configuration_utils import PretrainedConfig


class Olmo1124Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Olmo1124Model`]. It is used to instantiate an OLMo November 2024
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the [allenai/Olmo1124-7B-hf](https://huggingface.co/allenai/Olmo1124-7B-hf).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50304):
Vocabulary size of the Olmo1124 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Olmo1124Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*, defaults to 1):
Padding token id.
bos_token_id (`int`, *optional*):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 50279):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
```python
>>> from transformers import Olmo1124Model, Olmo1124Config
>>> # Initializing a Olmo November 2024 7B style configuration
>>> configuration = Olmo1124Config()
>>> # Initializing a model from the Olmo November 2024 7B style configuration
>>> model = Olmo1124Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""

model_type = "olmo_1124"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=50304,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
use_cache=True,
pad_token_id=1,
bos_token_id=None,
eos_token_id=50279,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
rms_norm_eps=1e-5,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout

self.rms_norm_eps = rms_norm_eps

def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")


__all__ = ["Olmo1124Config"]
Loading

0 comments on commit 3ee24e2

Please sign in to comment.