From 7324ba0eadea285a86893b953399abba292e6850 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 18 Nov 2024 01:43:10 -0800 Subject: [PATCH] Add OLMo November 2024 (#34551) * Add model skeletion with transformers-cli add-new-model-like * Convert config to modular, add rms_norm_eps, delete clip_qkv * Convert model to modular, add RMSNorm * Add flash attention with qk norm and no qkv clipping * Add decoder layer with RMSNorm after attention/feedforward layers * Add base and causal model * Add converter improvements from OLMo repo * Update weight loading in OLMo to HF converter * Set correct default for rms_norm_eps * Set correct pipeline_model_mapping in test * Run make fixup * Fix model type * Re-run modular conversion * Manually set config docs to fix build errors * Convert olmo-1124 to olmo_1124 to fix flash attention docs errors * Start updating tests * Update tests * Copy upstream test_eager_matches_sdpa_inference_1_bfloat16 changes to olmo_1124 * Rename input_layernorm and post_attention_layernorm to reflect their ops better * Use correct tokenizer * Remove test unsupported by GPT2 tokenizer * Create GenerationConfig outside of from_pretrained call * Use simpler init file structure * Add explicit __all__ to support simplified init * Make safetensor serialization the default * Update OLMo November 2024 docs --- docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/olmo_1124.md | 46 + docs/source/en/perf_infer_gpu_one.md | 2 + src/transformers/__init__.py | 14 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/tokenization_auto.py | 1 + src/transformers/models/olmo_1124/__init__.py | 27 + .../olmo_1124/configuration_olmo_1124.py | 166 +++ .../convert_olmo_1124_weights_to_hf.py | 304 +++++ .../models/olmo_1124/modeling_olmo_1124.py | 1095 +++++++++++++++++ .../models/olmo_1124/modular_olmo_1124.py | 489 ++++++++ src/transformers/utils/dummy_pt_objects.py | 21 + tests/models/olmo_1124/__init__.py | 0 .../olmo_1124/test_modeling_olmo_1124.py | 468 +++++++ 17 files changed, 2641 insertions(+) create mode 100644 docs/source/en/model_doc/olmo_1124.md create mode 100644 src/transformers/models/olmo_1124/__init__.py create mode 100644 src/transformers/models/olmo_1124/configuration_olmo_1124.py create mode 100644 src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py create mode 100644 src/transformers/models/olmo_1124/modeling_olmo_1124.py create mode 100644 src/transformers/models/olmo_1124/modular_olmo_1124.py create mode 100644 tests/models/olmo_1124/__init__.py create mode 100644 tests/models/olmo_1124/test_modeling_olmo_1124.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8788c26c090643..44badd0f412302 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -514,6 +514,8 @@ title: Nyströmformer - local: model_doc/olmo title: OLMo + - local: model_doc/olmo_1124 + title: OLMo November 2024 - local: model_doc/olmoe title: OLMoE - local: model_doc/open-llama diff --git a/docs/source/en/index.md b/docs/source/en/index.md index c5c6763126dd6e..3ed0f97090b85e 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -241,6 +241,7 @@ Flax), PyTorch, and/or TensorFlow. | [Nougat](model_doc/nougat) | ✅ | ✅ | ✅ | | [Nyströmformer](model_doc/nystromformer) | ✅ | ❌ | ❌ | | [OLMo](model_doc/olmo) | ✅ | ❌ | ❌ | +| [OLMo November 2024](model_doc/olmo_1124) | ✅ | ❌ | ❌ | | [OLMoE](model_doc/olmoe) | ✅ | ❌ | ❌ | | [OmDet-Turbo](model_doc/omdet-turbo) | ✅ | ❌ | ❌ | | [OneFormer](model_doc/oneformer) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/olmo_1124.md b/docs/source/en/model_doc/olmo_1124.md new file mode 100644 index 00000000000000..f36ec438e57a84 --- /dev/null +++ b/docs/source/en/model_doc/olmo_1124.md @@ -0,0 +1,46 @@ + + +# OLMo November 2024 + +## Overview + +The OLMo November 2024 model is a successor of the OLMo model, which was proposed in +[OLMo: Accelerating the Science of Language Models](https://arxiv.org/abs/2402.00838). + + The architectural changes from the original OLMo model to this model are: + +- RMSNorm is used instead of standard layer norm. +- Norm is applied to attention queries and keys. +- Norm is applied after attention/feedforward layers rather than before. + +This model was contributed by [shanearora](https://huggingface.co/shanearora). +The original code can be found [here](https://github.com/allenai/OLMo/tree/main/olmo). + + +## Olmo1124Config + +[[autodoc]] Olmo1124Config + +## Olmo1124Model + +[[autodoc]] Olmo1124Model + - forward + +## Olmo1124ForCausalLM + +[[autodoc]] Olmo1124ForCausalLM + - forward diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 67bd31fdaeede5..84109746f95998 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -77,6 +77,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron) * [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) +* [OLMo November 2024](https://huggingface.co/docs/transformers/model_doc/olmo_1124#transformers.Olmo1124Model) * [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel) * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration) @@ -260,6 +261,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) +* [OLMo November 2024](https://huggingface.co/docs/transformers/model_doc/olmo_1124#transformers.Olmo1124Model) * [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel) * [OPT](https://huggingface.co/docs/transformers/en/model_doc/opt) * [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 104347af018d72..4c5121b5932d8e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -621,6 +621,7 @@ "models.nougat": ["NougatProcessor"], "models.nystromformer": ["NystromformerConfig"], "models.olmo": ["OlmoConfig"], + "models.olmo_1124": ["Olmo1124Config"], "models.olmoe": ["OlmoeConfig"], "models.omdet_turbo": [ "OmDetTurboConfig", @@ -2928,6 +2929,13 @@ "OlmoPreTrainedModel", ] ) + _import_structure["models.olmo_1124"].extend( + [ + "Olmo1124ForCausalLM", + "Olmo1124Model", + "Olmo1124PreTrainedModel", + ] + ) _import_structure["models.olmoe"].extend( [ "OlmoeForCausalLM", @@ -5516,6 +5524,7 @@ NystromformerConfig, ) from .models.olmo import OlmoConfig + from .models.olmo_1124 import Olmo1124Config from .models.olmoe import OlmoeConfig from .models.omdet_turbo import ( OmDetTurboConfig, @@ -7539,6 +7548,11 @@ OlmoModel, OlmoPreTrainedModel, ) + from .models.olmo_1124 import ( + Olmo1124ForCausalLM, + Olmo1124Model, + Olmo1124PreTrainedModel, + ) from .models.olmoe import ( OlmoeForCausalLM, OlmoeModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 69ed5a446b55ba..198db614157357 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -178,6 +178,7 @@ nougat, nystromformer, olmo, + olmo_1124, olmoe, omdet_turbo, oneformer, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 8151f251f0d691..17f13813e13c2f 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -196,6 +196,7 @@ ("nougat", "VisionEncoderDecoderConfig"), ("nystromformer", "NystromformerConfig"), ("olmo", "OlmoConfig"), + ("olmo_1124", "Olmo1124Config"), ("olmoe", "OlmoeConfig"), ("omdet-turbo", "OmDetTurboConfig"), ("oneformer", "OneFormerConfig"), @@ -512,6 +513,7 @@ ("nougat", "Nougat"), ("nystromformer", "Nyströmformer"), ("olmo", "OLMo"), + ("olmo_1124", "OLMo November 2024"), ("olmoe", "OLMoE"), ("omdet-turbo", "OmDet-Turbo"), ("oneformer", "OneFormer"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ba39fa27a6cc10..dfad3fad246b46 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -185,6 +185,7 @@ ("nllb-moe", "NllbMoeModel"), ("nystromformer", "NystromformerModel"), ("olmo", "OlmoModel"), + ("olmo_1124", "Olmo1124Model"), ("olmoe", "OlmoeModel"), ("omdet-turbo", "OmDetTurboForObjectDetection"), ("oneformer", "OneFormerModel"), @@ -517,6 +518,7 @@ ("mvp", "MvpForCausalLM"), ("nemotron", "NemotronForCausalLM"), ("olmo", "OlmoForCausalLM"), + ("olmo_1124", "Olmo1124ForCausalLM"), ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 7674ea51a53377..4ed67df0e84b52 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -348,6 +348,7 @@ ), ), ("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("olmo_1124", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ( "omdet-turbo", diff --git a/src/transformers/models/olmo_1124/__init__.py b/src/transformers/models/olmo_1124/__init__.py new file mode 100644 index 00000000000000..5d4127766c293f --- /dev/null +++ b/src/transformers/models/olmo_1124/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_olmo_1124 import * + from .modeling_olmo_1124 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/olmo_1124/configuration_olmo_1124.py b/src/transformers/models/olmo_1124/configuration_olmo_1124.py new file mode 100644 index 00000000000000..b7f6c57ae446e7 --- /dev/null +++ b/src/transformers/models/olmo_1124/configuration_olmo_1124.py @@ -0,0 +1,166 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/olmo_1124/modular_olmo_1124.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_olmo_1124.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + +from ...configuration_utils import PretrainedConfig + + +class Olmo1124Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Olmo1124Model`]. It is used to instantiate an OLMo November 2024 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/Olmo1124-7B-hf](https://huggingface.co/allenai/Olmo1124-7B-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the Olmo1124 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Olmo1124Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + + ```python + >>> from transformers import Olmo1124Model, Olmo1124Config + + >>> # Initializing a Olmo November 2024 7B style configuration + >>> configuration = Olmo1124Config() + + >>> # Initializing a model from the Olmo November 2024 7B style configuration + >>> model = Olmo1124Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "olmo_1124" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + self.rms_norm_eps = rms_norm_eps + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") + + +__all__ = ["Olmo1124Config"] diff --git a/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py new file mode 100644 index 00000000000000..da35fc8301fd37 --- /dev/null +++ b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py @@ -0,0 +1,304 @@ +# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os +import shutil +from pathlib import Path +from typing import Any, Dict + +import torch +import yaml +from tokenizers import Tokenizer + +from transformers import Olmo1124Config, Olmo1124ForCausalLM +from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast + + +""" +Sample usage: + +``` +python src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py \ + --input_dir /path/to/downloaded/olmo_1124/weights --model_size 7B --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import Olmo1124ForCausalLM, AutoTokenizer + +model = Olmo1124ForCausalLM.from_pretrained("/output/path") +tokenizer = AutoTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model( + model_path, + input_base_path, + include_tokenizer=True, + tokenizer_path=None, + safe_serialization=True, + fix_eos_token_id=True, + tmp_cleanup=True, +): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + config_path = Path(input_base_path) / "config.yaml" + olmo_1124_config = yaml.safe_load(config_path.read_text())["model"] + + if not olmo_1124_config.get("attention_layer_norm", False): + raise RuntimeError("OLMo November 2024 checkpoints must have attention layer norm") + if not olmo_1124_config.get("norm_after", False): + raise RuntimeError("OLMo November 2024 checkpoints must set norm_after to True") + + n_layers = olmo_1124_config["n_layers"] + n_heads = olmo_1124_config["n_heads"] + dim = olmo_1124_config["d_model"] + dims_per_head = dim // n_heads + base = olmo_1124_config["rope_theta"] + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + max_position_embeddings = olmo_1124_config["max_sequence_length"] + + vocab_size = olmo_1124_config.get("embedding_size", olmo_1124_config["vocab_size"]) + + if olmo_1124_config.get("n_kv_heads", None) is not None: + num_key_value_heads = olmo_1124_config["n_kv_heads"] # for GQA / MQA + elif olmo_1124_config["multi_query_attention"]: # compatibility with other checkpoints + num_key_value_heads = 1 + else: + num_key_value_heads = n_heads + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu") + + param_count = 0 + index_dict: Dict[str, Any] = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + # Unsharded + # TODO: Layernorm stuff + # TODO: multi query attention + fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads] + q_proj_weight, k_proj_weight, v_proj_weight = torch.split( + loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0 + ) + up_proj_weight, gate_proj_weight = torch.chunk( + loaded[f"transformer.blocks.{layer_i}.ff_proj.weight"], 2, dim=0 + ) + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight, + f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight, + f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight, + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"], + f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded[f"transformer.blocks.{layer_i}.q_norm.weight"], + f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded[f"transformer.blocks.{layer_i}.k_norm.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight, + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight, + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ + f"transformer.blocks.{layer_i}.attn_norm.weight" + ], + f"model.layers.{layer_i}.post_feedforward_layernorm.weight": loaded[ + f"transformer.blocks.{layer_i}.ff_norm.weight" + ], + } + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + + # Unsharded + # TODO: Deal with weight-tying + state_dict = { + "model.embed_tokens.weight": loaded["transformer.wte.weight"], + "model.norm.weight": loaded["transformer.ln_f.weight"], + "lm_head.weight": loaded["transformer.ff_out.weight"] + if "transformer.ff_out.weight" in loaded + else loaded["transformer.wte.weight"], + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + + if olmo_1124_config.get("mlp_hidden_size", None) is not None: + intermediate_size = olmo_1124_config["mlp_hidden_size"] // 2 + else: + intermediate_size = (dim * olmo_1124_config["mlp_ratio"]) // 2 + + if fix_eos_token_id and olmo_1124_config["eos_token_id"] == 0: + # Fixing a bug in OLMo where eos token id was incorrectly set + print("Changing eos_token_id from 0 to 50279.") + olmo_1124_config["eos_token_id"] = 50279 + + config = Olmo1124Config( + vocab_size=vocab_size, + hidden_size=dim, + intermediate_size=intermediate_size, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + num_key_value_heads=num_key_value_heads, + max_position_embeddings=max_position_embeddings, + pad_token_id=olmo_1124_config["pad_token_id"], + bos_token_id=None, + eos_token_id=olmo_1124_config["eos_token_id"], + tie_word_embeddings=olmo_1124_config["weight_tying"], + rms_norm_eps=olmo_1124_config["layer_norm_eps"], + rope_theta=base, + ) + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + if include_tokenizer: + _write_tokenizer(model_path, config, input_base_path, tokenizer_path) + + print("Loading the checkpoint in a OLMo November 2024 model.") + model = Olmo1124ForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True) + # Avoid saving this as part of the config. + del model.config._name_or_path + print("Saving in the Transformers format.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + if tmp_cleanup: + # Make cleanup optional; attempting to `rmtree` the `tmp_model_path` causes + # errors if using NFS. + shutil.rmtree(tmp_model_path) + + +def _write_tokenizer( + output_path: Path, + config: Olmo1124Config, + checkpoint_dir: str, + input_tokenizer_path: Path | None, +) -> None: + print(f"Saving a {GPT2TokenizerFast.__name__} to {output_path}.") + + if input_tokenizer_path is not None: + base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path)) + else: + config_path = Path(checkpoint_dir) / "config.yaml" + tokenizer_config = yaml.safe_load(config_path.read_text())["tokenizer"] + + # Initialize tokenizer and validate vocab size. + if Path(tokenizer_config["identifier"]).is_file(): + base_tokenizer = Tokenizer.from_file(tokenizer_config["identifier"]) + else: + base_tokenizer = Tokenizer.from_pretrained(tokenizer_config["identifier"]) + + eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1 + pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id + + tokenizer = GPT2TokenizerFast( + tokenizer_object=base_tokenizer, + eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False), + pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False), + ) + + tokenizer.save_pretrained(output_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + required=True, + help="Location of OLMo November 2024 weights, which contains config.yaml and model.pt.", + ) + parser.add_argument( + "--no_tokenizer", + action="store_false", + dest="include_tokenizer", + help="If set, do not convert OLMo tokenizer to HF tokenizer.", + ) + parser.add_argument( + "--tokenizer_json_path", + type=Path, + default=None, + help="Location of OLMo November 2024 tokenizer json file. Defaults to what is set in the config file.", + ) + parser.add_argument( + "--output_dir", + required=True, + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--no_fix_eos_token_id", + action="store_false", + dest="fix_eos_token_id", + help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.", + ) + parser.add_argument( + "--no_tmp_cleanup", + action="store_false", + dest="tmp_cleanup", + help="If passed, don't remove temp dir at end of HF conversion.", + ) + parser.add_argument( + "--no_safe_serialization", + action="store_false", + dest="safe_serialization", + help="Whether or not to save using `safetensors`.", + ) + args = parser.parse_args() + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + safe_serialization=args.safe_serialization, + include_tokenizer=args.include_tokenizer, + tokenizer_path=args.tokenizer_json_path, + fix_eos_token_id=args.fix_eos_token_id, + tmp_cleanup=args.tmp_cleanup, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py new file mode 100644 index 00000000000000..52eb1f544bb484 --- /dev/null +++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py @@ -0,0 +1,1095 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/olmo_1124/modular_olmo_1124.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_olmo_1124.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_olmo_1124 import Olmo1124Config + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Olmo1124Config" + + +class Olmo1124RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Olmo1124RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo1124 +# TODO(joao): add me back asap :) +class Olmo1124RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo1124 +# TODO(joao): add me back asap :) +class Olmo1124LinearScalingRotaryEmbedding(Olmo1124RotaryEmbedding): + """Olmo1124RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo1124 +# TODO(joao): add me back asap :) +class Olmo1124DynamicNTKScalingRotaryEmbedding(Olmo1124RotaryEmbedding): + """Olmo1124RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Olmo1124Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo1124 + # TODO(joao): add me back asap :) + def __init__(self, config: Olmo1124Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + self.q_norm = Olmo1124RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = Olmo1124RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = Olmo1124RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = Olmo1124LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = Olmo1124DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Olmo1124FlashAttention2(Olmo1124Attention): + """ + Olmo1124 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + + OLMo November 2024 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (OlmoRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Olmo1124SdpaAttention(Olmo1124Attention): + """ + Olmo1124 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Olmo1124Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Olmo1124Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Olmo1124Model is using Olmo1124SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = attention_mask + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value + + +class Olmo1124MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +OLMO_1124_ATTENTION_CLASSES = { + "eager": Olmo1124Attention, + "flash_attention_2": Olmo1124FlashAttention2, + "sdpa": Olmo1124SdpaAttention, +} + + +class Olmo1124DecoderLayer(nn.Module): + def __init__(self, config: Olmo1124Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = OLMO_1124_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = Olmo1124MLP(config) + self.post_attention_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward + # TODO(joao): add me back asap :) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + return outputs + + +OLMO_1124_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Olmo1124Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Olmo1124 Model outputting raw hidden-states without any specific head on top.", + OLMO_1124_START_DOCSTRING, +) +class Olmo1124PreTrainedModel(PreTrainedModel): + config_class = Olmo1124Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Olmo1124DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OLMO_1124_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Olmo1124 Model outputting raw hidden-states without any specific head on top.", + OLMO_1124_START_DOCSTRING, +) +class Olmo1124Model(Olmo1124PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Olmo1124DecoderLayer`] + + Args: + config: Olmo1124Config + """ + + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Olmo1124DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(OLMO_1124_INPUTS_DOCSTRING) + # copied from transformers.models.llama.modeling_llama.LlamaModel.forward + # TODO(joao): add me back asap :) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class Olmo1124ForCausalLM(Olmo1124PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.model = Olmo1124Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(OLMO_1124_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Olmo1124ForCausalLM + + >>> model = Olmo1124ForCausalLM.from_pretrained("allenai/Olmo1124-1B-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo1124-1B-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = ["Olmo1124ForCausalLM", "Olmo1124Model", "Olmo1124PreTrainedModel"] diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py new file mode 100644 index 00000000000000..2305b1f400054a --- /dev/null +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -0,0 +1,489 @@ +import math +from typing import Optional, Tuple + +import torch +from torch import nn + +from ...cache_utils import Cache +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging +from ..llama.modeling_llama import LlamaRMSNorm +from ..olmo.configuration_olmo import OlmoConfig +from ..olmo.modeling_olmo import ( + OlmoAttention, + OlmoDecoderLayer, + OlmoFlashAttention2, + OlmoForCausalLM, + OlmoModel, + OlmoPreTrainedModel, + OlmoSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, +) + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + +logger = logging.get_logger(__name__) + + +class Olmo1124Config(OlmoConfig): + r""" + This is the configuration class to store the configuration of a [`Olmo1124Model`]. It is used to instantiate an OLMo November 2024 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/Olmo1124-7B-hf](https://huggingface.co/allenai/Olmo1124-7B-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the Olmo1124 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Olmo1124Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + + ```python + >>> from transformers import Olmo1124Model, Olmo1124Config + + >>> # Initializing a Olmo November 2024 7B style configuration + >>> configuration = Olmo1124Config() + + >>> # Initializing a model from the Olmo November 2024 7B style configuration + >>> model = Olmo1124Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "olmo_1124" + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-5, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + use_cache=use_cache, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + **kwargs, + ) + + self.rms_norm_eps = rms_norm_eps + del self.clip_qkv + + +class Olmo1124RMSNorm(LlamaRMSNorm): + pass + + +ALL_LAYERNORM_LAYERS.append(Olmo1124RMSNorm) + + +# Olmo1124 attention is identical to OLMo attention except: +# - Norm is applied to attention queries and keys. +# - No qkv clipping. +class Olmo1124Attention(OlmoAttention): + def __init__(self, config: Olmo1124Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx=layer_idx) + self.q_norm = Olmo1124RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = Olmo1124RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Olmo1124FlashAttention2(OlmoFlashAttention2, Olmo1124Attention): + """ + OLMo November 2024 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + Olmo1124Attention.__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (OlmoRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Olmo1124SdpaAttention(OlmoSdpaAttention, Olmo1124Attention): + # Adapted from Olmo1124Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Olmo1124Model is using Olmo1124SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = attention_mask + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value + + +# The OLMo November 2024 layers are identical to those of the OLMo model except: +# - RMSNorm is used instead of standard layer norm. +# - Norm is applied after attention/feedforward rather than before. +class Olmo1124DecoderLayer(OlmoDecoderLayer): + def __init__(self, config: Olmo1124Config, layer_idx: int): + super().__init__(config, layer_idx=layer_idx) + self.post_attention_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + del self.input_layernorm + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + return outputs + + +class Olmo1124PreTrainedModel(OlmoPreTrainedModel): + pass + + +# The OLMo November 2024 model is identical to the OLMo model, except RMSNorm is used instead of +# standard layer norm for the output norm. +class Olmo1124Model(OlmoModel): + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.layers = nn.ModuleList( + [Olmo1124DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +# The heads now only need to redefine the model inside to the correct `RobertaModel` +class Olmo1124ForCausalLM(OlmoForCausalLM): + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.model = Olmo1124Model(config) + + +__all__ = [ + "Olmo1124Config", + "Olmo1124ForCausalLM", + "Olmo1124Model", + "Olmo1124PreTrainedModel", +] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 10fd0aa020391f..52627327a17907 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6786,6 +6786,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Olmo1124ForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Olmo1124Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Olmo1124PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class OlmoeForCausalLM(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/olmo_1124/__init__.py b/tests/models/olmo_1124/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/olmo_1124/test_modeling_olmo_1124.py b/tests/models/olmo_1124/test_modeling_olmo_1124.py new file mode 100644 index 00000000000000..9dad7932f1a570 --- /dev/null +++ b/tests/models/olmo_1124/test_modeling_olmo_1124.py @@ -0,0 +1,468 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch OLMo November 2024 model.""" + +import unittest + +from packaging import version +from parameterized import parameterized + +from transformers import Olmo1124Config, is_torch_available, set_seed +from transformers.generation.configuration_utils import GenerationConfig +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.testing_utils import ( + require_tokenizers, + require_torch, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + Olmo1124ForCausalLM, + Olmo1124Model, + ) + + +class Olmo1124ModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="silu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return Olmo1124Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = Olmo1124Model(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = Olmo1124Model(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = Olmo1124ForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = Olmo1124ForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class Olmo1124ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Olmo1124Model, Olmo1124ForCausalLM) if is_torch_available() else () + all_generative_model_classes = (Olmo1124ForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": Olmo1124Model, + "text-generation": Olmo1124ForCausalLM, + } + if is_torch_available() + else {} + ) + test_pruning = False + fx_compatible = False + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + def setUp(self): + self.model_tester = Olmo1124ModelTester(self) + self.config_tester = ConfigTester(self, config_class=Olmo1124Config, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="OLMo November 2024 does not support head pruning.") + def test_headmasking(self): + pass + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="OLMo November 2024 buffers include complex numbers, which breaks this test") + def test_save_load_fast_init_from_base(self): + pass + + @parameterized.expand([("linear",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = Olmo1124Model(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = Olmo1124Model(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + +@require_torch +class Olmo1124IntegrationTest(unittest.TestCase): + @slow + def test_model_7b_logits(self): + input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]] + model = Olmo1124ForCausalLM.from_pretrained("shanearora/OLMo-7B-1124-hf", device_map="auto") + out = model(torch.tensor(input_ids)).logits.float() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor( + [[-13.0244, -13.9564, -11.8270, -11.3047, -12.3794, -12.4215, -15.6030, -12.7962]] + ) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([-5.3909, -13.9841, -13.6123, -14.5780, -13.9455, -13.2265, -13.4734, -11.9079, -9.2879, -12.6139, -11.4819, -5.9607, -11.9657, -6.3618, -11.1065, -7.3075, -6.5674, -6.7154, -7.3409, -7.9662, -8.0863, -8.1682, -8.7341, -8.7665, -8.8742, -9.7813, -8.0620, -12.5937, -7.6440, -11.3966]) # fmt: skip + torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-2, rtol=1e-2) + + @slow + def test_model_7b_greedy_generation(self): + EXPECTED_TEXT_COMPLETION = """Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the fastest speed possible, and 3) the speed of light is the same for all observers, regardless of their relative motion. The theory of relativity is based on the idea that the speed of light is constant. This means that""" + prompt = "Simply put, the theory of relativity states that " + tokenizer = AutoTokenizer.from_pretrained("shanearora/OLMo-7B-1124-hf", device_map="auto") + model = Olmo1124ForCausalLM.from_pretrained("shanearora/OLMo-7B-1124-hf", device_map="auto") + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @require_tokenizers + def test_simple_encode_decode(self): + rust_tokenizer = AutoTokenizer.from_pretrained("shanearora/OLMo-7B-1124-hf") + + self.assertEqual(rust_tokenizer.encode("This is a test"), [2028, 374, 264, 1296]) + self.assertEqual(rust_tokenizer.decode([2028, 374, 264, 1296], skip_special_tokens=True), "This is a test") + + # bytefallback showcase + self.assertEqual(rust_tokenizer.encode("生活的真谛是"), [21990, 76706, 9554, 89151, 39013, 249, 21043]) # fmt: skip + self.assertEqual( + rust_tokenizer.decode([21990, 76706, 9554, 89151, 39013, 249, 21043], skip_special_tokens=True), + "生活的真谛是", + ) + + # Inner spaces showcase + self.assertEqual(rust_tokenizer.encode("Hi Hello"), [13347, 220, 22691]) + self.assertEqual(rust_tokenizer.decode([13347, 220, 22691], skip_special_tokens=True), "Hi Hello") + + self.assertEqual(rust_tokenizer.encode("Hi Hello"), [13347, 256, 22691]) + self.assertEqual(rust_tokenizer.decode([13347, 256, 22691], skip_special_tokens=True), "Hi Hello") + + self.assertEqual(rust_tokenizer.encode(""), []) + + self.assertEqual(rust_tokenizer.encode(" "), [220]) + + self.assertEqual(rust_tokenizer.encode(" "), [256]) + + self.assertEqual(rust_tokenizer.encode(" Hello"), [22691]) + + @slow + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + olmo_1124_model = "shanearora/OLMo-7B-1124-hf" + + tokenizer = AutoTokenizer.from_pretrained(olmo_1124_model, pad_token="", padding_side="right") + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light", + ] + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + generation_config = GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + }, + ) + model = Olmo1124ForCausalLM.from_pretrained( + olmo_1124_model, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=generation_config, + ) + + prompts = ["Simply put, the theory of relativity states that "] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + eager + eager_generated_ids = model.generate( + **prompt_tokens, max_new_tokens=max_new_tokens, do_sample=False, cache_implementation=cache_implementation + ) + eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text) + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text)