From 8e0ccca30892c6a916efabb08bbad38512621f2d Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 15:44:23 -0700 Subject: [PATCH 01/22] Add model skeletion with transformers-cli add-new-model-like --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/olmo-1124.md | 48 + src/transformers/__init__.py | 14 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/auto/tokenization_auto.py | 1 + src/transformers/models/olmo_1124/__init__.py | 57 + .../olmo_1124/configuration_olmo_1124.py | 181 +++ .../convert_olmo_1124_weights_to_hf.py | 248 ++++ .../models/olmo_1124/modeling_olmo_1124.py | 1149 +++++++++++++++++ tests/models/olmo_1124/__init__.py | 0 .../olmo_1124/test_modeling_olmo_1124.py | 498 +++++++ 13 files changed, 2203 insertions(+) create mode 100644 docs/source/en/model_doc/olmo-1124.md create mode 100644 src/transformers/models/olmo_1124/__init__.py create mode 100644 src/transformers/models/olmo_1124/configuration_olmo_1124.py create mode 100644 src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py create mode 100644 src/transformers/models/olmo_1124/modeling_olmo_1124.py create mode 100644 tests/models/olmo_1124/__init__.py create mode 100644 tests/models/olmo_1124/test_modeling_olmo_1124.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a7806059afaa59..193af77f5693ea 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -514,6 +514,8 @@ title: NystrΓΆmformer - local: model_doc/olmo title: OLMo + - local: model_doc/olmo-1124 + title: OLMo November 2024 - local: model_doc/olmoe title: OLMoE - local: model_doc/open-llama diff --git a/docs/source/en/model_doc/olmo-1124.md b/docs/source/en/model_doc/olmo-1124.md new file mode 100644 index 00000000000000..8388f14798ac87 --- /dev/null +++ b/docs/source/en/model_doc/olmo-1124.md @@ -0,0 +1,48 @@ + + +# OLMo November 2024 + +## Overview + +The OLMo November 2024 model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## Olmo1124Config + +[[autodoc]] Olmo1124Config + +## Olmo1124Model + +[[autodoc]] Olmo1124Model + - forward + +## Olmo1124ForCausalLM + +[[autodoc]] Olmo1124ForCausalLM + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cc8b07395024a8..8d5f9f75fc0b2a 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -620,6 +620,7 @@ "models.nougat": ["NougatProcessor"], "models.nystromformer": ["NystromformerConfig"], "models.olmo": ["OlmoConfig"], + "models.olmo_1124": ["Olmo1124Config"], "models.olmoe": ["OlmoeConfig"], "models.omdet_turbo": [ "OmDetTurboConfig", @@ -2918,6 +2919,13 @@ "OlmoPreTrainedModel", ] ) + _import_structure["models.olmo_1124"].extend( + [ + "Olmo1124ForCausalLM", + "Olmo1124Model", + "Olmo1124PreTrainedModel", + ] + ) _import_structure["models.olmoe"].extend( [ "OlmoeForCausalLM", @@ -5505,6 +5513,7 @@ NystromformerConfig, ) from .models.olmo import OlmoConfig + from .models.olmo_1124 import Olmo1124Config from .models.olmoe import OlmoeConfig from .models.omdet_turbo import ( OmDetTurboConfig, @@ -7521,6 +7530,11 @@ OlmoModel, OlmoPreTrainedModel, ) + from .models.olmo_1124 import ( + Olmo1124ForCausalLM, + Olmo1124Model, + Olmo1124PreTrainedModel, + ) from .models.olmoe import ( OlmoeForCausalLM, OlmoeModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 9155f629e63f91..0d4b9f2f94de9b 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -177,6 +177,7 @@ nougat, nystromformer, olmo, + olmo_1124, olmoe, omdet_turbo, oneformer, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 48625ea3f346cd..7769edb6608dba 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -195,6 +195,7 @@ ("nougat", "VisionEncoderDecoderConfig"), ("nystromformer", "NystromformerConfig"), ("olmo", "OlmoConfig"), + ("olmo-1124", "Olmo1124Config"), ("olmoe", "OlmoeConfig"), ("omdet-turbo", "OmDetTurboConfig"), ("oneformer", "OneFormerConfig"), @@ -510,6 +511,7 @@ ("nougat", "Nougat"), ("nystromformer", "NystrΓΆmformer"), ("olmo", "OLMo"), + ("olmo-1124", "OLMo November 2024"), ("olmoe", "OLMoE"), ("omdet-turbo", "OmDet-Turbo"), ("oneformer", "OneFormer"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 67c539fca66496..572ec81d6c5e56 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -184,6 +184,7 @@ ("nllb-moe", "NllbMoeModel"), ("nystromformer", "NystromformerModel"), ("olmo", "OlmoModel"), + ("olmo-1124", "Olmo1124Model"), ("olmoe", "OlmoeModel"), ("omdet-turbo", "OmDetTurboForObjectDetection"), ("oneformer", "OneFormerModel"), @@ -516,6 +517,7 @@ ("mvp", "MvpForCausalLM"), ("nemotron", "NemotronForCausalLM"), ("olmo", "OlmoForCausalLM"), + ("olmo-1124", "Olmo1124ForCausalLM"), ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 7674ea51a53377..e4a4533b48ea7f 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -348,6 +348,7 @@ ), ), ("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("olmo-1124", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ( "omdet-turbo", diff --git a/src/transformers/models/olmo_1124/__init__.py b/src/transformers/models/olmo_1124/__init__.py new file mode 100644 index 00000000000000..93e8177b923b43 --- /dev/null +++ b/src/transformers/models/olmo_1124/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_olmo_1124": ["Olmo1124Config"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_olmo_1124"] = [ + "Olmo1124ForCausalLM", + "Olmo1124Model", + "Olmo1124PreTrainedModel", + ] + +if TYPE_CHECKING: + from .configuration_olmo_1124 import Olmo1124Config + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_olmo_1124 import ( + Olmo1124ForCausalLM, + Olmo1124Model, + Olmo1124PreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/olmo_1124/configuration_olmo_1124.py b/src/transformers/models/olmo_1124/configuration_olmo_1124.py new file mode 100644 index 00000000000000..942dd816664786 --- /dev/null +++ b/src/transformers/models/olmo_1124/configuration_olmo_1124.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OLMo November 2024 model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class Olmo1124Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Olmo1124Model`]. It is used to instantiate an OLMo November 2024 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/OLMo November 2024-7B-hf](https://huggingface.co/allenai/OLMo November 2024-7B-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the OLMo November 2024 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Olmo1124Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + clip_qkv (`float`, *optional*): + If not `None`, elements of query, key and value attention states are clipped so that their + absolute value does not exceed this value. + + ```python + >>> from transformers import Olmo1124Model, Olmo1124Config + + >>> # Initializing a OLMo November 2024 7B style configuration + >>> configuration = Olmo1124Config() + + >>> # Initializing a model from the OLMo November 2024 7B style configuration + >>> model = Olmo1124Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "olmo-1124" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + clip_qkv=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.clip_qkv = clip_qkv + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py new file mode 100644 index 00000000000000..c021940b7e8b36 --- /dev/null +++ b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py @@ -0,0 +1,248 @@ +# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os +import shutil +from pathlib import Path + +import torch +import yaml +from tokenizers import Tokenizer + +from transformers import Olmo1124Config, Olmo1124ForCausalLM +from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast + + +""" +Sample usage: + +``` +python src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py \ + --input_dir /path/to/downloaded/olmo_1124/weights --model_size 7B --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import Olmo1124ForCausalLM, AutoTokenizer + +model = Olmo1124ForCausalLM.from_pretrained("/output/path") +tokenizer = AutoTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model(model_path, input_base_path, tokenizer_path=None, safe_serialization=True, fix_eos_token_id=True): + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + config_path = Path(input_base_path) / "config.yaml" + olmo_1124_config = yaml.safe_load(config_path.read_text())["model"] + + n_layers = olmo_1124_config["n_layers"] + n_heads = olmo_1124_config["n_heads"] + dim = olmo_1124_config["d_model"] + dims_per_head = dim // n_heads + base = 10000.0 + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + max_position_embeddings = olmo_1124_config["max_sequence_length"] + + vocab_size = olmo_1124_config.get("embedding_size", olmo_1124_config["vocab_size"]) + + if olmo_1124_config.get("n_kv_heads", None) is not None: + num_key_value_heads = olmo_1124_config["n_kv_heads"] # for GQA / MQA + elif olmo_1124_config["multi_query_attention"]: # compatibility with other checkpoints + num_key_value_heads = 1 + else: + num_key_value_heads = n_heads + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu") + + param_count = 0 + index_dict = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + # Unsharded + # TODO: Layernorm stuff + # TODO: multi query attention + fused_dims = [dim, dims_per_head * num_key_value_heads, dims_per_head * num_key_value_heads] + q_proj_weight, k_proj_weight, v_proj_weight = torch.split( + loaded[f"transformer.blocks.{layer_i}.att_proj.weight"], fused_dims, dim=0 + ) + up_proj_weight, gate_proj_weight = torch.chunk( + loaded[f"transformer.blocks.{layer_i}.ff_proj.weight"], 2, dim=0 + ) + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": q_proj_weight, + f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight, + f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight, + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight, + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight, + } + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + + # Unsharded + # TODO: Deal with weight-tying + state_dict = { + "model.embed_tokens.weight": loaded["transformer.wte.weight"], + "lm_head.weight": loaded["transformer.ff_out.weight"] + if "transformer.ff_out.weight" in loaded + else loaded["transformer.wte.weight"], + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + + if olmo_1124_config.get("mlp_hidden_size", None) is not None: + intermediate_size = olmo_1124_config["mlp_hidden_size"] // 2 + else: + intermediate_size = (dim * olmo_1124_config["mlp_ratio"]) // 2 + + config = Olmo1124Config( + vocab_size=vocab_size, + hidden_size=dim, + intermediate_size=intermediate_size, + num_hidden_layers=n_layers, + num_attention_heads=n_heads, + num_key_value_heads=num_key_value_heads, + max_position_embeddings=max_position_embeddings, + pad_token_id=olmo_1124_config["pad_token_id"], + bos_token_id=None, + eos_token_id=olmo_1124_config["eos_token_id"], + tie_word_embeddings=olmo_1124_config["weight_tying"], + rope_theta=base, + clip_qkv=olmo_1124_config.get("clip_qkv"), + ) + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + if tokenizer_path is not None: + _write_tokenizer(model_path, config, tokenizer_path, fix_eos_token_id) + + print("Loading the checkpoint in a OLMo November 2024 model.") + model = Olmo1124ForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True) + # Avoid saving this as part of the config. + del model.config._name_or_path + print("Saving in the Transformers format.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path) + + +def _write_tokenizer( + output_path: Path, config: Olmo1124Config, input_tokenizer_path: Path, fix_eos_token_id: bool = True +) -> None: + print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.") + + base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path)) + + eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1 + pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id + + if fix_eos_token_id and eos_token_id == 0: + # Fixing a bug in OLMo November 2024 where eos token id was incorrectly set + print("Changing eos_token_id from 0 to 50279.") + eos_token_id = 50279 + + tokenizer = GPTNeoXTokenizerFast( + tokenizer_object=base_tokenizer, + eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False), + pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False), + unk_token=None, + bos_token=None, + ) + + tokenizer.save_pretrained(output_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + required=True, + help="Location of OLMo November 2024 weights, which contains config.yaml and model.pt.", + ) + parser.add_argument( + "--tokenizer_json_path", + default=None, + help="Location of OLMo November 2024 tokenizer json file.", + ) + parser.add_argument( + "--output_dir", + required=True, + help="Location to write HF model and tokenizer", + ) + parser.add_argument( + "--no_fix_eos_token_id", + action="store_false", + dest="fix_eos_token_id", + help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.", + ) + parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + # Different OLMo November 2024 versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. + args = parser.parse_args() + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + safe_serialization=args.safe_serialization, + tokenizer_path=args.tokenizer_json_path, + fix_eos_token_id=args.fix_eos_token_id, + ) + + +if __name__ == "__main__": + main() diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py new file mode 100644 index 00000000000000..49c8bf3924d1f6 --- /dev/null +++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py @@ -0,0 +1,1149 @@ +# coding=utf-8 +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OLMo November 2024 model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from .configuration_olmo_1124 import Olmo1124Config + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Olmo1124Config" + + +# Copied from transformers.models.olmo.modeling_olmo.OlmoLayerNorm with Olmo->Olmo1124 +class Olmo1124LayerNorm(nn.Module): + """LayerNorm but with no learnable weight or bias.""" + + def __init__(self, hidden_size: int) -> None: + super().__init__() + self.normalized_shape = (hidden_size,) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_dtype = hidden_states.dtype + return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to( + orig_dtype + ) + + +ALL_LAYERNORM_LAYERS.append(Olmo1124LayerNorm) + + +# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo1124 +# TODO(joao): add me back asap :) +# Copied from transformers.models.olmo.modeling_olmo.OlmoRotaryEmbedding with Olmo->Olmo1124 +class Olmo1124RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo1124 +# TODO(joao): add me back asap :) +# Copied from transformers.models.olmo.modeling_olmo.OlmoLinearScalingRotaryEmbedding with Olmo->Olmo1124 +class Olmo1124LinearScalingRotaryEmbedding(Olmo1124RotaryEmbedding): + """Olmo1124RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo1124 +# TODO(joao): add me back asap :) +# Copied from transformers.models.olmo.modeling_olmo.OlmoDynamicNTKScalingRotaryEmbedding with Olmo->Olmo1124 +class Olmo1124DynamicNTKScalingRotaryEmbedding(Olmo1124RotaryEmbedding): + """Olmo1124RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.olmo.modeling_olmo.OlmoMLP with Olmo->Olmo1124 +class Olmo1124MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.olmo.modeling_olmo.OlmoAttention with Olmo->Olmo1124 +class Olmo1124Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo1124 + # TODO(joao): add me back asap :) + def __init__(self, config: Olmo1124Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = Olmo1124RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = Olmo1124LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = Olmo1124DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.olmo.modeling_olmo.OlmoFlashAttention2 with Olmo->Olmo1124,OLMo->OLMo November 2024 +class Olmo1124FlashAttention2(Olmo1124Attention): + """ + OLMo November 2024 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Olmo1124RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.olmo.modeling_olmo.OlmoSdpaAttention with Olmo->Olmo1124,OLMo->OLMo November 2024 +class Olmo1124SdpaAttention(Olmo1124Attention): + """ + OLMo November 2024 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Olmo1124Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Olmo1124Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Olmo1124Model is using Olmo1124SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +OLMO_1124_ATTENTION_CLASSES = { + "eager": Olmo1124Attention, + "flash_attention_2": Olmo1124FlashAttention2, + "sdpa": Olmo1124SdpaAttention, +} + + +# Copied from transformers.models.olmo.modeling_olmo.OlmoDecoderLayer with OLMO->OLMO_1124,Olmo->Olmo1124 +class Olmo1124DecoderLayer(nn.Module): + def __init__(self, config: Olmo1124Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = OLMO_1124_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = Olmo1124MLP(config) + self.input_layernorm = Olmo1124LayerNorm(config.hidden_size) + self.post_attention_layernorm = Olmo1124LayerNorm(config.hidden_size) + + # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward + # TODO(joao): add me back asap :) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +OLMO_1124_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Olmo1124Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Olmo1124 Model outputting raw hidden-states without any specific head on top.", + OLMO_1124_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmo1124 +class Olmo1124PreTrainedModel(PreTrainedModel): + config_class = Olmo1124Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Olmo1124DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +OLMO_1124_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Olmo1124 Model outputting raw hidden-states without any specific head on top.", + OLMO_1124_START_DOCSTRING, +) +# Copied from transformers.models.olmo.modeling_olmo.OlmoModel with OLMO->OLMO_1124,Olmo->Olmo1124 +class Olmo1124Model(Olmo1124PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Olmo1124DecoderLayer`] + + Args: + config: Olmo1124Config + """ + + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Olmo1124DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Olmo1124LayerNorm(config.hidden_size) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(OLMO_1124_INPUTS_DOCSTRING) + # copied from transformers.models.llama.modeling_llama.LlamaModel.forward + # TODO(joao): add me back asap :) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO_1124,Llama->Olmo1124 +class Olmo1124ForCausalLM(Olmo1124PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = Olmo1124Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(OLMO_1124_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Olmo1124ForCausalLM + + >>> model = Olmo1124ForCausalLM.from_pretrained("allenai/OLMo-7B-1124-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-7B-1124-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/tests/models/olmo_1124/__init__.py b/tests/models/olmo_1124/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/olmo_1124/test_modeling_olmo_1124.py b/tests/models/olmo_1124/test_modeling_olmo_1124.py new file mode 100644 index 00000000000000..af56c475a3005b --- /dev/null +++ b/tests/models/olmo_1124/test_modeling_olmo_1124.py @@ -0,0 +1,498 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the PyTorch OLMo November 2024 model.""" + +import unittest + +from packaging import version +from parameterized import parameterized + +from transformers import Olmo1124Config, is_torch_available, set_seed +from transformers.generation.configuration_utils import GenerationConfig +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast +from transformers.testing_utils import ( + require_tokenizers, + require_torch, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + Olmo1124ForCausalLM, + Olmo1124Model, + ) + + +class Olmo1124ModelTester: + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=False, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + num_attention_heads=4, + intermediate_size=37, + hidden_act="silu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + pad_token_id=0, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.pad_token_id = pad_token_id + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = torch.tril(torch.ones_like(input_ids).to(torch_device)) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return Olmo1124Config( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + is_decoder=False, + initializer_range=self.initializer_range, + pad_token_id=self.pad_token_id, + ) + + def create_and_check_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = Olmo1124Model(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask) + result = model(input_ids) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_as_decoder( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.add_cross_attention = True + model = Olmo1124Model(config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + ) + result = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + ) + result = model(input_ids, attention_mask=input_mask) + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + model = Olmo1124ForCausalLM(config=config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=input_mask, labels=token_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_decoder_model_past_large_inputs( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ): + config.is_decoder = True + config.add_cross_attention = True + model = Olmo1124ForCausalLM(config=config) + model.to(torch_device) + model.eval() + + # first forward pass + outputs = model( + input_ids, + attention_mask=input_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + ) + past_key_values = outputs.past_key_values + + # create hypothetical multiple next token and extent to next_input_ids + next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) + next_mask = ids_tensor((self.batch_size, 3), vocab_size=2) + + # append to next input_ids and + next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) + next_attention_mask = torch.cat([input_mask, next_mask], dim=-1) + + output_from_no_past = model( + next_input_ids, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_hidden_states=True, + )["hidden_states"][0] + output_from_past = model( + next_tokens, + attention_mask=next_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + output_hidden_states=True, + )["hidden_states"][0] + + # select random slice + random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() + output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach() + output_from_past_slice = output_from_past[:, :, random_slice_idx].detach() + + self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) + + # test that outputs are equal for slice + self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@require_torch +class Olmo1124ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (Olmo1124Model, Olmo1124ForCausalLM) if is_torch_available() else () + all_generative_model_classes = (Olmo1124ForCausalLM,) if is_torch_available() else () + test_pruning = False + fx_compatible = False + + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + def setUp(self): + self.model_tester = Olmo1124ModelTester(self) + self.config_tester = ConfigTester(self, config_class=Olmo1124Config, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="OLMo November 2024 does not support head pruning.") + def test_headmasking(self): + pass + + def test_model_various_embeddings(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + for type in ["absolute", "relative_key", "relative_key_query"]: + config_and_inputs[0].position_embedding_type = type + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skip(reason="OLMo November 2024 buffers include complex numbers, which breaks this test") + def test_save_load_fast_init_from_base(self): + pass + + @parameterized.expand([("linear",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = Olmo1124Model(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = Olmo1124Model(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + + +@require_torch +class Olmo1124IntegrationTest(unittest.TestCase): + @slow + def test_model_1b_logits(self): + input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]] + model = Olmo1124ForCausalLM.from_pretrained("allenai/OLMo-7B-1124-hf", device_map="auto") + out = model(torch.tensor(input_ids)).logits.float() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[2.2869, 0.3315, 0.9876, 1.4146, 1.8804, 2.0430, 1.7055, 1.2065]]) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([2.5551, -1.1230, 11.0510, 12.4977, 7.9651, 7.2342, 6.1885, 7.8340, 9.9847, 12.6695, 12.2345, 10.7970, 8.4749, 14.2483, 12.9588, 13.9233, 11.0496, 5.5749, 7.4466, 7.7914, 6.8440, 5.8951, 4.8180, 4.1935, 4.5216, 4.7256, 3.9553, 12.2870, 12.4990, 8.1591]) # fmt: skip + torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-2, rtol=1e-2) + + @slow + def test_model_7b_logits(self): + input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]] + model = Olmo1124ForCausalLM.from_pretrained("allenai/OLMo November 2024-7B-hf", device_map="auto") + out = model(torch.tensor(input_ids)).logits.float() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[0.0271, 0.0249, -0.0578, -0.0870, 0.0167, 0.0710, 0.1002, 0.0677]]) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([-1.7433, -1.6685, 7.4941, 6.1506, 0.1364, -0.1127, 1.3224, 4.5458, 4.2068, 5.8296, 7.4723, 2.7925, 3.1245, 10.8872, 10.0758, 10.6717, 7.0945, 1.2398, 3.6766, 4.2365, 2.5655, 2.2222, 1.7418, 0.5223, 0.7753, 1.0938, 0.6723, 6.2522, 6.2264, 1.8105]) # fmt: skip + torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-2, rtol=1e-2) + + @slow + def test_model_7b_twin_2t_logits(self): + input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]] + model = Olmo1124ForCausalLM.from_pretrained("allenai/OLMo November 2024-7B-Twin-2T-hf", device_map="auto") + out = model(torch.tensor(input_ids)).logits.float() + # Expected mean on dim = -1 + EXPECTED_MEAN = torch.tensor([[-0.3636, -0.3825, -0.4800, -0.3696, -0.8388, -0.9737, -0.9849, -0.8356]]) + torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) + # slicing logits[0, 0, 0:30] + EXPECTED_SLICE = torch.tensor([-2.0833, -1.9234, 8.7312, 7.8049, 1.0372, 0.8941, 3.1548, 1.8502, 5.5511, 5.5793, 8.1166, 4.5906, 1.8691, 11.6377, 8.9858, 11.6447, 7.4549, 1.4725, 2.8399, 2.7568, 1.4011, 1.6958, 0.5572, 0.5231, 0.3068, 0.5364, 0.6769, 7.9636, 8.2379, 1.7950]) # fmt: skip + torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-2, rtol=1e-2) + + @slow + def test_model_7b_greedy_generation(self): + EXPECTED_TEXT_COMPLETION = """Simply put, the theory of relativity states that \nthe speed of light is the same for all observers.\n\nThe theory of relativity is a theory of physics that describes the \nmovement of objects in space and time.\n\nThe theory of relativity is a theory of physics that describes the \nmovement of objects in space and time.\n\n""" + prompt = "Simply put, the theory of relativity states that " + tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo November 2024-7B-hf", device_map="auto") + input_ids = tokenizer.encode(prompt, return_tensors="pt") + model = Olmo1124ForCausalLM.from_pretrained("allenai/OLMo November 2024-7B-hf", device_map="auto") + + # greedy generation outputs + generated_ids = model.generate(input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False) + text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, text) + + @require_tokenizers + def test_fast_special_tokens(self): + fast_tokenizer = GPTNeoXTokenizerFast.from_pretrained("allenai/OLMo-7B-1124-hf") + + original_add_eos_token = fast_tokenizer.add_eos_token + + fast_tokenizer.add_eos_token = False + fast = fast_tokenizer.encode("A sample test") + self.assertEqual(fast, [34, 3410, 1071]) + + fast_tokenizer.add_eos_token = True + fast = fast_tokenizer.encode("A sample test") + self.assertEqual(fast, [34, 3410, 1071, 50279]) + + fast_tokenizer.add_eos_token = original_add_eos_token + + @require_tokenizers + def test_simple_encode_decode(self): + rust_tokenizer = GPTNeoXTokenizerFast.from_pretrained("allenai/OLMo-7B-1124-hf") + + self.assertEqual(rust_tokenizer.encode("This is a test"), [1552, 310, 247, 1071]) + self.assertEqual(rust_tokenizer.decode([1552, 310, 247, 1071], skip_special_tokens=True), "This is a test") + + # bytefallback showcase + self.assertEqual(rust_tokenizer.encode("η”Ÿζ΄»ηš„ηœŸθ°›ζ˜―"), [20025, 46549, 5225, 48561, 33656, 238, 12105]) # fmt: skip + self.assertEqual( + rust_tokenizer.decode([20025, 46549, 5225, 48561, 33656, 238, 12105], skip_special_tokens=True), + "η”Ÿζ΄»ηš„ηœŸθ°›ζ˜―", + ) + + # Inner spaces showcase + self.assertEqual(rust_tokenizer.encode("Hi Hello"), [12764, 50276, 12092]) + self.assertEqual(rust_tokenizer.decode([12764, 50276, 12092], skip_special_tokens=True), "Hi Hello") + + self.assertEqual(rust_tokenizer.encode("Hi Hello"), [12764, 50275, 12092]) + self.assertEqual(rust_tokenizer.decode([12764, 50275, 12092], skip_special_tokens=True), "Hi Hello") + + self.assertEqual(rust_tokenizer.encode(""), []) + + self.assertEqual(rust_tokenizer.encode(" "), [209]) + + self.assertEqual(rust_tokenizer.encode(" "), [50276]) + + self.assertEqual(rust_tokenizer.encode(" Hello"), [24387]) + + @slow + def test_export_static_cache(self): + if version.parse(torch.__version__) < version.parse("2.4.0"): + self.skipTest(reason="This test requires torch >= 2.4 to run.") + + from transformers.integrations.executorch import ( + TorchExportableModuleWithStaticCache, + convert_and_export_with_cache, + ) + + olmo_1124_model = "allenai/OLMo-7B-1124-hf" + + tokenizer = AutoTokenizer.from_pretrained(olmo_1124_model, pad_token="", padding_side="right") + EXPECTED_TEXT_COMPLETION = [ + "Simply put, the theory of relativity states that \nthe speed of light is the same in all reference frames.\n\nThe speed of light", + ] + max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ + "input_ids" + ].shape[-1] + + # Load model + device = "cpu" + dtype = torch.bfloat16 + cache_implementation = "static" + attn_implementation = "sdpa" + batch_size = 1 + model = Olmo1124ForCausalLM.from_pretrained( + olmo_1124_model, + device_map=device, + torch_dtype=dtype, + attn_implementation=attn_implementation, + generation_config=GenerationConfig( + use_cache=True, + cache_implementation=cache_implementation, + max_length=max_generation_length, + cache_config={ + "batch_size": batch_size, + "max_cache_len": max_generation_length, + }, + ), + ) + + prompts = ["Simply put, the theory of relativity states that "] + prompt_tokens = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + prompt_token_ids = prompt_tokens["input_ids"] + max_new_tokens = max_generation_length - prompt_token_ids.shape[-1] + + # Static Cache + eager + eager_generated_ids = model.generate( + **prompt_tokens, max_new_tokens=max_new_tokens, do_sample=False, cache_implementation=cache_implementation + ) + eager_generated_text = tokenizer.batch_decode(eager_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, eager_generated_text) + + # Static Cache + export + exported_program = convert_and_export_with_cache(model) + ep_generated_ids = TorchExportableModuleWithStaticCache.generate( + exported_program=exported_program, prompt_token_ids=prompt_token_ids, max_new_tokens=max_new_tokens + ) + ep_generated_text = tokenizer.batch_decode(ep_generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, ep_generated_text) From d6b18e26fcec32703cf5f1684018335e933f7929 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 15:56:26 -0700 Subject: [PATCH 02/22] Convert config to modular, add rms_norm_eps, delete clip_qkv --- .../olmo_1124/configuration_olmo_1124.py | 52 ++++++---------- .../models/olmo_1124/modular_olmo_1124.py | 60 +++++++++++++++++++ 2 files changed, 79 insertions(+), 33 deletions(-) create mode 100644 src/transformers/models/olmo_1124/modular_olmo_1124.py diff --git a/src/transformers/models/olmo_1124/configuration_olmo_1124.py b/src/transformers/models/olmo_1124/configuration_olmo_1124.py index 942dd816664786..10ac098e6d8a9c 100644 --- a/src/transformers/models/olmo_1124/configuration_olmo_1124.py +++ b/src/transformers/models/olmo_1124/configuration_olmo_1124.py @@ -1,36 +1,18 @@ -# coding=utf-8 -# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""OLMo November 2024 model configuration""" +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/olmo_1124/modular_olmo_1124.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_olmo_1124.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 from ...configuration_utils import PretrainedConfig -from ...utils import logging - - -logger = logging.get_logger(__name__) class Olmo1124Config(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`Olmo1124Model`]. It is used to instantiate an OLMo November 2024 + This is the configuration class to store the configuration of a [`Olmo1124Model`]. It is used to instantiate an Olmo1124 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the [allenai/OLMo November 2024-7B-hf](https://huggingface.co/allenai/OLMo November 2024-7B-hf). + defaults will yield a similar configuration to that of the [allenai/Olmo1124-7B-hf](https://huggingface.co/allenai/Olmo1124-7B-hf). Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -38,7 +20,7 @@ class Olmo1124Config(PretrainedConfig): Args: vocab_size (`int`, *optional*, defaults to 50304): - Vocabulary size of the OLMo November 2024 model. Defines the number of different tokens that can be represented by the + Vocabulary size of the Olmo1124 model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`Olmo1124Model`] hidden_size (`int`, *optional*, defaults to 4096): Dimension of the hidden representations. @@ -94,17 +76,20 @@ class Olmo1124Config(PretrainedConfig): ```python >>> from transformers import Olmo1124Model, Olmo1124Config - >>> # Initializing a OLMo November 2024 7B style configuration + >>> # Initializing a Olmo1124 7B style configuration >>> configuration = Olmo1124Config() - >>> # Initializing a model from the OLMo November 2024 7B style configuration + >>> # Initializing a model from the Olmo1124 7B style configuration >>> model = Olmo1124Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config - ```""" + ``` + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + """ - model_type = "olmo-1124" + model_type = "olmo_1124" keys_to_ignore_at_inference = ["past_key_values"] def __init__( @@ -127,7 +112,7 @@ def __init__( rope_scaling=None, attention_bias=False, attention_dropout=0.0, - clip_qkv=None, + rms_norm_eps=1e-6, **kwargs, ): self.vocab_size = vocab_size @@ -150,7 +135,8 @@ def __init__( self._rope_scaling_validation() self.attention_bias = attention_bias self.attention_dropout = attention_dropout - self.clip_qkv = clip_qkv + + self.rms_norm_eps = rms_norm_eps super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py new file mode 100644 index 00000000000000..f2d269796a3a6f --- /dev/null +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -0,0 +1,60 @@ +from ...utils import logging +from ..olmo.configuration_olmo import OlmoConfig + + +logger = logging.get_logger(__name__) + + +class Olmo1124Config(OlmoConfig): + r""" + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + """ + + def __init__( + self, + vocab_size=50304, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + use_cache=True, + pad_token_id=1, + bos_token_id=None, + eos_token_id=50279, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + rms_norm_eps=1e-6, + **kwargs, + ): + super().__init__( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + hidden_act=hidden_act, + max_position_embeddings=max_position_embeddings, + initializer_range=initializer_range, + use_cache=use_cache, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + **kwargs, + ) + + self.rms_norm_eps = rms_norm_eps + del self.clip_qkv From 6605eb9c13f865b2f221ae56f9b2a02ee95ae056 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 15:58:50 -0700 Subject: [PATCH 03/22] Convert model to modular, add RMSNorm --- .../models/olmo_1124/modeling_olmo_1124.py | 1163 +---------------- .../models/olmo_1124/modular_olmo_1124.py | 9 + 2 files changed, 31 insertions(+), 1141 deletions(-) diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py index 49c8bf3924d1f6..fad03273e9fb93 100644 --- a/src/transformers/models/olmo_1124/modeling_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py @@ -1,1149 +1,30 @@ -# coding=utf-8 -# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch OLMo November 2024 model.""" - -import math -from typing import List, Optional, Tuple, Union +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/olmo_1124/modular_olmo_1124.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_olmo_1124.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache -from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal_2_10, - logging, - replace_return_docstrings, -) -from .configuration_olmo_1124 import Olmo1124Config - - -if is_flash_attn_2_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - -logger = logging.get_logger(__name__) - -_CONFIG_FOR_DOC = "Olmo1124Config" - - -# Copied from transformers.models.olmo.modeling_olmo.OlmoLayerNorm with Olmo->Olmo1124 -class Olmo1124LayerNorm(nn.Module): - """LayerNorm but with no learnable weight or bias.""" - - def __init__(self, hidden_size: int) -> None: - super().__init__() - self.normalized_shape = (hidden_size,) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - orig_dtype = hidden_states.dtype - return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to( - orig_dtype - ) - - -ALL_LAYERNORM_LAYERS.append(Olmo1124LayerNorm) - - -# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo1124 -# TODO(joao): add me back asap :) -# Copied from transformers.models.olmo.modeling_olmo.OlmoRotaryEmbedding with Olmo->Olmo1124 -class Olmo1124RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - super().__init__() - self.scaling_factor = scaling_factor - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings - - @torch.no_grad() - def forward(self, x, position_ids): - # x: [bs, num_attention_heads, seq_len, head_size] - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 since bfloat16 loses precision on long contexts - # See https://github.com/huggingface/transformers/pull/29285 - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - -# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo1124 -# TODO(joao): add me back asap :) -# Copied from transformers.models.olmo.modeling_olmo.OlmoLinearScalingRotaryEmbedding with Olmo->Olmo1124 -class Olmo1124LinearScalingRotaryEmbedding(Olmo1124RotaryEmbedding): - """Olmo1124RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def forward(self, x, position_ids): - # difference to the original RoPE: a scaling factor is aplied to the position ids - position_ids = position_ids.float() / self.scaling_factor - cos, sin = super().forward(x, position_ids) - return cos, sin - - -# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo1124 -# TODO(joao): add me back asap :) -# Copied from transformers.models.olmo.modeling_olmo.OlmoDynamicNTKScalingRotaryEmbedding with Olmo->Olmo1124 -class Olmo1124DynamicNTKScalingRotaryEmbedding(Olmo1124RotaryEmbedding): - """Olmo1124RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def forward(self, x, position_ids): - # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - - cos, sin = super().forward(x, position_ids) - return cos, sin - - -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.olmo.modeling_olmo.OlmoMLP with Olmo->Olmo1124 -class Olmo1124MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -# Copied from transformers.models.olmo.modeling_olmo.OlmoAttention with Olmo->Olmo1124 -class Olmo1124Attention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo1124 - # TODO(joao): add me back asap :) - def __init__(self, config: Olmo1124Config, layer_idx: Optional[int] = None): - super().__init__() - self.config = config - self.layer_idx = layer_idx - if layer_idx is None: - logger.warning_once( - f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " - "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) - - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" - f" and `num_heads`: {self.num_heads})." - ) - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = Olmo1124RotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.rope_theta, - ) - else: - scaling_type = self.config.rope_scaling["type"] - scaling_factor = self.config.rope_scaling["factor"] - if scaling_type == "linear": - self.rotary_emb = Olmo1124LinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - elif scaling_type == "dynamic": - self.rotary_emb = Olmo1124DynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor, - base=self.rope_theta, - ) - else: - raise ValueError(f"Unknown RoPE scaling type {scaling_type}") - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.config.clip_qkv is not None: - query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Copied from transformers.models.olmo.modeling_olmo.OlmoFlashAttention2 with Olmo->Olmo1124,OLMo->OLMo November 2024 -class Olmo1124FlashAttention2(Olmo1124Attention): - """ - OLMo November 2024 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.config.clip_qkv is not None: - query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (Olmo1124RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -# Copied from transformers.models.olmo.modeling_olmo.OlmoSdpaAttention with Olmo->Olmo1124,OLMo->OLMo November 2024 -class Olmo1124SdpaAttention(Olmo1124Attention): - """ - OLMo November 2024 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from - `Olmo1124Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to - SDPA API. - """ - - # Adapted from Olmo1124Attention.forward - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if output_attentions: - # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Olmo1124Model is using Olmo1124SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - if self.config.clip_qkv is not None: - query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - # if attention_mask is not None and cache_position is not None: - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=is_causal, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - -OLMO_1124_ATTENTION_CLASSES = { - "eager": Olmo1124Attention, - "flash_attention_2": Olmo1124FlashAttention2, - "sdpa": Olmo1124SdpaAttention, -} - - -# Copied from transformers.models.olmo.modeling_olmo.OlmoDecoderLayer with OLMO->OLMO_1124,Olmo->Olmo1124 -class Olmo1124DecoderLayer(nn.Module): - def __init__(self, config: Olmo1124Config, layer_idx: int): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = OLMO_1124_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) - - self.mlp = Olmo1124MLP(config) - self.input_layernorm = Olmo1124LayerNorm(config.hidden_size) - self.post_attention_layernorm = Olmo1124LayerNorm(config.hidden_size) - - # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward - # TODO(joao): add me back asap :) - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: +class Olmo1124RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model + Olmo1124RMSNorm is equivalent to T5LayerNorm """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - **kwargs, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -OLMO_1124_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Olmo1124Config`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Olmo1124 Model outputting raw hidden-states without any specific head on top.", - OLMO_1124_START_DOCSTRING, -) -# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmo1124 -class Olmo1124PreTrainedModel(PreTrainedModel): - config_class = Olmo1124Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Olmo1124DecoderLayer"] - _skip_keys_device_placement = ["past_key_values"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_cache_class = True - _supports_quantized_cache = True - _supports_static_cache = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -OLMO_1124_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - "The bare Olmo1124 Model outputting raw hidden-states without any specific head on top.", - OLMO_1124_START_DOCSTRING, -) -# Copied from transformers.models.olmo.modeling_olmo.OlmoModel with OLMO->OLMO_1124,Olmo->Olmo1124 -class Olmo1124Model(Olmo1124PreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Olmo1124DecoderLayer`] - - Args: - config: Olmo1124Config - """ - - def __init__(self, config: Olmo1124Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [Olmo1124DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = Olmo1124LayerNorm(config.hidden_size) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - @add_start_docstrings_to_model_forward(OLMO_1124_INPUTS_DOCSTRING) - # copied from transformers.models.llama.modeling_llama.LlamaModel.forward - # TODO(joao): add me back asap :) - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - # kept for BC (non `Cache` `past_key_values` inputs) - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - if past_key_values is None: - past_key_values = DynamicCache() - else: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - logger.warning_once( - "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " - "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " - "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" - ) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - - # embed positions - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - causal_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if return_legacy_cache: - next_cache = next_cache.to_legacy_cache() - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - min_dtype = torch.finfo(dtype).min - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - **kwargs, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape - `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, - to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device - ) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO_1124,Llama->Olmo1124 -class Olmo1124ForCausalLM(Olmo1124PreTrainedModel, GenerationMixin): - _tied_weights_keys = ["lm_head.weight"] - - def __init__(self, config): - super().__init__(config) - self.model = Olmo1124Model(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - @add_start_docstrings_to_model_forward(OLMO_1124_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) - # Ignore copy - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - **loss_kwargs, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Olmo1124ForCausalLM - - >>> model = Olmo1124ForCausalLM.from_pretrained("allenai/OLMo-7B-1124-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-7B-1124-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' - ``` - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index f2d269796a3a6f..702c971416c843 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -1,4 +1,6 @@ +from ...pytorch_utils import ALL_LAYERNORM_LAYERS from ...utils import logging +from ..llama.modeling_llama import LlamaRMSNorm from ..olmo.configuration_olmo import OlmoConfig @@ -58,3 +60,10 @@ def __init__( self.rms_norm_eps = rms_norm_eps del self.clip_qkv + + +class Olmo1124RMSNorm(LlamaRMSNorm): + pass + + +ALL_LAYERNORM_LAYERS.append(Olmo1124RMSNorm) From 916a8417b0c8c4f9880de5ee26f22884941b2798 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 16:01:06 -0700 Subject: [PATCH 04/22] Add flash attention with qk norm and no qkv clipping --- .../olmo_1124/configuration_olmo_1124.py | 1 + .../models/olmo_1124/modeling_olmo_1124.py | 440 ++++++++++++++++++ .../models/olmo_1124/modular_olmo_1124.py | 263 ++++++++++- 3 files changed, 703 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/olmo_1124/configuration_olmo_1124.py b/src/transformers/models/olmo_1124/configuration_olmo_1124.py index 10ac098e6d8a9c..7f29b5586a0382 100644 --- a/src/transformers/models/olmo_1124/configuration_olmo_1124.py +++ b/src/transformers/models/olmo_1124/configuration_olmo_1124.py @@ -5,6 +5,7 @@ # modular_olmo_1124.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + from ...configuration_utils import PretrainedConfig diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py index fad03273e9fb93..5e052d6e6f1c08 100644 --- a/src/transformers/models/olmo_1124/modeling_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py @@ -4,11 +4,26 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_olmo_1124.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import math +from typing import Optional, Tuple import torch import torch.utils.checkpoint from torch import nn +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import _flash_attention_forward +from ...utils import ( + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, +) +from .configuration_olmo_1124 import Olmo1124Config + + +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + class Olmo1124RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -28,3 +43,428 @@ def forward(self, hidden_states): def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +logger = logging.get_logger(__name__) + + +# copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo1124 +# TODO(joao): add me back asap :) +class Olmo1124RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo1124 +# TODO(joao): add me back asap :) +class Olmo1124LinearScalingRotaryEmbedding(Olmo1124RotaryEmbedding): + """Olmo1124RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +# copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo1124 +# TODO(joao): add me back asap :) +class Olmo1124DynamicNTKScalingRotaryEmbedding(Olmo1124RotaryEmbedding): + """Olmo1124RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Olmo1124Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo1124 + # TODO(joao): add me back asap :) + def __init__(self, config: Olmo1124Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self.q_norm = Olmo1124RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = Olmo1124RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = Olmo1124RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = Olmo1124LinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = Olmo1124DynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Olmo1124FlashAttention2(Olmo1124Attention): + """ + Olmo1124 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + + OLMo November 2024 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (OlmoRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Olmo1124SdpaAttention(Olmo1124Attention): + """ + Olmo1124 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `Olmo1124Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from Olmo1124Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Olmo1124Model is using Olmo1124SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = attention_mask + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index 702c971416c843..01f8126565dc41 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -1,9 +1,26 @@ +import math +from typing import Optional, Tuple + +import torch +from torch import nn + +from ...cache_utils import Cache from ...pytorch_utils import ALL_LAYERNORM_LAYERS -from ...utils import logging +from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging from ..llama.modeling_llama import LlamaRMSNorm from ..olmo.configuration_olmo import OlmoConfig +from ..olmo.modeling_olmo import ( + OlmoAttention, + OlmoFlashAttention2, + OlmoSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, +) +if is_flash_attn_2_available(): + from ...modeling_flash_attention_utils import _flash_attention_forward + logger = logging.get_logger(__name__) @@ -67,3 +84,247 @@ class Olmo1124RMSNorm(LlamaRMSNorm): ALL_LAYERNORM_LAYERS.append(Olmo1124RMSNorm) + + +# Olmo1124 attention is identical to OLMo attention except: +# - Norm is applied to attention queries and keys. +# - No qkv clipping. +class Olmo1124Attention(OlmoAttention): + def __init__(self, config: Olmo1124Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx=layer_idx) + self.q_norm = Olmo1124RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps) + self.k_norm = Olmo1124RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Olmo1124FlashAttention2(OlmoFlashAttention2, Olmo1124Attention): + """ + OLMo November 2024 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + Olmo1124Attention.__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (OlmoRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class Olmo1124SdpaAttention(OlmoSdpaAttention, Olmo1124Attention): + # Adapted from Olmo1124Attention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "Olmo1124Model is using Olmo1124SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + bsz, q_len, _ = hidden_states.size() + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = attention_mask + # if attention_mask is not None and cache_position is not None: + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value From c19d622de2b7000d4764975295cb47515cf8b1c9 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 16:02:43 -0700 Subject: [PATCH 05/22] Add decoder layer with RMSNorm after attention/feedforward layers --- .../models/olmo_1124/modeling_olmo_1124.py | 96 +++++++++++++++++++ .../models/olmo_1124/modular_olmo_1124.py | 51 ++++++++++ 2 files changed, 147 insertions(+) diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py index 5e052d6e6f1c08..4ed50006675fde 100644 --- a/src/transformers/models/olmo_1124/modeling_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py @@ -11,6 +11,7 @@ import torch.utils.checkpoint from torch import nn +from ...activations import ACT2FN from ...cache_utils import Cache from ...modeling_flash_attention_utils import _flash_attention_forward from ...utils import ( @@ -286,6 +287,21 @@ def forward( return attn_output, attn_weights, past_key_value +class Olmo1124MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + class Olmo1124FlashAttention2(Olmo1124Attention): """ Olmo1124 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays @@ -468,3 +484,83 @@ def forward( attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value + + +OLMO_1124_ATTENTION_CLASSES = { + "eager": Olmo1124Attention, + "flash_attention_2": Olmo1124FlashAttention2, + "sdpa": Olmo1124SdpaAttention, +} + + +class Olmo1124DecoderLayer(nn.Module): + def __init__(self, config: Olmo1124Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = OLMO_1124_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = Olmo1124MLP(config) + self.input_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward + # TODO(joao): add me back asap :) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.input_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + return outputs diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index 01f8126565dc41..d2c667490a1887 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -11,6 +11,7 @@ from ..olmo.configuration_olmo import OlmoConfig from ..olmo.modeling_olmo import ( OlmoAttention, + OlmoDecoderLayer, OlmoFlashAttention2, OlmoSdpaAttention, apply_rotary_pos_emb, @@ -328,3 +329,53 @@ def forward( attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value + + +# The OLMo November 2024 layers are identical to those of the OLMo model except: +# - RMSNorm is used instead of standard layer norm. +# - Norm is applied after attention/feedforward rather than before. +class Olmo1124DecoderLayer(OlmoDecoderLayer): + def __init__(self, config: Olmo1124Config, layer_idx: int): + super().__init__(config, layer_idx=layer_idx) + self.input_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = self.input_layernorm(hidden_states) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if use_cache: + outputs += (present_key_value,) + return outputs From ebec833da007c2aaa4653828977ad12ffc50c544 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 16:03:33 -0700 Subject: [PATCH 06/22] Add base and causal model --- .../models/olmo_1124/modeling_olmo_1124.py | 535 +++++++++++++++++- .../models/olmo_1124/modular_olmo_1124.py | 20 + 2 files changed, 553 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py index 4ed50006675fde..42f8a5df902b6a 100644 --- a/src/transformers/models/olmo_1124/modeling_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py @@ -5,19 +5,29 @@ # modular_olmo_1124.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward +from ...modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from ...modeling_utils import PreTrainedModel from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging, + replace_return_docstrings, ) from .configuration_olmo_1124 import Olmo1124Config @@ -564,3 +574,524 @@ def forward( if use_cache: outputs += (present_key_value,) return outputs + + +OLMO_1124_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Olmo1124Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Olmo1124 Model outputting raw hidden-states without any specific head on top.", + OLMO_1124_START_DOCSTRING, +) +class Olmo1124PreTrainedModel(PreTrainedModel): + config_class = Olmo1124Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Olmo1124DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +_CONFIG_FOR_DOC = "Olmo1124Config" + + +OLMO_1124_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Olmo1124 Model outputting raw hidden-states without any specific head on top.", + OLMO_1124_START_DOCSTRING, +) +class Olmo1124Model(Olmo1124PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Olmo1124DecoderLayer`] + + Args: + config: Olmo1124Config + """ + + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Olmo1124DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(OLMO_1124_INPUTS_DOCSTRING) + # copied from transformers.models.llama.modeling_llama.LlamaModel.forward + # TODO(joao): add me back asap :) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +class Olmo1124ForCausalLM(Olmo1124PreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.model = Olmo1124Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(OLMO_1124_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **loss_kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Olmo1124ForCausalLM + + >>> model = Olmo1124ForCausalLM.from_pretrained("allenai/Olmo1124-1B-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo1124-1B-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m' + ``` + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index d2c667490a1887..0febeb3d60c84e 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -13,6 +13,8 @@ OlmoAttention, OlmoDecoderLayer, OlmoFlashAttention2, + OlmoForCausalLM, + OlmoModel, OlmoSdpaAttention, apply_rotary_pos_emb, repeat_kv, @@ -379,3 +381,21 @@ def forward( if use_cache: outputs += (present_key_value,) return outputs + + +# The OLMo November 2024 model is identical to the OLMo model, except RMSNorm is used instead of +# standard layer norm for the output norm. +class Olmo1124Model(OlmoModel): + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.layers = nn.ModuleList( + [Olmo1124DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +# The heads now only need to redefine the model inside to the correct `RobertaModel` +class Olmo1124ForCausalLM(OlmoForCausalLM): + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.model = Olmo1124Model(config) \ No newline at end of file From ee2ad8a97db66f8fda089ea6eeef3411f04faa93 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 16:07:05 -0700 Subject: [PATCH 07/22] Add converter improvements from OLMo repo --- .../convert_olmo_1124_weights_to_hf.py | 66 +++++++++++++++---- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py index c021940b7e8b36..cb1783290f3a84 100644 --- a/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py +++ b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py @@ -17,6 +17,7 @@ import os import shutil from pathlib import Path +from typing import Any, Dict import torch import yaml @@ -62,7 +63,15 @@ def write_json(text, path): json.dump(text, f) -def write_model(model_path, input_base_path, tokenizer_path=None, safe_serialization=True, fix_eos_token_id=True): +def write_model( + model_path, + input_base_path, + include_tokenizer=True, + tokenizer_path=None, + safe_serialization=True, + fix_eos_token_id=True, + tmp_cleanup=True, +): os.makedirs(model_path, exist_ok=True) tmp_model_path = os.path.join(model_path, "tmp") os.makedirs(tmp_model_path, exist_ok=True) @@ -94,7 +103,7 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa loaded = torch.load(os.path.join(input_base_path, "model.pt"), map_location="cpu") param_count = 0 - index_dict = {"weight_map": {}} + index_dict: Dict[str, Any] = {"weight_map": {}} for layer_i in range(n_layers): filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" # Unsharded @@ -149,6 +158,11 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa else: intermediate_size = (dim * olmo_1124_config["mlp_ratio"]) // 2 + if fix_eos_token_id and olmo_1124_config["eos_token_id"] == 0: + # Fixing a bug in OLMo where eos token id was incorrectly set + print("Changing eos_token_id from 0 to 50279.") + olmo_1124_config["eos_token_id"] = 50279 + config = Olmo1124Config( vocab_size=vocab_size, hidden_size=dim, @@ -171,8 +185,8 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa del loaded gc.collect() - if tokenizer_path is not None: - _write_tokenizer(model_path, config, tokenizer_path, fix_eos_token_id) + if include_tokenizer: + _write_tokenizer(model_path, config, input_base_path, tokenizer_path) print("Loading the checkpoint in a OLMo November 2024 model.") model = Olmo1124ForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True) @@ -180,24 +194,35 @@ def write_model(model_path, input_base_path, tokenizer_path=None, safe_serializa del model.config._name_or_path print("Saving in the Transformers format.") model.save_pretrained(model_path, safe_serialization=safe_serialization) - shutil.rmtree(tmp_model_path) + if tmp_cleanup: + # Make cleanup optional; attempting to `rmtree` the `tmp_model_path` causes + # errors if using NFS. + shutil.rmtree(tmp_model_path) def _write_tokenizer( - output_path: Path, config: Olmo1124Config, input_tokenizer_path: Path, fix_eos_token_id: bool = True + output_path: Path, + config: Olmo1124Config, + checkpoint_dir: str, + input_tokenizer_path: Path | None, ) -> None: print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.") - base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path)) + if input_tokenizer_path is not None: + base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path)) + else: + config_path = Path(checkpoint_dir) / "config.yaml" + tokenizer_config = yaml.safe_load(config_path.read_text())["tokenizer"] + + # Initialize tokenizer and validate vocab size. + if Path(tokenizer_config["identifier"]).is_file(): + base_tokenizer = Tokenizer.from_file(tokenizer_config["identifier"]) + else: + base_tokenizer = Tokenizer.from_pretrained(tokenizer_config["identifier"]) eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1 pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id - if fix_eos_token_id and eos_token_id == 0: - # Fixing a bug in OLMo November 2024 where eos token id was incorrectly set - print("Changing eos_token_id from 0 to 50279.") - eos_token_id = 50279 - tokenizer = GPTNeoXTokenizerFast( tokenizer_object=base_tokenizer, eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False), @@ -216,10 +241,17 @@ def main(): required=True, help="Location of OLMo November 2024 weights, which contains config.yaml and model.pt.", ) + parser.add_argument( + "--no_tokenizer", + action="store_false", + dest="include_tokenizer", + help="If set, do not convert OLMo tokenizer to HF tokenizer.", + ) parser.add_argument( "--tokenizer_json_path", + type=Path, default=None, - help="Location of OLMo November 2024 tokenizer json file.", + help="Location of OLMo November 2024 tokenizer json file. Defaults to what is set in the config file.", ) parser.add_argument( "--output_dir", @@ -232,6 +264,12 @@ def main(): dest="fix_eos_token_id", help="If set, does not change eos token id from 0 to 50279 if it is 0. Changing 0 to 50279 is a bug fix, so use this option with care.", ) + parser.add_argument( + "--no_tmp_cleanup", + action="store_false", + dest="tmp_cleanup", + help="If passed, don't remove temp dir at end of HF conversion.", + ) parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") # Different OLMo November 2024 versions used different default values for max_position_embeddings, hence the need to be able to specify which version is being used. args = parser.parse_args() @@ -239,8 +277,10 @@ def main(): model_path=args.output_dir, input_base_path=args.input_dir, safe_serialization=args.safe_serialization, + include_tokenizer=args.include_tokenizer, tokenizer_path=args.tokenizer_json_path, fix_eos_token_id=args.fix_eos_token_id, + tmp_cleanup=args.tmp_cleanup, ) From 84f72ba0ea35dc3c4dcde154c37b2a97c2ada032 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 16:12:28 -0700 Subject: [PATCH 08/22] Update weight loading in OLMo to HF converter --- .../olmo_1124/convert_olmo_1124_weights_to_hf.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py index cb1783290f3a84..34f0b01cdb5ec6 100644 --- a/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py +++ b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py @@ -79,11 +79,16 @@ def write_model( config_path = Path(input_base_path) / "config.yaml" olmo_1124_config = yaml.safe_load(config_path.read_text())["model"] + if not olmo_1124_config.get("attention_layer_norm", False): + raise RuntimeError("OLMo November 2024 checkpoints must have attention layer norm") + if not olmo_1124_config.get("norm_after", False): + raise RuntimeError("OLMo November 2024 checkpoints must set norm_after to True") + n_layers = olmo_1124_config["n_layers"] n_heads = olmo_1124_config["n_heads"] dim = olmo_1124_config["d_model"] dims_per_head = dim // n_heads - base = 10000.0 + base = olmo_1124_config["rope_theta"] inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) max_position_embeddings = olmo_1124_config["max_sequence_length"] @@ -121,9 +126,15 @@ def write_model( f"model.layers.{layer_i}.self_attn.k_proj.weight": k_proj_weight, f"model.layers.{layer_i}.self_attn.v_proj.weight": v_proj_weight, f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.blocks.{layer_i}.attn_out.weight"], + f"model.layers.{layer_i}.self_attn.q_norm.weight": loaded[f"transformer.blocks.{layer_i}.q_norm.weight"], + f"model.layers.{layer_i}.self_attn.k_norm.weight": loaded[f"transformer.blocks.{layer_i}.k_norm.weight"], f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight, f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"], f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight, + f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.blocks.{layer_i}.attn_norm.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ + f"transformer.blocks.{layer_i}.ff_norm.weight" + ], } state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq @@ -139,6 +150,7 @@ def write_model( # TODO: Deal with weight-tying state_dict = { "model.embed_tokens.weight": loaded["transformer.wte.weight"], + "model.norm.weight": loaded["transformer.ln_f.weight"], "lm_head.weight": loaded["transformer.ff_out.weight"] if "transformer.ff_out.weight" in loaded else loaded["transformer.wte.weight"], @@ -175,8 +187,8 @@ def write_model( bos_token_id=None, eos_token_id=olmo_1124_config["eos_token_id"], tie_word_embeddings=olmo_1124_config["weight_tying"], + rms_norm_eps=olmo_1124_config["layer_norm_eps"], rope_theta=base, - clip_qkv=olmo_1124_config.get("clip_qkv"), ) config.save_pretrained(tmp_model_path) From 68c67638cb0b748ce582fce6ca832b1f57664993 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 16:12:52 -0700 Subject: [PATCH 09/22] Set correct default for rms_norm_eps --- src/transformers/models/olmo_1124/configuration_olmo_1124.py | 2 +- src/transformers/models/olmo_1124/modular_olmo_1124.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/olmo_1124/configuration_olmo_1124.py b/src/transformers/models/olmo_1124/configuration_olmo_1124.py index 7f29b5586a0382..31eb25ebb327a0 100644 --- a/src/transformers/models/olmo_1124/configuration_olmo_1124.py +++ b/src/transformers/models/olmo_1124/configuration_olmo_1124.py @@ -113,7 +113,7 @@ def __init__( rope_scaling=None, attention_bias=False, attention_dropout=0.0, - rms_norm_eps=1e-6, + rms_norm_eps=1e-5, **kwargs, ): self.vocab_size = vocab_size diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index 0febeb3d60c84e..f72c2b59b06cee 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -53,7 +53,7 @@ def __init__( rope_scaling=None, attention_bias=False, attention_dropout=0.0, - rms_norm_eps=1e-6, + rms_norm_eps=1e-5, **kwargs, ): super().__init__( From 8e5a74af9996a0f32d48743aeeafa6d8c713d1eb Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 16:14:29 -0700 Subject: [PATCH 10/22] Set correct pipeline_model_mapping in test --- tests/models/olmo_1124/test_modeling_olmo_1124.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/models/olmo_1124/test_modeling_olmo_1124.py b/tests/models/olmo_1124/test_modeling_olmo_1124.py index af56c475a3005b..f732276ebbc8f9 100644 --- a/tests/models/olmo_1124/test_modeling_olmo_1124.py +++ b/tests/models/olmo_1124/test_modeling_olmo_1124.py @@ -275,6 +275,14 @@ def prepare_config_and_inputs_for_common(self): class Olmo1124ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (Olmo1124Model, Olmo1124ForCausalLM) if is_torch_available() else () all_generative_model_classes = (Olmo1124ForCausalLM,) if is_torch_available() else () + pipeline_model_mapping = ( + { + "feature-extraction": Olmo1124Model, + "text-generation": Olmo1124ForCausalLM, + } + if is_torch_available() + else {} + ) test_pruning = False fx_compatible = False From a5f92c2b69574ed4d38e792412daa89a1849d464 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 31 Oct 2024 16:50:44 -0700 Subject: [PATCH 11/22] Run make fixup --- docs/source/en/index.md | 1 + .../olmo_1124/configuration_olmo_1124.py | 4 +--- .../models/olmo_1124/modular_olmo_1124.py | 2 +- src/transformers/utils/dummy_pt_objects.py | 21 +++++++++++++++++++ 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/docs/source/en/index.md b/docs/source/en/index.md index aaff45ab65dfb6..5cb80aac80d828 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -240,6 +240,7 @@ Flax), PyTorch, and/or TensorFlow. | [Nougat](model_doc/nougat) | βœ… | βœ… | βœ… | | [NystrΓΆmformer](model_doc/nystromformer) | βœ… | ❌ | ❌ | | [OLMo](model_doc/olmo) | βœ… | ❌ | ❌ | +| [OLMo November 2024](model_doc/olmo-1124) | βœ… | ❌ | ❌ | | [OLMoE](model_doc/olmoe) | βœ… | ❌ | ❌ | | [OmDet-Turbo](model_doc/omdet-turbo) | βœ… | ❌ | ❌ | | [OneFormer](model_doc/oneformer) | βœ… | ❌ | ❌ | diff --git a/src/transformers/models/olmo_1124/configuration_olmo_1124.py b/src/transformers/models/olmo_1124/configuration_olmo_1124.py index 31eb25ebb327a0..b430df8378be1c 100644 --- a/src/transformers/models/olmo_1124/configuration_olmo_1124.py +++ b/src/transformers/models/olmo_1124/configuration_olmo_1124.py @@ -70,9 +70,7 @@ class Olmo1124Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - clip_qkv (`float`, *optional*): - If not `None`, elements of query, key and value attention states are clipped so that their - absolute value does not exceed this value. + rms_norm_eps (``, *optional*, defaults to 1e-05): ```python >>> from transformers import Olmo1124Model, Olmo1124Config diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index f72c2b59b06cee..1af827d709cc9d 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -398,4 +398,4 @@ def __init__(self, config: Olmo1124Config): class Olmo1124ForCausalLM(OlmoForCausalLM): def __init__(self, config: Olmo1124Config): super().__init__(config) - self.model = Olmo1124Model(config) \ No newline at end of file + self.model = Olmo1124Model(config) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 36e1ff2cfe65c4..3bf6d6eb288a9a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -6758,6 +6758,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class Olmo1124ForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Olmo1124Model(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class Olmo1124PreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class OlmoeForCausalLM(metaclass=DummyObject): _backends = ["torch"] From fe2e478fb7b64948bd150fb2c841ae05c70f35db Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 4 Nov 2024 11:28:52 -0800 Subject: [PATCH 12/22] Fix model type --- .../models/olmo_1124/configuration_olmo_1124.py | 6 ++++-- src/transformers/models/olmo_1124/modular_olmo_1124.py | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/olmo_1124/configuration_olmo_1124.py b/src/transformers/models/olmo_1124/configuration_olmo_1124.py index b430df8378be1c..e9877ad631d6f4 100644 --- a/src/transformers/models/olmo_1124/configuration_olmo_1124.py +++ b/src/transformers/models/olmo_1124/configuration_olmo_1124.py @@ -70,7 +70,9 @@ class Olmo1124Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - rms_norm_eps (``, *optional*, defaults to 1e-05): + clip_qkv (`float`, *optional*): + If not `None`, elements of query, key and value attention states are clipped so that their + absolute value does not exceed this value. ```python >>> from transformers import Olmo1124Model, Olmo1124Config @@ -88,7 +90,7 @@ class Olmo1124Config(PretrainedConfig): The epsilon used by the rms normalization layers. """ - model_type = "olmo_1124" + model_type = "olmo-1124" keys_to_ignore_at_inference = ["past_key_values"] def __init__( diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index 1af827d709cc9d..6a59cff18159c6 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -33,6 +33,8 @@ class Olmo1124Config(OlmoConfig): The epsilon used by the rms normalization layers. """ + model_type = "olmo-1124" + def __init__( self, vocab_size=50304, From 43499385d3dc32a6a2b40a986f1fba1b69b4e48c Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 4 Nov 2024 11:48:36 -0800 Subject: [PATCH 13/22] Re-run modular conversion --- .../olmo_1124/configuration_olmo_1124.py | 16 +++--- .../models/olmo_1124/modeling_olmo_1124.py | 49 +++++++++---------- 2 files changed, 29 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/olmo_1124/configuration_olmo_1124.py b/src/transformers/models/olmo_1124/configuration_olmo_1124.py index e9877ad631d6f4..c7e5d04060843e 100644 --- a/src/transformers/models/olmo_1124/configuration_olmo_1124.py +++ b/src/transformers/models/olmo_1124/configuration_olmo_1124.py @@ -5,7 +5,6 @@ # modular_olmo_1124.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 - from ...configuration_utils import PretrainedConfig @@ -116,6 +115,13 @@ def __init__( rms_norm_eps=1e-5, **kwargs, ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size @@ -139,14 +145,6 @@ def __init__( self.rms_norm_eps = rms_norm_eps - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py index 42f8a5df902b6a..457335c3680120 100644 --- a/src/transformers/models/olmo_1124/modeling_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py @@ -8,7 +8,6 @@ from typing import List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN @@ -16,10 +15,7 @@ from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -36,6 +32,11 @@ from ...modeling_flash_attention_utils import _flash_attention_forward +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Olmo1124Config" + + class Olmo1124RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -56,9 +57,6 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -logger = logging.get_logger(__name__) - - # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo1124 # TODO(joao): add me back asap :) class Olmo1124RotaryEmbedding(nn.Module): @@ -205,9 +203,9 @@ def __init__(self, config: Olmo1124Config, layer_idx: Optional[int] = None): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() self.q_norm = Olmo1124RMSNorm(self.num_heads * self.head_dim, config.rms_norm_eps) self.k_norm = Olmo1124RMSNorm(self.num_key_value_heads * self.head_dim, config.rms_norm_eps) - self._init_rope() def _init_rope(self): if self.config.rope_scaling is None: @@ -297,21 +295,6 @@ def forward( return attn_output, attn_weights, past_key_value -class Olmo1124MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - class Olmo1124FlashAttention2(Olmo1124Attention): """ Olmo1124 flash attention module. This module inherits from `Olmo1124Attention` as the weights of the module stays @@ -496,6 +479,21 @@ def forward( return attn_output, None, past_key_value +class Olmo1124MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + OLMO_1124_ATTENTION_CLASSES = { "eager": Olmo1124Attention, "flash_attention_2": Olmo1124FlashAttention2, @@ -621,9 +619,6 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() -_CONFIG_FOR_DOC = "Olmo1124Config" - - OLMO_1124_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): From bd94d9cab9a361adfffb33987ca14ea8ea7388bf Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 4 Nov 2024 12:28:18 -0800 Subject: [PATCH 14/22] Manually set config docs to fix build errors --- .../olmo_1124/configuration_olmo_1124.py | 13 ++-- .../models/olmo_1124/modular_olmo_1124.py | 76 ++++++++++++++++++- 2 files changed, 79 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/olmo_1124/configuration_olmo_1124.py b/src/transformers/models/olmo_1124/configuration_olmo_1124.py index c7e5d04060843e..d0c5051a832c64 100644 --- a/src/transformers/models/olmo_1124/configuration_olmo_1124.py +++ b/src/transformers/models/olmo_1124/configuration_olmo_1124.py @@ -10,7 +10,7 @@ class Olmo1124Config(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`Olmo1124Model`]. It is used to instantiate an Olmo1124 + This is the configuration class to store the configuration of a [`Olmo1124Model`]. It is used to instantiate an OLMo November 2024 model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the [allenai/Olmo1124-7B-hf](https://huggingface.co/allenai/Olmo1124-7B-hf). @@ -69,24 +69,21 @@ class Olmo1124Config(PretrainedConfig): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - clip_qkv (`float`, *optional*): - If not `None`, elements of query, key and value attention states are clipped so that their - absolute value does not exceed this value. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. ```python >>> from transformers import Olmo1124Model, Olmo1124Config - >>> # Initializing a Olmo1124 7B style configuration + >>> # Initializing a Olmo November 2024 7B style configuration >>> configuration = Olmo1124Config() - >>> # Initializing a model from the Olmo1124 7B style configuration + >>> # Initializing a model from the Olmo November 2024 7B style configuration >>> model = Olmo1124Model(configuration) >>> # Accessing the model configuration >>> configuration = model.config ``` - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. """ model_type = "olmo-1124" diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index 6a59cff18159c6..8fecf65b59e20a 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -29,8 +29,80 @@ class Olmo1124Config(OlmoConfig): r""" - rms_norm_eps (`float`, *optional*, defaults to 1e-06): - The epsilon used by the rms normalization layers. + This is the configuration class to store the configuration of a [`Olmo1124Model`]. It is used to instantiate an OLMo November 2024 + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the [allenai/Olmo1124-7B-hf](https://huggingface.co/allenai/Olmo1124-7B-hf). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50304): + Vocabulary size of the Olmo1124 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Olmo1124Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 1): + Padding token id. + bos_token_id (`int`, *optional*): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 50279): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + + ```python + >>> from transformers import Olmo1124Model, Olmo1124Config + + >>> # Initializing a Olmo November 2024 7B style configuration + >>> configuration = Olmo1124Config() + + >>> # Initializing a model from the Olmo November 2024 7B style configuration + >>> model = Olmo1124Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` """ model_type = "olmo-1124" From 7a0cbbea5347af211d6cbe8d003a321e08ecab4d Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 4 Nov 2024 12:36:30 -0800 Subject: [PATCH 15/22] Convert olmo-1124 to olmo_1124 to fix flash attention docs errors --- docs/source/en/_toctree.yml | 2 +- docs/source/en/index.md | 2 +- docs/source/en/model_doc/{olmo-1124.md => olmo_1124.md} | 0 docs/source/en/perf_infer_gpu_one.md | 2 ++ src/transformers/models/auto/configuration_auto.py | 4 ++-- src/transformers/models/auto/modeling_auto.py | 4 ++-- src/transformers/models/auto/tokenization_auto.py | 2 +- src/transformers/models/olmo_1124/configuration_olmo_1124.py | 2 +- src/transformers/models/olmo_1124/modular_olmo_1124.py | 2 +- 9 files changed, 11 insertions(+), 9 deletions(-) rename docs/source/en/model_doc/{olmo-1124.md => olmo_1124.md} (100%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 193af77f5693ea..d800e40ecbd69d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -514,7 +514,7 @@ title: NystrΓΆmformer - local: model_doc/olmo title: OLMo - - local: model_doc/olmo-1124 + - local: model_doc/olmo_1124 title: OLMo November 2024 - local: model_doc/olmoe title: OLMoE diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 5cb80aac80d828..341cb417c7b8ac 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -240,7 +240,7 @@ Flax), PyTorch, and/or TensorFlow. | [Nougat](model_doc/nougat) | βœ… | βœ… | βœ… | | [NystrΓΆmformer](model_doc/nystromformer) | βœ… | ❌ | ❌ | | [OLMo](model_doc/olmo) | βœ… | ❌ | ❌ | -| [OLMo November 2024](model_doc/olmo-1124) | βœ… | ❌ | ❌ | +| [OLMo November 2024](model_doc/olmo_1124) | βœ… | ❌ | ❌ | | [OLMoE](model_doc/olmoe) | βœ… | ❌ | ❌ | | [OmDet-Turbo](model_doc/omdet-turbo) | βœ… | ❌ | ❌ | | [OneFormer](model_doc/oneformer) | βœ… | ❌ | ❌ | diff --git a/docs/source/en/model_doc/olmo-1124.md b/docs/source/en/model_doc/olmo_1124.md similarity index 100% rename from docs/source/en/model_doc/olmo-1124.md rename to docs/source/en/model_doc/olmo_1124.md diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 67bd31fdaeede5..84109746f95998 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -77,6 +77,7 @@ FlashAttention-2 is currently supported for the following architectures: * [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron) * [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) +* [OLMo November 2024](https://huggingface.co/docs/transformers/model_doc/olmo_1124#transformers.Olmo1124Model) * [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel) * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) * [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration) @@ -260,6 +261,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb) * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) +* [OLMo November 2024](https://huggingface.co/docs/transformers/model_doc/olmo_1124#transformers.Olmo1124Model) * [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel) * [OPT](https://huggingface.co/docs/transformers/en/model_doc/opt) * [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7769edb6608dba..7f0182b50085c5 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -195,7 +195,7 @@ ("nougat", "VisionEncoderDecoderConfig"), ("nystromformer", "NystromformerConfig"), ("olmo", "OlmoConfig"), - ("olmo-1124", "Olmo1124Config"), + ("olmo_1124", "Olmo1124Config"), ("olmoe", "OlmoeConfig"), ("omdet-turbo", "OmDetTurboConfig"), ("oneformer", "OneFormerConfig"), @@ -511,7 +511,7 @@ ("nougat", "Nougat"), ("nystromformer", "NystrΓΆmformer"), ("olmo", "OLMo"), - ("olmo-1124", "OLMo November 2024"), + ("olmo_1124", "OLMo November 2024"), ("olmoe", "OLMoE"), ("omdet-turbo", "OmDet-Turbo"), ("oneformer", "OneFormer"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 572ec81d6c5e56..5206972b72efde 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -184,7 +184,7 @@ ("nllb-moe", "NllbMoeModel"), ("nystromformer", "NystromformerModel"), ("olmo", "OlmoModel"), - ("olmo-1124", "Olmo1124Model"), + ("olmo_1124", "Olmo1124Model"), ("olmoe", "OlmoeModel"), ("omdet-turbo", "OmDetTurboForObjectDetection"), ("oneformer", "OneFormerModel"), @@ -517,7 +517,7 @@ ("mvp", "MvpForCausalLM"), ("nemotron", "NemotronForCausalLM"), ("olmo", "OlmoForCausalLM"), - ("olmo-1124", "Olmo1124ForCausalLM"), + ("olmo_1124", "Olmo1124ForCausalLM"), ("olmoe", "OlmoeForCausalLM"), ("open-llama", "OpenLlamaForCausalLM"), ("openai-gpt", "OpenAIGPTLMHeadModel"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index e4a4533b48ea7f..4ed67df0e84b52 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -348,7 +348,7 @@ ), ), ("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), - ("olmo-1124", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("olmo_1124", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ( "omdet-turbo", diff --git a/src/transformers/models/olmo_1124/configuration_olmo_1124.py b/src/transformers/models/olmo_1124/configuration_olmo_1124.py index d0c5051a832c64..c451c48af032f2 100644 --- a/src/transformers/models/olmo_1124/configuration_olmo_1124.py +++ b/src/transformers/models/olmo_1124/configuration_olmo_1124.py @@ -86,7 +86,7 @@ class Olmo1124Config(PretrainedConfig): ``` """ - model_type = "olmo-1124" + model_type = "olmo_1124" keys_to_ignore_at_inference = ["past_key_values"] def __init__( diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index 8fecf65b59e20a..b9de62beb6031e 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -105,7 +105,7 @@ class Olmo1124Config(OlmoConfig): ``` """ - model_type = "olmo-1124" + model_type = "olmo_1124" def __init__( self, From ada451cf993bb073b39be26e1233ed0a8b6be843 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 4 Nov 2024 13:04:53 -0800 Subject: [PATCH 16/22] Start updating tests --- .../olmo_1124/test_modeling_olmo_1124.py | 36 ++++--------------- 1 file changed, 6 insertions(+), 30 deletions(-) diff --git a/tests/models/olmo_1124/test_modeling_olmo_1124.py b/tests/models/olmo_1124/test_modeling_olmo_1124.py index f732276ebbc8f9..87a85913351140 100644 --- a/tests/models/olmo_1124/test_modeling_olmo_1124.py +++ b/tests/models/olmo_1124/test_modeling_olmo_1124.py @@ -349,22 +349,10 @@ def test_model_rope_scaling(self, scaling_type): @require_torch class Olmo1124IntegrationTest(unittest.TestCase): - @slow - def test_model_1b_logits(self): - input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]] - model = Olmo1124ForCausalLM.from_pretrained("allenai/OLMo-7B-1124-hf", device_map="auto") - out = model(torch.tensor(input_ids)).logits.float() - # Expected mean on dim = -1 - EXPECTED_MEAN = torch.tensor([[2.2869, 0.3315, 0.9876, 1.4146, 1.8804, 2.0430, 1.7055, 1.2065]]) - torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) - # slicing logits[0, 0, 0:30] - EXPECTED_SLICE = torch.tensor([2.5551, -1.1230, 11.0510, 12.4977, 7.9651, 7.2342, 6.1885, 7.8340, 9.9847, 12.6695, 12.2345, 10.7970, 8.4749, 14.2483, 12.9588, 13.9233, 11.0496, 5.5749, 7.4466, 7.7914, 6.8440, 5.8951, 4.8180, 4.1935, 4.5216, 4.7256, 3.9553, 12.2870, 12.4990, 8.1591]) # fmt: skip - torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-2, rtol=1e-2) - @slow def test_model_7b_logits(self): input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]] - model = Olmo1124ForCausalLM.from_pretrained("allenai/OLMo November 2024-7B-hf", device_map="auto") + model = Olmo1124ForCausalLM.from_pretrained("shanearora/OLMo-7B-1124-hf", device_map="auto") out = model(torch.tensor(input_ids)).logits.float() # Expected mean on dim = -1 EXPECTED_MEAN = torch.tensor([[0.0271, 0.0249, -0.0578, -0.0870, 0.0167, 0.0710, 0.1002, 0.0677]]) @@ -373,25 +361,13 @@ def test_model_7b_logits(self): EXPECTED_SLICE = torch.tensor([-1.7433, -1.6685, 7.4941, 6.1506, 0.1364, -0.1127, 1.3224, 4.5458, 4.2068, 5.8296, 7.4723, 2.7925, 3.1245, 10.8872, 10.0758, 10.6717, 7.0945, 1.2398, 3.6766, 4.2365, 2.5655, 2.2222, 1.7418, 0.5223, 0.7753, 1.0938, 0.6723, 6.2522, 6.2264, 1.8105]) # fmt: skip torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-2, rtol=1e-2) - @slow - def test_model_7b_twin_2t_logits(self): - input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]] - model = Olmo1124ForCausalLM.from_pretrained("allenai/OLMo November 2024-7B-Twin-2T-hf", device_map="auto") - out = model(torch.tensor(input_ids)).logits.float() - # Expected mean on dim = -1 - EXPECTED_MEAN = torch.tensor([[-0.3636, -0.3825, -0.4800, -0.3696, -0.8388, -0.9737, -0.9849, -0.8356]]) - torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) - # slicing logits[0, 0, 0:30] - EXPECTED_SLICE = torch.tensor([-2.0833, -1.9234, 8.7312, 7.8049, 1.0372, 0.8941, 3.1548, 1.8502, 5.5511, 5.5793, 8.1166, 4.5906, 1.8691, 11.6377, 8.9858, 11.6447, 7.4549, 1.4725, 2.8399, 2.7568, 1.4011, 1.6958, 0.5572, 0.5231, 0.3068, 0.5364, 0.6769, 7.9636, 8.2379, 1.7950]) # fmt: skip - torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-2, rtol=1e-2) - @slow def test_model_7b_greedy_generation(self): EXPECTED_TEXT_COMPLETION = """Simply put, the theory of relativity states that \nthe speed of light is the same for all observers.\n\nThe theory of relativity is a theory of physics that describes the \nmovement of objects in space and time.\n\nThe theory of relativity is a theory of physics that describes the \nmovement of objects in space and time.\n\n""" prompt = "Simply put, the theory of relativity states that " - tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo November 2024-7B-hf", device_map="auto") + tokenizer = AutoTokenizer.from_pretrained("shanearora/OLMo-7B-1124-hf", device_map="auto") input_ids = tokenizer.encode(prompt, return_tensors="pt") - model = Olmo1124ForCausalLM.from_pretrained("allenai/OLMo November 2024-7B-hf", device_map="auto") + model = Olmo1124ForCausalLM.from_pretrained("shanearora/OLMo-7B-1124-hf", device_map="auto") # greedy generation outputs generated_ids = model.generate(input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False) @@ -400,7 +376,7 @@ def test_model_7b_greedy_generation(self): @require_tokenizers def test_fast_special_tokens(self): - fast_tokenizer = GPTNeoXTokenizerFast.from_pretrained("allenai/OLMo-7B-1124-hf") + fast_tokenizer = GPTNeoXTokenizerFast.from_pretrained("shanearora/OLMo-7B-1124-hf") original_add_eos_token = fast_tokenizer.add_eos_token @@ -416,7 +392,7 @@ def test_fast_special_tokens(self): @require_tokenizers def test_simple_encode_decode(self): - rust_tokenizer = GPTNeoXTokenizerFast.from_pretrained("allenai/OLMo-7B-1124-hf") + rust_tokenizer = GPTNeoXTokenizerFast.from_pretrained("shanearora/OLMo-7B-1124-hf") self.assertEqual(rust_tokenizer.encode("This is a test"), [1552, 310, 247, 1071]) self.assertEqual(rust_tokenizer.decode([1552, 310, 247, 1071], skip_special_tokens=True), "This is a test") @@ -453,7 +429,7 @@ def test_export_static_cache(self): convert_and_export_with_cache, ) - olmo_1124_model = "allenai/OLMo-7B-1124-hf" + olmo_1124_model = "shanearora/OLMo-7B-1124-hf" tokenizer = AutoTokenizer.from_pretrained(olmo_1124_model, pad_token="", padding_side="right") EXPECTED_TEXT_COMPLETION = [ From 4a2bf2ea2674f6e3b36d2855bbc2650892f92c25 Mon Sep 17 00:00:00 2001 From: Shane A Date: Mon, 4 Nov 2024 15:27:06 -0800 Subject: [PATCH 17/22] Update tests --- .../olmo_1124/test_modeling_olmo_1124.py | 48 ++++++++++++------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/tests/models/olmo_1124/test_modeling_olmo_1124.py b/tests/models/olmo_1124/test_modeling_olmo_1124.py index 87a85913351140..e3d623f50ee257 100644 --- a/tests/models/olmo_1124/test_modeling_olmo_1124.py +++ b/tests/models/olmo_1124/test_modeling_olmo_1124.py @@ -24,8 +24,10 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast from transformers.testing_utils import ( + is_flaky, require_tokenizers, require_torch, + require_torch_sdpa, slow, torch_device, ) @@ -346,6 +348,14 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + @require_torch_sdpa + @slow + @is_flaky() + # Copied from tests.models.mllama.test_modeling_mllama.MllamaForConditionalGenerationModelTest.test_eager_matches_sdpa_inference_1_bfloat16 + def test_eager_matches_sdpa_inference_1_bfloat16(self): + # A workaround to override parametrized test with flaky decorator + super().test_eager_matches_sdpa_inference_1_bfloat16() + @require_torch class Olmo1124IntegrationTest(unittest.TestCase): @@ -355,19 +365,21 @@ def test_model_7b_logits(self): model = Olmo1124ForCausalLM.from_pretrained("shanearora/OLMo-7B-1124-hf", device_map="auto") out = model(torch.tensor(input_ids)).logits.float() # Expected mean on dim = -1 - EXPECTED_MEAN = torch.tensor([[0.0271, 0.0249, -0.0578, -0.0870, 0.0167, 0.0710, 0.1002, 0.0677]]) + EXPECTED_MEAN = torch.tensor( + [[-13.0244, -13.9564, -11.8270, -11.3047, -12.3794, -12.4215, -15.6030, -12.7962]] + ) torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, atol=1e-2, rtol=1e-2) # slicing logits[0, 0, 0:30] - EXPECTED_SLICE = torch.tensor([-1.7433, -1.6685, 7.4941, 6.1506, 0.1364, -0.1127, 1.3224, 4.5458, 4.2068, 5.8296, 7.4723, 2.7925, 3.1245, 10.8872, 10.0758, 10.6717, 7.0945, 1.2398, 3.6766, 4.2365, 2.5655, 2.2222, 1.7418, 0.5223, 0.7753, 1.0938, 0.6723, 6.2522, 6.2264, 1.8105]) # fmt: skip + EXPECTED_SLICE = torch.tensor([-5.3909, -13.9841, -13.6123, -14.5780, -13.9455, -13.2265, -13.4734, -11.9079, -9.2879, -12.6139, -11.4819, -5.9607, -11.9657, -6.3618, -11.1065, -7.3075, -6.5674, -6.7154, -7.3409, -7.9662, -8.0863, -8.1682, -8.7341, -8.7665, -8.8742, -9.7813, -8.0620, -12.5937, -7.6440, -11.3966]) # fmt: skip torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, atol=1e-2, rtol=1e-2) @slow def test_model_7b_greedy_generation(self): - EXPECTED_TEXT_COMPLETION = """Simply put, the theory of relativity states that \nthe speed of light is the same for all observers.\n\nThe theory of relativity is a theory of physics that describes the \nmovement of objects in space and time.\n\nThe theory of relativity is a theory of physics that describes the \nmovement of objects in space and time.\n\n""" + EXPECTED_TEXT_COMPLETION = """Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the fastest speed possible, and 3) the speed of light is the same for all observers, regardless of their relative motion. The theory of relativity is based on the idea that the speed of light is constant. This means that""" prompt = "Simply put, the theory of relativity states that " tokenizer = AutoTokenizer.from_pretrained("shanearora/OLMo-7B-1124-hf", device_map="auto") - input_ids = tokenizer.encode(prompt, return_tensors="pt") model = Olmo1124ForCausalLM.from_pretrained("shanearora/OLMo-7B-1124-hf", device_map="auto") + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) # greedy generation outputs generated_ids = model.generate(input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False) @@ -382,11 +394,11 @@ def test_fast_special_tokens(self): fast_tokenizer.add_eos_token = False fast = fast_tokenizer.encode("A sample test") - self.assertEqual(fast, [34, 3410, 1071]) + self.assertEqual(fast, [32, 6205, 1296]) fast_tokenizer.add_eos_token = True fast = fast_tokenizer.encode("A sample test") - self.assertEqual(fast, [34, 3410, 1071, 50279]) + self.assertEqual(fast, [32, 6205, 1296, 100257]) fast_tokenizer.add_eos_token = original_add_eos_token @@ -394,30 +406,30 @@ def test_fast_special_tokens(self): def test_simple_encode_decode(self): rust_tokenizer = GPTNeoXTokenizerFast.from_pretrained("shanearora/OLMo-7B-1124-hf") - self.assertEqual(rust_tokenizer.encode("This is a test"), [1552, 310, 247, 1071]) - self.assertEqual(rust_tokenizer.decode([1552, 310, 247, 1071], skip_special_tokens=True), "This is a test") + self.assertEqual(rust_tokenizer.encode("This is a test"), [2028, 374, 264, 1296]) + self.assertEqual(rust_tokenizer.decode([2028, 374, 264, 1296], skip_special_tokens=True), "This is a test") # bytefallback showcase - self.assertEqual(rust_tokenizer.encode("η”Ÿζ΄»ηš„ηœŸθ°›ζ˜―"), [20025, 46549, 5225, 48561, 33656, 238, 12105]) # fmt: skip + self.assertEqual(rust_tokenizer.encode("η”Ÿζ΄»ηš„ηœŸθ°›ζ˜―"), [21990, 76706, 9554, 89151, 39013, 249, 21043]) # fmt: skip self.assertEqual( - rust_tokenizer.decode([20025, 46549, 5225, 48561, 33656, 238, 12105], skip_special_tokens=True), + rust_tokenizer.decode([21990, 76706, 9554, 89151, 39013, 249, 21043], skip_special_tokens=True), "η”Ÿζ΄»ηš„ηœŸθ°›ζ˜―", ) # Inner spaces showcase - self.assertEqual(rust_tokenizer.encode("Hi Hello"), [12764, 50276, 12092]) - self.assertEqual(rust_tokenizer.decode([12764, 50276, 12092], skip_special_tokens=True), "Hi Hello") + self.assertEqual(rust_tokenizer.encode("Hi Hello"), [13347, 220, 22691]) + self.assertEqual(rust_tokenizer.decode([13347, 220, 22691], skip_special_tokens=True), "Hi Hello") - self.assertEqual(rust_tokenizer.encode("Hi Hello"), [12764, 50275, 12092]) - self.assertEqual(rust_tokenizer.decode([12764, 50275, 12092], skip_special_tokens=True), "Hi Hello") + self.assertEqual(rust_tokenizer.encode("Hi Hello"), [13347, 256, 22691]) + self.assertEqual(rust_tokenizer.decode([13347, 256, 22691], skip_special_tokens=True), "Hi Hello") self.assertEqual(rust_tokenizer.encode(""), []) - self.assertEqual(rust_tokenizer.encode(" "), [209]) + self.assertEqual(rust_tokenizer.encode(" "), [220]) - self.assertEqual(rust_tokenizer.encode(" "), [50276]) + self.assertEqual(rust_tokenizer.encode(" "), [256]) - self.assertEqual(rust_tokenizer.encode(" Hello"), [24387]) + self.assertEqual(rust_tokenizer.encode(" Hello"), [22691]) @slow def test_export_static_cache(self): @@ -433,7 +445,7 @@ def test_export_static_cache(self): tokenizer = AutoTokenizer.from_pretrained(olmo_1124_model, pad_token="", padding_side="right") EXPECTED_TEXT_COMPLETION = [ - "Simply put, the theory of relativity states that \nthe speed of light is the same in all reference frames.\n\nThe speed of light", + "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light", ] max_generation_length = tokenizer(EXPECTED_TEXT_COMPLETION, return_tensors="pt", padding=True)[ "input_ids" From a330c9035de11f18e1c8efb991500b5c9e7745f0 Mon Sep 17 00:00:00 2001 From: Shane A Date: Tue, 5 Nov 2024 16:11:02 -0800 Subject: [PATCH 18/22] Copy upstream test_eager_matches_sdpa_inference_1_bfloat16 changes to olmo_1124 --- tests/models/olmo_1124/test_modeling_olmo_1124.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/models/olmo_1124/test_modeling_olmo_1124.py b/tests/models/olmo_1124/test_modeling_olmo_1124.py index e3d623f50ee257..9c2b8534ba022e 100644 --- a/tests/models/olmo_1124/test_modeling_olmo_1124.py +++ b/tests/models/olmo_1124/test_modeling_olmo_1124.py @@ -24,10 +24,8 @@ from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast from transformers.testing_utils import ( - is_flaky, require_tokenizers, require_torch, - require_torch_sdpa, slow, torch_device, ) @@ -348,14 +346,6 @@ def test_model_rope_scaling(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - @require_torch_sdpa - @slow - @is_flaky() - # Copied from tests.models.mllama.test_modeling_mllama.MllamaForConditionalGenerationModelTest.test_eager_matches_sdpa_inference_1_bfloat16 - def test_eager_matches_sdpa_inference_1_bfloat16(self): - # A workaround to override parametrized test with flaky decorator - super().test_eager_matches_sdpa_inference_1_bfloat16() - @require_torch class Olmo1124IntegrationTest(unittest.TestCase): From 1af9cbd9e523714acb65865dda7ac793b8e6f0b8 Mon Sep 17 00:00:00 2001 From: Shane A Date: Wed, 6 Nov 2024 15:16:49 -0800 Subject: [PATCH 19/22] Rename input_layernorm and post_attention_layernorm to reflect their ops better --- .../models/olmo_1124/convert_olmo_1124_weights_to_hf.py | 4 +++- src/transformers/models/olmo_1124/modeling_olmo_1124.py | 6 +++--- src/transformers/models/olmo_1124/modular_olmo_1124.py | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py index 34f0b01cdb5ec6..fb6cc02b64d9e8 100644 --- a/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py +++ b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py @@ -131,8 +131,10 @@ def write_model( f"model.layers.{layer_i}.mlp.gate_proj.weight": gate_proj_weight, f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.blocks.{layer_i}.ff_out.weight"], f"model.layers.{layer_i}.mlp.up_proj.weight": up_proj_weight, - f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.blocks.{layer_i}.attn_norm.weight"], f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[ + f"transformer.blocks.{layer_i}.attn_norm.weight" + ], + f"model.layers.{layer_i}.post_feedforward_layernorm.weight": loaded[ f"transformer.blocks.{layer_i}.ff_norm.weight" ], } diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py index 457335c3680120..843b13b41de213 100644 --- a/src/transformers/models/olmo_1124/modeling_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py @@ -509,8 +509,8 @@ def __init__(self, config: Olmo1124Config, layer_idx: int): self.self_attn = OLMO_1124_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) self.mlp = Olmo1124MLP(config) - self.input_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward # TODO(joao): add me back asap :) @@ -557,13 +557,13 @@ def forward( cache_position=cache_position, **kwargs, ) - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.mlp(hidden_states) - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index b9de62beb6031e..4927b3537506c6 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -413,8 +413,9 @@ def forward( class Olmo1124DecoderLayer(OlmoDecoderLayer): def __init__(self, config: Olmo1124Config, layer_idx: int): super().__init__(config, layer_idx=layer_idx) - self.input_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = Olmo1124RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + del self.input_layernorm def forward( self, @@ -440,13 +441,13 @@ def forward( cache_position=cache_position, **kwargs, ) - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.mlp(hidden_states) - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) From 083cc93479314e34bb70c3ccfa4a8e9519c8d5ef Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Nov 2024 14:53:02 -0800 Subject: [PATCH 20/22] Use correct tokenizer --- .../models/olmo_1124/convert_olmo_1124_weights_to_hf.py | 8 +++----- tests/models/olmo_1124/test_modeling_olmo_1124.py | 5 ++--- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py index fb6cc02b64d9e8..d10abb6f8ebb5a 100644 --- a/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py +++ b/src/transformers/models/olmo_1124/convert_olmo_1124_weights_to_hf.py @@ -24,7 +24,7 @@ from tokenizers import Tokenizer from transformers import Olmo1124Config, Olmo1124ForCausalLM -from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast +from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast """ @@ -220,7 +220,7 @@ def _write_tokenizer( checkpoint_dir: str, input_tokenizer_path: Path | None, ) -> None: - print(f"Saving a {GPTNeoXTokenizerFast.__name__} to {output_path}.") + print(f"Saving a {GPT2TokenizerFast.__name__} to {output_path}.") if input_tokenizer_path is not None: base_tokenizer = Tokenizer.from_file(str(input_tokenizer_path)) @@ -237,12 +237,10 @@ def _write_tokenizer( eos_token_id = config.eos_token_id if config.eos_token_id is not None else base_tokenizer.get_vocab_size() - 1 pad_token_id = config.pad_token_id if config.pad_token_id is not None else eos_token_id - tokenizer = GPTNeoXTokenizerFast( + tokenizer = GPT2TokenizerFast( tokenizer_object=base_tokenizer, eos_token=base_tokenizer.decode([eos_token_id], skip_special_tokens=False), pad_token=base_tokenizer.decode([pad_token_id], skip_special_tokens=False), - unk_token=None, - bos_token=None, ) tokenizer.save_pretrained(output_path) diff --git a/tests/models/olmo_1124/test_modeling_olmo_1124.py b/tests/models/olmo_1124/test_modeling_olmo_1124.py index 9c2b8534ba022e..b61447f01d9f12 100644 --- a/tests/models/olmo_1124/test_modeling_olmo_1124.py +++ b/tests/models/olmo_1124/test_modeling_olmo_1124.py @@ -22,7 +22,6 @@ from transformers import Olmo1124Config, is_torch_available, set_seed from transformers.generation.configuration_utils import GenerationConfig from transformers.models.auto.tokenization_auto import AutoTokenizer -from transformers.models.gpt_neox.tokenization_gpt_neox_fast import GPTNeoXTokenizerFast from transformers.testing_utils import ( require_tokenizers, require_torch, @@ -378,7 +377,7 @@ def test_model_7b_greedy_generation(self): @require_tokenizers def test_fast_special_tokens(self): - fast_tokenizer = GPTNeoXTokenizerFast.from_pretrained("shanearora/OLMo-7B-1124-hf") + fast_tokenizer = AutoTokenizer.from_pretrained("shanearora/OLMo-7B-1124-hf") original_add_eos_token = fast_tokenizer.add_eos_token @@ -394,7 +393,7 @@ def test_fast_special_tokens(self): @require_tokenizers def test_simple_encode_decode(self): - rust_tokenizer = GPTNeoXTokenizerFast.from_pretrained("shanearora/OLMo-7B-1124-hf") + rust_tokenizer = AutoTokenizer.from_pretrained("shanearora/OLMo-7B-1124-hf") self.assertEqual(rust_tokenizer.encode("This is a test"), [2028, 374, 264, 1296]) self.assertEqual(rust_tokenizer.decode([2028, 374, 264, 1296], skip_special_tokens=True), "This is a test") From b7bd6bdd99d655a8af065abe385d791b77cfbea0 Mon Sep 17 00:00:00 2001 From: Shane A Date: Thu, 7 Nov 2024 15:01:24 -0800 Subject: [PATCH 21/22] Remove test unsupported by GPT2 tokenizer --- .../models/olmo_1124/test_modeling_olmo_1124.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tests/models/olmo_1124/test_modeling_olmo_1124.py b/tests/models/olmo_1124/test_modeling_olmo_1124.py index b61447f01d9f12..2540e5c9d47f53 100644 --- a/tests/models/olmo_1124/test_modeling_olmo_1124.py +++ b/tests/models/olmo_1124/test_modeling_olmo_1124.py @@ -375,22 +375,6 @@ def test_model_7b_greedy_generation(self): text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) - @require_tokenizers - def test_fast_special_tokens(self): - fast_tokenizer = AutoTokenizer.from_pretrained("shanearora/OLMo-7B-1124-hf") - - original_add_eos_token = fast_tokenizer.add_eos_token - - fast_tokenizer.add_eos_token = False - fast = fast_tokenizer.encode("A sample test") - self.assertEqual(fast, [32, 6205, 1296]) - - fast_tokenizer.add_eos_token = True - fast = fast_tokenizer.encode("A sample test") - self.assertEqual(fast, [32, 6205, 1296, 100257]) - - fast_tokenizer.add_eos_token = original_add_eos_token - @require_tokenizers def test_simple_encode_decode(self): rust_tokenizer = AutoTokenizer.from_pretrained("shanearora/OLMo-7B-1124-hf") From e3a1ede655187cb1988803dad48b18211c2f44ea Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 12 Nov 2024 04:25:24 +0000 Subject: [PATCH 22/22] quick push --- src/transformers/models/olmo/modeling_olmo.py | 110 ++++++++++++++++++ .../models/olmo_1124/modeling_olmo_1124.py | 109 ++++++++++++++++- .../models/olmo_1124/modular_olmo_1124.py | 7 ++ 3 files changed, 225 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 60225d4759c6ab..194cdb81454efd 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -26,6 +26,8 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn +from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss + from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache @@ -34,6 +36,7 @@ from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -1140,3 +1143,110 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +class OlmoForSequenceClassification(OlmoPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OlmoModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/models/olmo_1124/modeling_olmo_1124.py b/src/transformers/models/olmo_1124/modeling_olmo_1124.py index 843b13b41de213..16bb805050b07d 100644 --- a/src/transformers/models/olmo_1124/modeling_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modeling_olmo_1124.py @@ -9,13 +9,14 @@ import torch from torch import nn +from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward -from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -1090,3 +1091,109 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + +class Olmo1124ForSequenceClassification(Olmo1124PreTrainedModel): + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = Olmo1124Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/src/transformers/models/olmo_1124/modular_olmo_1124.py b/src/transformers/models/olmo_1124/modular_olmo_1124.py index 4927b3537506c6..56dfe777fa5096 100644 --- a/src/transformers/models/olmo_1124/modular_olmo_1124.py +++ b/src/transformers/models/olmo_1124/modular_olmo_1124.py @@ -14,6 +14,7 @@ OlmoDecoderLayer, OlmoFlashAttention2, OlmoForCausalLM, + OlmoForSequenceClassification, OlmoModel, OlmoSdpaAttention, apply_rotary_pos_emb, @@ -474,3 +475,9 @@ class Olmo1124ForCausalLM(OlmoForCausalLM): def __init__(self, config: Olmo1124Config): super().__init__(config) self.model = Olmo1124Model(config) + + +class Olmo1124ForSequenceClassification(OlmoForSequenceClassification): + def __init__(self, config: Olmo1124Config): + super().__init__(config) + self.model = Olmo1124Model(config)