Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OLMo1123 classification. #1

Open
wants to merge 25 commits into
base: shanea/add-olmo1124
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8e0ccca
Add model skeletion with transformers-cli add-new-model-like
2015aroras Oct 31, 2024
d6b18e2
Convert config to modular, add rms_norm_eps, delete clip_qkv
2015aroras Oct 31, 2024
6605eb9
Convert model to modular, add RMSNorm
2015aroras Oct 31, 2024
916a841
Add flash attention with qk norm and no qkv clipping
2015aroras Oct 31, 2024
c19d622
Add decoder layer with RMSNorm after attention/feedforward layers
2015aroras Oct 31, 2024
ebec833
Add base and causal model
2015aroras Oct 31, 2024
ee2ad8a
Add converter improvements from OLMo repo
2015aroras Oct 31, 2024
84f72ba
Update weight loading in OLMo to HF converter
2015aroras Oct 31, 2024
68c6763
Set correct default for rms_norm_eps
2015aroras Oct 31, 2024
8e5a74a
Set correct pipeline_model_mapping in test
2015aroras Oct 31, 2024
a5f92c2
Run make fixup
2015aroras Oct 31, 2024
fe2e478
Fix model type
2015aroras Nov 4, 2024
029e843
Merge remote-tracking branch 'upstream/main' into shanea/add-olmo1124
2015aroras Nov 4, 2024
4349938
Re-run modular conversion
2015aroras Nov 4, 2024
bd94d9c
Manually set config docs to fix build errors
2015aroras Nov 4, 2024
7a0cbbe
Convert olmo-1124 to olmo_1124 to fix flash attention docs errors
2015aroras Nov 4, 2024
ada451c
Start updating tests
2015aroras Nov 4, 2024
4a2bf2e
Update tests
2015aroras Nov 4, 2024
9a7c61e
Merge branch 'main' into shanea/add-olmo1124
2015aroras Nov 4, 2024
5cd260e
Merge branch 'main' into shanea/add-olmo1124
2015aroras Nov 6, 2024
a330c90
Copy upstream test_eager_matches_sdpa_inference_1_bfloat16 changes to…
2015aroras Nov 6, 2024
1af9cbd
Rename input_layernorm and post_attention_layernorm to reflect their …
2015aroras Nov 6, 2024
083cc93
Use correct tokenizer
2015aroras Nov 7, 2024
b7bd6bd
Remove test unsupported by GPT2 tokenizer
2015aroras Nov 7, 2024
e3a1ede
quick push
vwxyzjn Nov 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,8 @@
title: Nyströmformer
- local: model_doc/olmo
title: OLMo
- local: model_doc/olmo_1124
title: OLMo November 2024
- local: model_doc/olmoe
title: OLMoE
- local: model_doc/open-llama
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Nougat](model_doc/nougat) | ✅ | ✅ | ✅ |
| [Nyströmformer](model_doc/nystromformer) | ✅ | ❌ | ❌ |
| [OLMo](model_doc/olmo) | ✅ | ❌ | ❌ |
| [OLMo November 2024](model_doc/olmo_1124) | ✅ | ❌ | ❌ |
| [OLMoE](model_doc/olmoe) | ✅ | ❌ | ❌ |
| [OmDet-Turbo](model_doc/omdet-turbo) | ✅ | ❌ | ❌ |
| [OneFormer](model_doc/oneformer) | ✅ | ❌ | ❌ |
Expand Down
48 changes: 48 additions & 0 deletions docs/source/en/model_doc/olmo_1124.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# OLMo November 2024

## Overview

The OLMo November 2024 model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*

Tips:

<INSERT TIPS ABOUT MODEL HERE>

This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).


## Olmo1124Config

[[autodoc]] Olmo1124Config

## Olmo1124Model

[[autodoc]] Olmo1124Model
- forward

## Olmo1124ForCausalLM

[[autodoc]] Olmo1124ForCausalLM
- forward
2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OLMo November 2024](https://huggingface.co/docs/transformers/model_doc/olmo_1124#transformers.Olmo1124Model)
* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
Expand Down Expand Up @@ -260,6 +261,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OLMo November 2024](https://huggingface.co/docs/transformers/model_doc/olmo_1124#transformers.Olmo1124Model)
* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
* [OPT](https://huggingface.co/docs/transformers/en/model_doc/opt)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
Expand Down
14 changes: 14 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@
"models.nougat": ["NougatProcessor"],
"models.nystromformer": ["NystromformerConfig"],
"models.olmo": ["OlmoConfig"],
"models.olmo_1124": ["Olmo1124Config"],
"models.olmoe": ["OlmoeConfig"],
"models.omdet_turbo": [
"OmDetTurboConfig",
Expand Down Expand Up @@ -2919,6 +2920,13 @@
"OlmoPreTrainedModel",
]
)
_import_structure["models.olmo_1124"].extend(
[
"Olmo1124ForCausalLM",
"Olmo1124Model",
"Olmo1124PreTrainedModel",
]
)
_import_structure["models.olmoe"].extend(
[
"OlmoeForCausalLM",
Expand Down Expand Up @@ -5506,6 +5514,7 @@
NystromformerConfig,
)
from .models.olmo import OlmoConfig
from .models.olmo_1124 import Olmo1124Config
from .models.olmoe import OlmoeConfig
from .models.omdet_turbo import (
OmDetTurboConfig,
Expand Down Expand Up @@ -7523,6 +7532,11 @@
OlmoModel,
OlmoPreTrainedModel,
)
from .models.olmo_1124 import (
Olmo1124ForCausalLM,
Olmo1124Model,
Olmo1124PreTrainedModel,
)
from .models.olmoe import (
OlmoeForCausalLM,
OlmoeModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@
nougat,
nystromformer,
olmo,
olmo_1124,
olmoe,
omdet_turbo,
oneformer,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@
("nougat", "VisionEncoderDecoderConfig"),
("nystromformer", "NystromformerConfig"),
("olmo", "OlmoConfig"),
("olmo_1124", "Olmo1124Config"),
("olmoe", "OlmoeConfig"),
("omdet-turbo", "OmDetTurboConfig"),
("oneformer", "OneFormerConfig"),
Expand Down Expand Up @@ -510,6 +511,7 @@
("nougat", "Nougat"),
("nystromformer", "Nyströmformer"),
("olmo", "OLMo"),
("olmo_1124", "OLMo November 2024"),
("olmoe", "OLMoE"),
("omdet-turbo", "OmDet-Turbo"),
("oneformer", "OneFormer"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@
("nllb-moe", "NllbMoeModel"),
("nystromformer", "NystromformerModel"),
("olmo", "OlmoModel"),
("olmo_1124", "Olmo1124Model"),
("olmoe", "OlmoeModel"),
("omdet-turbo", "OmDetTurboForObjectDetection"),
("oneformer", "OneFormerModel"),
Expand Down Expand Up @@ -516,6 +517,7 @@
("mvp", "MvpForCausalLM"),
("nemotron", "NemotronForCausalLM"),
("olmo", "OlmoForCausalLM"),
("olmo_1124", "Olmo1124ForCausalLM"),
("olmoe", "OlmoeForCausalLM"),
("open-llama", "OpenLlamaForCausalLM"),
("openai-gpt", "OpenAIGPTLMHeadModel"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@
),
),
("olmo", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("olmo_1124", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("olmoe", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
(
"omdet-turbo",
Expand Down
110 changes: 110 additions & 0 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss


from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
Expand All @@ -34,6 +36,7 @@
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
Expand Down Expand Up @@ -1140,3 +1143,110 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


class OlmoForSequenceClassification(OlmoPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = OlmoModel(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
return self.model.embed_tokens

def set_input_embeddings(self, value):
self.model.embed_tokens = value

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]

if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
57 changes: 57 additions & 0 deletions src/transformers/models/olmo_1124/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2024 EleutherAI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_olmo_1124": ["Olmo1124Config"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_olmo_1124"] = [
"Olmo1124ForCausalLM",
"Olmo1124Model",
"Olmo1124PreTrainedModel",
]

if TYPE_CHECKING:
from .configuration_olmo_1124 import Olmo1124Config

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_olmo_1124 import (
Olmo1124ForCausalLM,
Olmo1124Model,
Olmo1124PreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading