Skip to content

Commit

Permalink
initial-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Dec 16, 2023
1 parent 238d2e3 commit 81c642f
Show file tree
Hide file tree
Showing 14 changed files with 1,840 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@
title: M2M100
- local: model_doc/madlad-400
title: MADLAD-400
- local: model_doc/mamba
title: Mamba
- local: model_doc/marian
title: MarianMT
- local: model_doc/markuplm
Expand Down
118 changes: 118 additions & 0 deletions docs/source/en/model_doc/mamba.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Mamba

# Mamba

# Mamba

## Overview

The Mamba model was proposed in [<INSERT PAPER NAME HERE>](<INSERT PAPER LINK HERE>) by <INSERT AUTHORS HERE>.
<INSERT SHORT SUMMARY HERE>

The abstract from the paper is the following:

*<INSERT PAPER ABSTRACT HERE>*

Tips:

<INSERT TIPS ABOUT MODEL HERE>

This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/<INSERT YOUR HF USERNAME HERE>).
The original code can be found [here](<INSERT LINK TO GITHUB REPO HERE>).


## MambaConfig

[[autodoc]] MambaConfig

## MambaModel

[[autodoc]] MambaModel
- forward

## MambaLMHeadModel

[[autodoc]] MambaForCausalLM
- forward

## Mamba attention and the recurrent formulas

In a traditional auto-regressive Transformer, attention is written as

$$O = \hbox{softmax}(QK^{T} / \sqrt{d}) V$$

with \\(Q\\), \\(K\\) and \\(V\\) are matrices of shape `seq_len x hidden_size` named query, key and value (they are actually bigger matrices with a batch dimension and an attention head dimension but we're only interested in the last two, which is where the matrix product is taken, so for the sake of simplicity we only consider those two). The product \\(QK^{T}\\) then has shape `seq_len x seq_len` and we can take the maxtrix product with \\(V\\) to get the output \\(O\\) of the same shape as the others.

Replacing the softmax by its value gives:

$$O_{i} = \frac{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}} V_{j}}{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}}}$$

Note that the entries in \\(QK^{T}\\) corresponding to \\(j > i\\) are masked (the sum stops at j) because the attention is not allowed to look at future tokens (only past ones).

In comparison, the MAMBA attention is given by

$$O_{i} = \sigma(R_{i}) \frac{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}} V_{j}}{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}}}$$

where \\(R\\) is a new matrix called receptance by the author, \\(K\\) and \\(V\\) are still the key and value (\\(\sigma\\) here is the sigmoid function). \\(W\\) is a new vector that represents the position of the token and is given by

$$W_{0} = u \hbox{ and } W_{k} = (k-1)w \hbox{ for } k \geq 1$$

with \\(u\\) and \\(w\\) learnable parameters called in the code `time_first` and `time_decay` respectively. The numerator and denominator can both be expressed recursively. Naming them \\(N_{i}\\) and \\(D_{i}\\) we have:

$$N_{i} = e^{u + K_{i}} V_{i} + \hat{N}_{i} \hbox{ where } \hat{N}_{i} = e^{K_{i-1}} V_{i-1} + e^{w + K_{i-2}} V_{i-2} \cdots + e^{(i-2)w + K_{1}} V_{1}$$

so \\(\hat{N}_{i}\\) (called `numerator_state` in the code) satistfies

$$\hat{N}_{0} = 0 \hbox{ and } \hat{N}_{j+1} = e^{K_{j}} V_{j} + e^{w} \hat{N}_{j}$$

and

$$D_{i} = e^{u + K_{i}} + \hat{D}_{i} \hbox{ where } \hat{D}_{i} = e^{K_{i-1}} + e^{w + K_{i-2}} \cdots + e^{(i-2)w + K_{1}}$$

so \\(\hat{D}_{i}\\) (called `denominator_state` in the code) satistfies

$$\hat{D}_{0} = 0 \hbox{ and } \hat{D}_{j+1} = e^{K_{j}} + e^{w} \hat{D}_{j}$$

The actual recurrent formula used are a tiny bit more complex, as for numerical stability we don't want to compute exponentials of big numbers. Usually the softmax is not computed as is, but the exponential of the maximum term is divided of the numerator and denominator:

$$\frac{e^{x_{i}}}{\sum_{j=1}^{n} e^{x_{j}}} = \frac{e^{x_{i} - M}}{\sum_{j=1}^{n} e^{x_{j} - M}}$$

with \\(M\\) the maximum of all \\(x_{j}\\). So here on top of saving the numerator state (\\(\hat{N}\\)) and the denominator state (\\(\hat{D}\\)) we also keep track of the maximum of all terms encountered in the exponentials. So we actually use

$$\tilde{N}_{i} = e^{-M_{i}} \hat{N}_{i} \hbox{ and } \tilde{D}_{i} = e^{-M_{i}} \hat{D}_{i}$$

defined by the following recurrent formulas:

$$\tilde{N}_{0} = 0 \hbox{ and } \tilde{N}_{j+1} = e^{K_{j} - q} V_{j} + e^{w + M_{j} - q} \tilde{N}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$

and

$$\tilde{D}_{0} = 0 \hbox{ and } \tilde{D}_{j+1} = e^{K_{j} - q} + e^{w + M_{j} - q} \tilde{D}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$

and \\(M_{j+1} = q\\). With those, we can then compute

$$N_{i} = e^{u + K_{i} - q} V_{i} + e^{M_{i}} \tilde{N}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$

and

$$D_{i} = e^{u + K_{i} - q} + e^{M_{i}} \tilde{D}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$

which finally gives us

$$O_{i} = \sigma(R_{i}) \frac{N_{i}}{D_{i}}$$
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@
"RoFormerTokenizer",
],
"models.rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig"],
"models.mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig"],
"models.sam": [
"SAM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"SamConfig",
Expand Down Expand Up @@ -3078,6 +3079,14 @@
"RwkvPreTrainedModel",
]
)
_import_structure["models.mamba"].extend(
[
"MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MambaForCausalLM",
"MambaModel",
"MambaPreTrainedModel",
]
)
_import_structure["models.sam"].extend(
[
"SAM_PRETRAINED_MODEL_ARCHIVE_LIST",
Expand Down Expand Up @@ -5396,6 +5405,7 @@
RoFormerTokenizer,
)
from .models.rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig
from .models.mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig
from .models.sam import (
SAM_PRETRAINED_CONFIG_ARCHIVE_MAP,
SamConfig,
Expand Down Expand Up @@ -7428,6 +7438,12 @@
RwkvModel,
RwkvPreTrainedModel,
)
from .models.mamba import (
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
MambaForCausalLM,
MambaModel,
MambaPreTrainedModel,
)
from .models.sam import (
SAM_PRETRAINED_MODEL_ARCHIVE_LIST,
SamModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@
roc_bert,
roformer,
rwkv,
mamba,
sam,
seamless_m4t,
seamless_m4t_v2,
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@
("roc_bert", "RoCBertConfig"),
("roformer", "RoFormerConfig"),
("rwkv", "RwkvConfig"),
("mamba", "MambaConfig"),
("sam", "SamConfig"),
("seamless_m4t", "SeamlessM4TConfig"),
("seamless_m4t_v2", "SeamlessM4Tv2Config"),
Expand Down Expand Up @@ -411,6 +412,7 @@
("roc_bert", "ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("rwkv", "RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("mamba", "MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("sam", "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("seamless_m4t", "SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP"),
("seamless_m4t_v2", "SEAMLESS_M4T_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"),
Expand Down Expand Up @@ -655,6 +657,7 @@
("roc_bert", "RoCBert"),
("roformer", "RoFormer"),
("rwkv", "RWKV"),
("mamba", "Mamba"),
("sam", "SAM"),
("seamless_m4t", "SeamlessM4T"),
("seamless_m4t_v2", "SeamlessM4Tv2"),
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@
("roc_bert", "RoCBertModel"),
("roformer", "RoFormerModel"),
("rwkv", "RwkvModel"),
("mamba", "MambaModel"),
("sam", "SamModel"),
("seamless_m4t", "SeamlessM4TModel"),
("seamless_m4t_v2", "SeamlessM4Tv2Model"),
Expand Down Expand Up @@ -291,6 +292,7 @@
("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"),
("roc_bert", "RoCBertForPreTraining"),
("rwkv", "RwkvForCausalLM"),
("mamba", "MambaForCausalLM"),
("splinter", "SplinterForPreTraining"),
("squeezebert", "SqueezeBertForMaskedLM"),
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
Expand Down Expand Up @@ -380,6 +382,7 @@
("roc_bert", "RoCBertForMaskedLM"),
("roformer", "RoFormerForMaskedLM"),
("rwkv", "RwkvForCausalLM"),
("mamba", "MambaForCausalLM"),
("speech_to_text", "Speech2TextForConditionalGeneration"),
("squeezebert", "SqueezeBertForMaskedLM"),
("switch_transformers", "SwitchTransformersForConditionalGeneration"),
Expand Down Expand Up @@ -453,6 +456,7 @@
("roc_bert", "RoCBertForCausalLM"),
("roformer", "RoFormerForCausalLM"),
("rwkv", "RwkvForCausalLM"),
("mamba", "MambaForCausalLM"),
("speech_to_text_2", "Speech2Text2ForCausalLM"),
("transfo-xl", "TransfoXLLMHeadModel"),
("trocr", "TrOCRForCausalLM"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@
("roc_bert", ("RoCBertTokenizer", None)),
("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)),
("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)),
(
"seamless_m4t",
(
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/flava/modeling_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -1949,7 +1949,7 @@ def forward(

if mim_labels is not None:
mim_labels = mim_labels[pos_mask]

# MMM Image Loss
if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
sequence_for_image = multimodal_masked_embeddings
Expand Down
60 changes: 60 additions & 0 deletions src/transformers/models/mamba/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig", "MambaOnnxConfig"],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_mamba"] = [
"MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MambaForCausalLM",
"MambaModel",
"MambaPreTrainedModel",
]


if TYPE_CHECKING:
from .configuration_mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig, MambaOnnxConfig

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_mamba import (
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
MambaForCausalLM,
MambaModel,
MambaPreTrainedModel,
)
else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
Loading

0 comments on commit 81c642f

Please sign in to comment.